
#Load stack images as np
#load UHI LC as np
#rasterize polygon features to determine polygon zones
#average all pixels that match target polygon and are urban and store to table
#save zonal mean value to all pixels that match target polygon and are urban

import gdal
import ogr
import os
import numpy as np
import csv
import utility_functions as utility_functions
import glob
import rasterio
from rasterio.merge import merge
import argparse
from rasterio.mask import mask

import shutil
import time


def boundingBoxToOffsets(bbox, geot, xsize, ysize, lc):
    lcadj = 0
    if lc == True:
        lcadj = 1

    col1 = int((bbox[0] - geot[0]) / geot[1])
    col2 = int((bbox[1] - geot[0]) / geot[1]) + 1
    row1 = int((bbox[3] - geot[3]) / geot[5])
    row2 = int((bbox[2] - geot[3]) / geot[5]) + 1

    if bbox[0] < geot[0]:
        col1 = 0
    if bbox[1] > (geot[0]+xsize*geot[1]):
        radj = geot[0]+(xsize*geot[1])
        col2 = int((radj - geot[0]) / geot[1])
    if bbox[3] > geot[3]:
        row1 = 0 + lcadj
    if bbox[2] < (geot[3]+ysize*geot[5]):
        ladj = geot[3]+(ysize*geot[5])
        row2 = int((ladj - geot[3]) / geot[5]) - lcadj

    return [row1, row2, col1, col2]


def geotFromOffsets(row_offset, col_offset, geot):
    new_geot = [
        geot[0] + (col_offset * geot[1]),
        geot[1],
        0.0,
        geot[3] + (row_offset * geot[5]),
        0.0,
        geot[5]
    ]
    return new_geot


def setFeatureStats(uhi_id, min, max, mean, median, sd, sum, count, nulst_mean,
                    names=["id","min", "max", "mean", "median", "sd", "sum", "count", "nulst_mean"]):
    featstats = {
        names[0]: uhi_id,
        names[1]: min,
        names[2]: max,
        names[3]: mean,
        names[4]: median,
        names[5]: sd,
        names[6]: sum,
        names[7]: count,
        names[8]: nulst_mean
    }
    return featstats

def ClassBasedStats_byUrbanZone(fn_raster, fn_zones, fn_class, workdir, itype,iata):

    mem_driver = ogr.GetDriverByName("Memory")
    mem_driver_gdal = gdal.GetDriverByName("MEM")
    shp_name = "temp"

    r_ds = gdal.Open(fn_raster)
    c_ds = gdal.Open(fn_class)
    p_ds = ogr.Open(fn_zones)

    lyr = p_ds.GetLayer()
    geot = r_ds.GetGeoTransform()
    cgeot = c_ds.GetGeoTransform()
    nodata = r_ds.GetRasterBand(1).GetNoDataValue()
    trans, prj = utility_functions.get_geo(fn_raster)
    xsize = r_ds.RasterXSize
    ysize = r_ds.RasterYSize
    xcsize = c_ds.RasterXSize
    ycsize = c_ds.RasterYSize

    zstats = []
    mosaic_list = []
    mosaic_list_uz = []

    p_feat = lyr.GetNextFeature()

    #Iterate through shapefile features (i.e. 5 km buffered urban zones)
    while p_feat:
        id = p_feat.GetField("UHI_ID")
        print(id)

        #Create temporary raster and shp layers in memory
        if p_feat.GetGeometryRef() is not None:

            if os.path.exists(shp_name):
                mem_driver.DeleteDataSource(shp_name)
            tp_ds = mem_driver.CreateDataSource(shp_name)
            tp_lyr = tp_ds.CreateLayer('polygons', None, ogr.wkbPolygon)

            # Copy the targetted feature to the new temporary shp
            tp_lyr.CreateFeature(p_feat.Clone())

            # Get the bounding box of the target feature and convert to offsets
            offsets = boundingBoxToOffsets(p_feat.GetGeometryRef().GetEnvelope(), geot, xsize, ysize, lc=False)
            new_geot = geotFromOffsets(offsets[0], offsets[2], geot)

            c_offsets = boundingBoxToOffsets(p_feat.GetGeometryRef().GetEnvelope(), cgeot, xcsize, ycsize, lc=True)

            tr_ds = mem_driver_gdal.Create( \
                "", \
                offsets[3] - offsets[2], \
                offsets[1] - offsets[0], \
                1, \
                gdal.GDT_UInt16)

            tr_ds.SetGeoTransform(new_geot)
            burn_value = int(id[-4:]) #the last 4 digits of the UHI_ID
            gdal.RasterizeLayer(tr_ds, [1], tp_lyr, burn_values=[burn_value])
            tr_array = tr_ds.ReadAsArray()
            ##print(str(offsets[2]) + "," + str(offsets[0]) + "," + str(offsets[3]-offsets[2]) + "," + str(offsets[1]-offsets[0]))
            '''
            #see phase2_7_uhi_stats_citycen_total
            r_array = r_ds.GetRasterBand(1).ReadAsArray( \
                offsets[2], \
                offsets[0], \
                offsets[3] - offsets[2], \
                offsets[1] - offsets[0])
            c_array = c_ds.GetRasterBand(1).ReadAsArray( \
                c_offsets[2], \
                c_offsets[0], \
                c_offsets[3] - c_offsets[2], \
                c_offsets[1] - c_offsets[0])
            '''

            r_array = r_ds.GetRasterBand(1).ReadAsArray( \
                offsets[2], \
                offsets[0], \
                offsets[3] - offsets[2], \
                offsets[1] - offsets[0])
            c_array = c_ds.GetRasterBand(1).ReadAsArray( \
                c_offsets[2], \
                c_offsets[0], \
                offsets[3] - offsets[2], \
                offsets[1] - offsets[0])

            uzout = workdir + os.sep + 'UrbanZones' + os.sep
            if not os.path.exists(uzout):
                os.mkdir(uzout)

            uzout_indiv = uzout + os.sep  + iata+'_sep'
            if not os.path.exists(uzout_indiv):
                os.mkdir(uzout_indiv)

            save_path_uz = uzout_indiv + os.sep + 'uhi_urbanzone_' + str(id) + '.tif'
            mosaic_list_uz = mosaic_list_uz + [save_path_uz]
            y_res, x_res = c_array.shape
            utility_functions.saveGeoTiff(save_path_uz, tr_array, [x_res, y_res, 1], new_geot, prj, dtype=gdal.GDT_Float32)


            if r_array is not None:
                maskarray = np.ma.MaskedArray( \
                    r_array, \
                    mask=np.logical_or(r_array == nodata, np.logical_not(tr_array)))
                ##print("maskarray: " + str(np.unique(maskarray)))

                if maskarray is not None:
                    maskarray[maskarray == -9999] = np.nan
                    ##print("nanarray: " + str(np.unique(maskarray)))

                    if False in maskarray[np.where(c_array > 0)].mask:
                        all_masked = False
                    else:
                        all_masked = True

                    if not all_masked:

                        min = np.nanmin(maskarray[np.where(c_array > 0)])
                        max = np.nanmax(maskarray[np.where(c_array > 0)])
                        mean = np.nanmean(maskarray[np.where(c_array > 0)])
                        median = np.nanmedian(maskarray[np.where(c_array > 0)])
                        std = np.nanstd(maskarray[np.where(c_array > 0)])
                        sum = np.nansum(maskarray[np.where(c_array > 0)])
                        count = maskarray[np.where(c_array > 0)].count() - np.count_nonzero(np.isnan(maskarray[np.where(c_array > 0)]))
                        #nulst_mean = np.nanmean(maskarray[np.where(c_array < 20)])
                        nulst_mean = np.nanmean(np.copy(maskarray[np.where(c_array < 20)]))

                        zstats.append(setFeatureStats( \
                            id, \
                            min, \
                            max, \
                            mean, \
                            median, \
                            std, \
                            sum, \
                            count, \
                            #c,\
                            nulst_mean))

                        # classarray = np.copy(maskarray)
                        # classarray[np.where(c_array <= 0)] = 'nan'
                        # classarray[classarray > -9999] = mean
                        # np.nan_to_num(classarray, copy=False, nan=-9999)
                        ##print(sum)

                        #I don't think this requires anything from this loop so it can be moved to the top
                        classout =workdir + os.sep + 'PERSISTENT_CITYCEN_STATS_' + itype
                        if not os.path.exists(classout):
                            os.mkdir(classout)

                        # save_path = classout + os.sep + os.path.basename(fn_raster)[:-4] + '_ZMEAN_' + str(id) + '_CLASS' + str(c).rjust(2, '0') + '.tif'
                        # mosaic_list = mosaic_list + [save_path]
                        # y_res, x_res = classarray.shape
                        # utility_functions.saveGeoTiff(save_path,classarray,[x_res, y_res, 1], new_geot, prj, dtype=gdal.GDT_Float32)

                        #save_path_uz = uzout + os.sep + 'uhi_urbanzone_' + str(id[:3]) + '.tif'
                        #mosaic_list_uz = mosaic_list_uz + [save_path_uz]
                        #y_res, x_res = classarray.shape
                        #utility_functions.saveGeoTiff(save_path_uz, classarray, [x_res, y_res, 1], new_geot, prj,
                        #                              dtype=gdal.GDT_Float32)

                else:
                    zstats.append(setFeatureStats( \
                         id, \
                         nodata, \
                         nodata, \
                         nodata, \
                         nodata, \
                         nodata, \
                         nodata, \
                         #nodata, \
                         nodata))
            else:
                zstats.append(setFeatureStats( \
                     id, \
                     nodata, \
                     nodata, \
                     nodata, \
                     nodata, \
                     nodata, \
                     nodata, \
                     #nodata, \
                     nodata))

            tp_ds = None
            tp_lyr = None
            tr_ds = None

            p_feat = lyr.GetNextFeature()
    csvout = classout + os.sep + "ByCSV"
    if not os.path.exists(csvout):
        os.mkdir(csvout)


    look = os.path.basename(fn_raster)
    tif_removed = os.path.basename(fn_raster)[:-4]
    original = os.path.basename(fn_raster)[:-12]
    fn_csv = csvout + os.sep + tif_removed + '_ZMEAN' + ".csv"
    fn_csv = fn_csv.replace("_MEAN_","_SMEAN_")

    #fn_csv = csvout + os.sep + os.path.basename(fn_raster)[:-12] + '_ZMEAN' + ".csv"
    col_names = zstats[0].keys()
    with open(fn_csv, 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, col_names)
        writer.writeheader()
        writer.writerows(zstats)



def uhi_stats_classbased(workdir, iata, shapefile_root, input_type,uhi_lc_root,lst_stat,section_354_directory):
    itype = input_type #, "STACK", "TOTAL", "ANNUAL"]

    for i in itype:

        if itype == "TOTAL":
            stype = ["R2", "SLOPE"]
        else:
            stype = ''
        
        # Chase said to use /caldera/projects/usgs/eros/urban_heat_islands/Phase2/section_354/FSD
        # for this folder
        #rasters = glob.glob(workdir + os.sep + "PERSISTENT_STACK" + os.sep + "*PU*" + ".tif")
        #directory = "/caldera/projects/usgs/eros/urban_heat_islands/Phase2/section_354" + os.sep + iata
        #rasters = glob.glob(directory + os.sep + "*PU*" + ".tif")

        raster_directory = section_354_directory
        #rasters = glob.glob(raster_directory + os.sep + "*PU*" + ".tif")

        rasters = glob.glob(raster_directory + os.sep + "*PUG*")
        rasters2 = glob.glob(raster_directory + os.sep + "*PNC*")
        rasters.extend(rasters2)

        for fn_raster in rasters:
            print(fn_raster)

            #identify target raster IATA code

            #iata_char_start = len(workdir) + 1 + len(i) +1 + 7
            #iata_char_end = iata_char_start + 3
            #iata = fn_raster[iata_char_start:iata_char_end]

            #locate ARD based city buffered shapefiles

            
            fn_zones = shapefile_root + os.sep + "ard_" + iata.lower() + "_city_dissolved.shp"
            #fn_zones = workdir + os.sep + "Shapefiles" + os.sep + "ard_" + iata + "_city_dissolved.shp"

            #fn_zones = workdir + os.sep + "SHAPEFILES" + os.sep + "City_Boundaries_Dissolved" + os.sep + "ard_" + iata + "_city_dissolved.shp"

            year = os.path.basename(fn_raster)[11:15]

            #determine what UHI_LC layer to use (if STACK, use most recently available layer)
            lc_rasters = glob.glob(workdir + os.sep + "UHI_LC" + os.sep + "*.tif")
            max_lcyr = 0
            for lc in lc_rasters:
                yr = int(os.path.basename(lc)[11:15])
                if yr > max_lcyr:
                    max_lcyr = yr

            if year == "STAC":
                year = str(2020)

            print(workdir + os.sep + "UHI_LC" + os.sep + iata + os.sep + "UHI_CU_" + iata + "_" + year + "*LC.tif")
            #fn_class = glob.glob(workdir + os.sep + "UHI_LC" + os.sep + iata + os.sep + "UHI_CU_" + iata + "_" + year + "*LC.tif")[0]
            #fn_class = glob.glob(workdir + os.sep + "UHI_LC" + os.sep + iata + os.sep + "UHI_CU_" + iata + "_" + year + "*LC.tif")[0]
            #fn_class = glob.glob(workdir + os.sep + "UHI_LC" + os.sep + iata + os.sep + "UHI_CU_" + iata + "_" + year + "*LC.tif")[0]
            match_string = workdir + os.sep + "UHI_LC" + os.sep + iata + os.sep + "UHI_CU_" + iata + "_*LC.tif"
            match_string1 = uhi_lc_root + os.sep + "UHI_CU_" + iata + "_" + year  + "*LC.tif"
            print("374 "+match_string)
            print("375 "+match_string1)
            fn_class = glob.glob(match_string1)
            fn_class = fn_class[0]

            #print(workdir + os.sep + "UHI_LC" + os.sep + iata + os.sep + "UHI_CU_" + iata + "_" + year + "*LC.tif")
            #fn_class = glob.glob(workdir + os.sep + "UHI_LC" + os.sep + iata + os.sep + "UHI_CU_" + iata + "_" + year + "*LC.tif")[0]

            ClassBasedStats_byUrbanZone(fn_raster, fn_zones, fn_class, workdir, i,iata)

if __name__ == '__main__':
    # https://eroslab.cr.usgs.gov/SCIENCE/urban-heat-islands/-/blob/5589d7a9c4a830459f14a3abd3bc90fa8e365224/Phase2/uhi_stats_classbased.py
    # https://stackabuse.com/command-line-arguments-in-python/
    

    #To Be Deleted
      # ["STACK", "TOTAL", "ANNUAL"]
    
    #workdir = r'\\igskmncnfs016.cr.usgs.gov\lsrdfs1\UHI\data\PreHPCTesting'
    #iata = 'FSD'
    #shapefile_root = r'\\igskmncnfs016.cr.usgs.gov\lsrdfs1\UHI\data\PreHPCTesting\SHAPEFILES\City Boundaries Dissolved'
    #uhi_stats_classbased(workdir, iata, shapefile_root,input_type, lst_stat)

    parser = argparse.ArgumentParser()

    parser.add_argument("--workdir", "-a", help="specify workpath directory")
    parser.add_argument("--iata_code", "-b", help="iata_code")
    parser.add_argument("--metrobuffer_shapefile_root", "-c", help="metrobuffer_shapefile_root")
    parser.add_argument("--uhi_lc_root", "-d", help="uhi_lc_root")
    parser.add_argument("--section_354_directory", "-e", help="section_354_directory")
    
    args = parser.parse_args()

    shapefile_root = args.metrobuffer_shapefile_root + os.sep + "City_Boundaries_Dissolved"
    iata = args.iata_code


    classout = args.workdir + os.sep + 'PERSISTENT_CITYCEN_STATS_' + "STACK"
    if os.path.exists(classout):
        shutil.rmtree(classout)
    time.sleep(5)
    if not os.path.exists(classout):
        os.mkdir(classout)


    input_type = ["STACK"]
    lst_stat = 'MAXLST'
    section_354_directory = args.section_354_directory
    uhi_stats_classbased(args.workdir, iata, shapefile_root,input_type,args.uhi_lc_root, lst_stat,section_354_directory)

    #added 1/17/23
    input_type = ["STACK"]
    lst_stat = 'MEANLST'
    section_354_directory = args.section_354_directory
    uhi_stats_classbased(args.workdir, iata, shapefile_root,input_type,args.uhi_lc_root, lst_stat,section_354_directory)