
import gdal
import ogr
import os
import utility_functions as utility_functions
import glob
import numpy as np
import rasterio
from rasterio.mask import mask
from rasterio.merge import merge
from datetime import date
import fiona

import datetime as dt

import argparse
import shutil
import time


#This script iterates through SHP features and creates rasterized outputs (for the whole area and city centers)
#   and if not already done, assigns a UHI_ID to the shapefile feature


def boundingBoxToOffsets(bbox, geot, xsize, ysize, lc):
    lcadj = 0
    if lc == True:
        lcadj = 1






    col1 = int((bbox[0] - geot[0]) / geot[1])
    col2 = int((bbox[1] - geot[0]) / geot[1]) + 1
    row1 = int((bbox[3] - geot[3]) / geot[5])
    row2 = int((bbox[2] - geot[3]) / geot[5]) + 1

    if bbox[0] < geot[0]:
        #if the city boundary extends west of the raster, default to the rasters western edge
        col1 = 0
    if bbox[1] > (geot[0]+xsize*geot[1]):
        #if the city boundary extends east of the raster, default to the rasters eastern edge
        radj = geot[0]+(xsize*geot[1])
        col2 = int((radj - geot[0]) / geot[1])
    if bbox[3] > geot[3]:
        row1 = 0 + lcadj
    if bbox[2] < (geot[3]+ysize*geot[5]):
        ladj = geot[3]+(ysize*geot[5])
        row2 = int((ladj - geot[3]) / geot[5]) - lcadj

    return [row1, row2, col1, col2]

def geotFromOffsets(row_offset, col_offset, geot):
    new_geot = [
        geot[0] + (col_offset * geot[1]),
        geot[1],
        0.0,
        geot[3] + (row_offset * geot[5]),
        0.0,
        geot[5]
    ]
    return new_geot

def add_fields(shapefile, iata):
    driver = ogr.GetDriverByName("ESRI Shapefile")
    dataSource = driver.Open(shapefile, 1)
    layer = dataSource.GetLayer()
    ldefn = layer.GetLayerDefn()

    schema = []
    for n in range(ldefn.GetFieldCount()):
        fdefn = ldefn.GetFieldDefn(n)
        schema.append(fdefn.name)

    if "Area" not in schema:
        new_field = ogr.FieldDefn("Area", ogr.OFTReal)
        new_field.SetWidth(32)
        new_field.SetPrecision(2)
        layer.CreateField(new_field)

    alist = []

    for feature in layer:
        geom = feature.GetGeometryRef()
        area = geom.GetArea()
        print("area: " + str(area))
        alist = alist + [area]
        feature.SetField("Area", area)
        layer.SetFeature(feature)


    alist.sort()
    dataSource = None

    return alist

def add_uhi_id(shapefile, iata, alist):
    driver = ogr.GetDriverByName("ESRI Shapefile")
    dataSource = driver.Open(shapefile, 1)
    layer = dataSource.GetLayer()
    year = str(date.today().year)[2:]  # last 2 digits of currentyear

    for feature in layer:
        area = feature.GetField('Area')

        feature.SetField("Area", area)
        layer.SetFeature(feature)

        if area > 0:
            for a in alist:
                if "{:.2f}".format(a[0]) == "{:.2f}".format(area):
                    id = iata + year + "{:02d}".format(a[1])

                    break

        feature.SetField("UHI_ID", id)
        layer.SetFeature(feature)
        ABC = None

    dataSource.SyncToDisk()
    dataSource = None
    ABC = None

def add_idfield(shapefile):
    driver = ogr.GetDriverByName("ESRI Shapefile")
    dataSource = driver.Open(shapefile, 1)
    layer = dataSource.GetLayer()
    ldefn = layer.GetLayerDefn()

    schema = []
    for n in range(ldefn.GetFieldCount()):
        fdefn = ldefn.GetFieldDefn(n)
        schema.append(fdefn.name)
    if "UHI_ID" not in schema:
        new_field = ogr.FieldDefn("UHI_ID", ogr.OFTString)
        new_field.SetWidth(7)
        new_field.SetPrecision(0)
        layer.CreateField(new_field)
        print("Added UHI_ID field")
    else:
        print("UHI_ID field exists")


    dataSource = None


def zonal_array(fn_raster, fn_zones, workdir, iata):

    save_path = workdir + os.sep + "UrbanZones"
    if not os.path.exists(save_path):
        os.mkdir(save_path)

    alist = add_fields(fn_zones, iata) #adds fields and calculates each urban centers area
    add_idfield(fn_zones)  #adds UHI_ID field if it does not exist


    mem_driver = ogr.GetDriverByName("Memory")
    mem_driver_gdal = gdal.GetDriverByName("MEM")
    shp_name = "temp"
    mosaic_list_uz = []

    r_ds = gdal.Open(fn_raster)
    p_ds = ogr.Open(fn_zones)

    lyr = p_ds.GetLayer()
    geot = r_ds.GetGeoTransform()
    trans, prj = utility_functions.get_geo(fn_raster)
    xsize = r_ds.RasterXSize
    ysize = r_ds.RasterYSize

    p_feat = lyr.GetNextFeature()

    #Iterate through shapefile features (i.e. 5 km buffered urban zones)
    while p_feat:
        print("new iteration")
        #Create temporary raster and shp layers in memory
        if p_feat.GetGeometryRef() is not None:
            if os.path.exists(shp_name):
                mem_driver.DeleteDataSource(shp_name)
            tp_ds = mem_driver.CreateDataSource(shp_name)
            tp_lyr = tp_ds.CreateLayer('polygons', None, ogr.wkbPolygon)

            # Copy the targeted feature to the new temporary shp
            tp_lyr.CreateFeature(p_feat.Clone())

            # Get the bounding box of the target feature and convert to offsets
            offsets = boundingBoxToOffsets(p_feat.GetGeometryRef().GetEnvelope(), geot, xsize, ysize, lc=False)
            new_geot = geotFromOffsets(offsets[0], offsets[2], geot)

            tr_ds = mem_driver_gdal.Create( \
                "", \
                offsets[3] - offsets[2], \
                offsets[1] - offsets[0], \
                1, \
                gdal.GDT_Byte)

            tr_ds.SetGeoTransform(new_geot)
            gdal.RasterizeLayer(tr_ds, [1], tp_lyr, burn_values=[1])
            tr_array = tr_ds.ReadAsArray()

            # insert UHI_ID into table
            n = len(alist)
            alist2 = []
            for a in alist:
                alist2 = alist2 + [[a, n]]
                n = n - 1

            add_uhi_id(fn_zones, iata, alist2)

            id = p_feat.GetField("UHI_ID")

            sep_dir = save_path + os.sep + iata + "_sep"
            if not os.path.exists(sep_dir):
                os.mkdir(sep_dir)
            output = sep_dir + os.sep + 'urbanzone_' + str(id) + '.tif'
            print(output)
            mosaic_list_uz = mosaic_list_uz + [output]
            y_res, x_res = tr_array.shape
            utility_functions.saveGeoTiff(output, tr_array, [x_res, y_res, 1], new_geot, prj, dtype=gdal.GDT_Float32,nodata_value=0)

            p_feat = lyr.GetNextFeature()


        # merge rasters
        src_files_to_mosaic_uz = []

        uzout = workdir + os.sep + 'UrbanZones' + os.sep

        uzout_indiv = uzout + os.sep + iata +'_sep'
        if not os.path.exists(uzout_indiv):
            os.mkdir(uzout_indiv)

        for uzraster in mosaic_list_uz:
            src_uz = rasterio.open(uzraster)
            src_files_to_mosaic_uz.append(src_uz)

        mosaic_uz, out_trans_uz = merge(src_files_to_mosaic_uz)
        out_meta_uz = src_uz.meta.copy()
        out_meta_uz.update({"driver": "GTiff",
                            "height": mosaic_uz.shape[1],
                            "width": mosaic_uz.shape[2],
                            "transform": out_trans_uz
                            })

        out_fp_uz = uzout + os.sep + 'uhi_urbanzone_' + str(id[:3]) + '.tif'

        with rasterio.open(out_fp_uz, "w", **out_meta_uz) as dest_uz:
            dest_uz.write(mosaic_uz)

def clip_raster(input, umask, output):
    with fiona.open(umask, "r") as shapefile:
        shapes = [feature["geometry"] for feature in shapefile]

    with rasterio.open(input) as src:
        out_image, out_transform = mask(src, shapes, crop=True)
        out_meta = src.meta
        out_image = np.ma.masked_where(out_image == 0, out_image)

    out_meta.update({"driver": "GTiff",
                     "height": out_image.shape[1],
                     "width": out_image.shape[2],
                     "transform": out_transform})
    #print(out_image.shape[1])
    out_image = out_image.data

    with rasterio.open(output, "w", **out_meta) as dest:
        dest.write(out_image)




if __name__ == '__main__':
    # https://eroslab.cr.usgs.gov/SCIENCE/urban-heat-islands/-/blob/5589d7a9c4a830459f14a3abd3bc90fa8e365224/Phase2/uhi_stats_classbased.py
    # https://stackabuse.com/command-line-arguments-in-python/
    
    
    # parser = argparse.ArgumentParser()
    # parser.add_argument("--iata", "-a", help="iata")
    # parser.add_argument("--workdir", "-b", help="workdir")
    # parser.add_argument("--fn_zones", "-c", help="fn_zones")
    # parser.add_argument("--clipped_output_directory", "-d", help="clipped_output_directory")
    # parser.add_argument("--raster_list_base_path", "-e", help="raster_list_base_path")
    # args = parser.parse_args()
    

    #--------------------------------------------------------------------------------|

    
    #original
    #workdir = r'\\igskmncnfs016.cr.usgs.gov\lsrdfs1\UHI\data\Pre HPC Testing'
    ## Y:\UHI\data\full_run_50_cities\FSD\output_folder\ANNUAL then u have to filter to only
    ## use the MAXLST
    #fn_rasters = workdir + r'\LST_ANNUAL\MAXLST'
    ##I think this is manually generated, because its not generated from the previous step
    #fn_zones = workdir + r'\SHAPEFILES\City_Boundaries_Dissolved'
    ##put this in an output folder with the script name
    #clipped_output = workdir + r'\LST_ANNUAL_CLIPPED'
    #iata = 'FSD'
    

    #--------------------------------------------------------------------------------|
    
    
    #iata = 'FSD'
    #workdir = r'/caldera/projects/usgs/eros/urban_heat_islands/Phase2/uhi_urbanbuffer_clip'
    ## Y:\UHI\data\full_run_50_cities\FSD\output_folder\ANNUAL then u have to filter to only
    ## use the MAXLST
    ##/caldera/projects/usgs/eros/urban_heat_islands/Phase1/FSD/output_folder/ANNUAL
    ##fn_rasters = workdir + r'\LST_ANNUAL\MAXLST'
    ##I think this is manually generated, because its not generated from the previous step
    #fn_zones = r'/caldera/projects/usgs/eros/urban_heat_islands/Phase2/uhi_metrobuffer_generation/City_Boundaries_Dissolved'
    ##put this in an output folder with the script name
    #clipped_output = r'/caldera/projects/usgs/eros/urban_heat_islands/Phase1/' + iata + '/output_folder/' + r'LST_ANNUAL_CLIPPED'
    #raster_list_base_path = '/caldera/projects/usgs/eros/urban_heat_islands/Phase1/' + iata + '/output_folder/ANNUAL'
    

    #--------------------------------------------------------------------------------|

    # #https://eroslab.cr.usgs.gov/SCIENCE/urban-heat-islands/-/commit/b4879f67dba6ef55d95afe0a704ceb18b139fe19
    # iata = args.iata
    # workdir =  args.workdir
    # fn_zones = args.fn_zones
    # clipped_output_directory =  args.clipped_output_directory
    # raster_list_base_path = args.raster_list_base_path

    '''
    iata = 'FSD'
    workdir = '/caldera/projects/usgs/eros/urban_heat_islands/Phase2/uhi_urbanbuffer_clip'
    fn_zones = '/caldera/projects/usgs/eros/urban_heat_islands/Phase2/uhi_metrobuffer_generation/City_Boundaries_Dissolved'
    clipped_output_directory = '/caldera/projects/usgs/eros/urban_heat_islands/Phase1/' + iata + '/output_folder/' + 'LST_ANNUAL_CLIPPED'
    raster_list_base_path = '/caldera/projects/usgs/eros/urban_heat_islands/Phase1/' + iata + '/output_folder/ANNUAL'
    '''
    
    iata = 'SMF'
    workdir = '/caldera/projects/usgs/eros/urban_heat_islands/Phase2/uhi_urbanbuffer_clip'
    fn_zones = '/caldera/projects/usgs/eros/urban_heat_islands/Phase2/uhi_metrobuffer_generation/City_Boundaries_Dissolved/'+'ard_'+iata.lower()+'_city_dissolved.shp'
    clipped_output_directory = '/caldera/projects/usgs/eros/urban_heat_islands/Phase1/' + iata + '/output_folder/' + 'LST_ANNUAL_CLIPPED'
    raster_list_base_path = '/caldera/projects/usgs/eros/urban_heat_islands/Phase1/' + iata + '/output_folder/ANNUAL'


    #--------------------------------------------------------------------------------|

    # if os.path.exists(clipped_output_directory):
    #     shutil.rmtree(clipped_output_directory)
    #     time.sleep(10)
    # if not os.path.exists(clipped_output_directory):
    #     os.mkdir(clipped_output_directory)

    #you have to filter to ones that contain MAXLST here
    #r_path = r"Z:\UHI\data\full_run_50_cities\MSO\output_folder\ANNUAL\*MAXLST.tif"

    for piece in ["*MAXLST.tif","*MEANLST.tif"]:

        r_path = raster_list_base_path + os.sep + piece

        rlist = glob.glob( r_path )
        zonal_array(rlist[0], fn_zones, workdir, iata)


        #zonal_array(rlist[0], umask, workdir, iata)

        for fn_raster in rlist:
            print("raster: " + fn_raster)
            #ard_FSD_city_dissolved.shp
            umask = fn_zones#+ os.sep + "ard_" + iata.lower() + "_city_dissolved.shp"
            print("umask: " + umask)
            
            year = os.path.basename(fn_raster).split("_")[3]
            long_name = os.path.basename(fn_raster).split(".tif")[0]

            
            split_version = long_name.split("_")
            processing_date = dt.datetime.today().strftime('%Y%m%d')

            new_long_name = split_version[0] + "_" + split_version[1] + "_" + split_version[2] + "_" + split_version[3] + "_" + processing_date + "_" + split_version[5] + "_" + split_version[6]


            if int(year) == 1984:
                continue

            outfile = clipped_output_directory + os.sep + new_long_name + '_CLIPPED.tif'
            clip_raster(fn_raster, umask, outfile)
            ABC = None

        #--------------------------------------------------------------------------------|
