import gdal
import os

import numpy
import numpy as np
import glob
import csv
import utility_functions as utility_functions
import fiona
import rasterio

from rasterio.mask import mask
import argparse

def clip_raster(input, umask, output):
    #D:\Users\cwmueller\data\UHI\FSD\output_folder\CITY BOUNDARIES
    with fiona.open(umask, "r") as shapefile:
        shapes = [feature["geometry"] for feature in shapefile]

    with rasterio.open(input) as src:
        out_image, out_transform = mask(src, shapes, crop=True)
        out_meta = src.meta
        out_image = np.ma.masked_where(out_image == 0, out_image)

    out_meta.update({"driver": "GTiff",
                     "height": out_image.shape[1],
                     "width": out_image.shape[2],
                     "transform": out_transform})
    #print(out_image.shape[1])
    out_image = out_image.data

    with rasterio.open(output, "w", **out_meta) as dest:
        dest.write(out_image)

def boundingBoxToOffsets(geot, cgeot, xsize, ysize, lc):
    lcadj = 0
    if lc == True:
        lcadj = 1

    col1 = int((geot[0] - cgeot[0]) / cgeot[1])
    col2 = int((geot[0] - cgeot[0] + xcsize * geot[1]) / cgeot[1]) + 1
    row1 = int((geot[3] - cgeot[3]) / cgeot[5])
    row2 = int((geot[3] - cgeot[3] + ycsize * geot[5]) / cgeot[5]) + 1

    if geot[0] < cgeot[0]:
        col1 = 0
    if (geot[0] + xcsize*geot[1]) > (cgeot[0]+xsize*cgeot[1]):
        radj = cgeot[0]+(xsize*cgeot[1])
        col2 = int((radj - cgeot[0]) / cgeot[1])
    if geot[3] > cgeot[3]:
        row1 = 0 + lcadj
    if (geot[3]+ycsize*geot[5]) < (cgeot[3]+ysize*cgeot[5]):
        ladj = cgeot[3]+(ysize*cgeot[5])
        row2 = int((ladj - cgeot[3]) / cgeot[5]) - lcadj

    return [row1, row2, col1, col2]

def geotFromOffsets(row_offset, col_offset, geot):
    new_geot = [
        geot[0] + (col_offset * geot[1]),
        geot[1],
        0.0,
        geot[3] + (row_offset * geot[5]),
        0.0,
        geot[5]
    ]
    return new_geot

def savetiff(workdir, raster, array, trans, prj, outputid):
    hotspot_out = workdir + os.sep + outputid

    if not os.path.exists(hotspot_out):
        os.mkdir(hotspot_out)
    save_path = hotspot_out + os.sep + os.path.basename(raster)[:-4] + '_' + outputid + '.tif'

    y_res, x_res = array.shape
    utility_functions.saveGeoTiff(save_path, array, [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

#Load UHI Urban Mask
def hotspot(workdir, raster, iata):
    #for a given stack pixel value, identify it's urban code and landclass and determine the lc MAXLST + 1 SD
    #return a value of 1 for everypixel that exceeds it's respective MAXLST + 1 SD (by class per urban area)

    ### LOAD INPUT RASTERS ###
    uz_file = glob.glob(workdir + os.sep + "UrbanZones" + os.sep + "*" + iata + ".tif")[0]
    uz_raster = gdal.Open(uz_file)
    uz_array = np.array(uz_raster.ReadAsArray())
    #print(uz_array.dtype)



    fn_raster = gdal.Open(raster)
    #geot = fn_raster.GetGeoTransform()
    fn_array = np.array(fn_raster.ReadAsArray())
    trans, prj = utility_functions.get_geo(raster)
    #print(fn_array.dtype)

    #load UHI_LC
    # determine what UHI_LC layer to use (if STACK, use most recently available layer)
    lc_rasters = glob.glob(workdir + os.sep + "UHI_LC" + os.sep + "*.tif")

    #Set year to most recent dataset if input is STACK data, otherwise set year to match the input file
    max_lcyr = 0
    year = os.path.basename(raster)[11:15]
    for lc in lc_rasters:
        yr = int(os.path.basename(lc)[11:15])
        if yr > max_lcyr:
            max_lcyr = yr

    if year == "STAC":
        year = str(max_lcyr)

    print(year)

    stats_csv = glob.glob(workdir + os.sep + "CLASS_STATS_ANNUAL" +os.sep + "ByCSV" + os.sep + "*" + year + "*" +  "_ZMEAN.csv")[0]
    print(stats_csv)
    with open(stats_csv) as csv_file:
        data = list(csv.reader(csv_file))


    lc_file = glob.glob(workdir + os.sep + "UHI_LC" + os.sep + "*" + str(year) + "*.tif")[0]
    print(lc_file)
    #lc_raster = gdal.Open(lc_file)
    lc_clipped_file = r'D:\Users\cwmueller\data\UHI\FSD\output_folder\LC_clipped' + os.sep + os.path.basename(lc_file)[:-4] + '_clipped.tif'
    print(lc_clipped_file)
    umask = r'D:\Users\cwmueller\data\UHI\FSD\output_folder\Shapefiles\ard_fsd_city_dissolved.shp'


    clip_raster(lc_file,umask, lc_clipped_file)
    lc_clipped_raster = gdal.Open(lc_clipped_file)
    lc_array= np.array(lc_clipped_raster.ReadAsArray()[:-1])
    #print(lc_array.dtype)
    ### GENERATE A HOTSPOT RASTER ###
    # Anywhere a MAXLST pixel is greater than the same urban zones class MAX + 1 SD

    uhi_id_list = []
    # Create list of UHI ID's
    for n in range(1, len(data)):
        if data[n][0] not in uhi_id_list:
            uhi_id_list = uhi_id_list + [data[n][0]]

    #print(uhi_id_list)

    #hotspot_array = np.copy(fn_array)
    #hotspot_array = np.where(hotspot_array > 0, 0, -9999)
    #print(hotspot_array.dtype)

    #nurb_dict = {}

    #for nu in range(1, len(data)):
    #    UHI_ID = data[n][0]
    #    lc_class = data[n][1]
    #    nurb_dict[UHI_ID]=[data[nu][]]

    hotspot_array = np.copy(fn_array)
    #norm_array = np.copy(fn_array)
    #_array = np.where(lc_array > 9, fn_array, -9999)
    #print("Max: " + str(np.max(hotspot_array)))
    print("Processing Hotspot, and Normalized by LC analyses")

    for n in range(1, len(data)):
        UHI_ID = data[n][0]
        lc_class = data[n][1]
        LST_threshold = float(data[n][4])+float(data[n][6])
        print(" Processing for LC class: " + str(lc_class))

        if (float(data[n][3]) - float(data[n][2])) == 0:
            #norm = -9999
        else:
            #norm = (fn_array - float(data[n][2])) / (float(data[n][3]) - float(data[n][2]))
            #diff = (fn_array - data[n][9])

        #WHERE the uz_mask = UHI_ID[-4:] AND lc_raster = lc_class AND fn_raster > LST_threshold
        hotspot_array = np.where((uz_array == int(UHI_ID[-4:])) & (lc_array == int(lc_class)) & (fn_array > LST_threshold),1,hotspot_array)
        #may need to check to see if any input array values are 1.0...if sothey may be accidently classified as a hotspot

        #norm_array = np.where((uz_array == int(UHI_ID[-4:])) & (lc_array == int(lc_class)), norm, norm_array)
        #norm_array[norm_array == 1] = norm


        #can't calculate for intensity data
        #diff_array = np.where((uz_array == int(UHI_ID[-4:])) & (lc_array > 9), diff, diff_array)

        #norm_array = np.where((uz_array == int(UHI_ID[-4:])) & (lc_array == int(lc_class)), norm, norm_array)
        #norm_array[norm_array < 0] = -9999
        #print(np.max(np.where((uz_array == int(UHI_ID[-4:])) & (lc_array == int(lc_class)),hotspot_array,0)))

    #norm_array[norm_array > 1] = -9999
    hotspot_array[hotspot_array != 1] = 0
    #diff_array[diff_array < -500] = -9999 #-500 is a temporary quick solution - need to exclude -9999 pixels in above calculations
    #print(np.max(hotspot_array))

    savetiff(workdir, raster, hotspot_array, trans, prj, 'HOTSPOT')
    #savetiff(workdir, raster, norm_array, trans, prj, 'NORMALIZED_byLC')
    #savetiff(workdir, raster, diff_array, trans, prj, 'INTENSITY')


if __name__ == '__main__':
    # https://stackabuse.com/command-line-arguments-in-python/
    parser = argparse.ArgumentParser()

    parser.add_argument("--workdir", "-w", help="specify workpath directory")
    parser.add_argument("--iata", "-i", help="iata_location_identifier")


    # Read arguments from the command line
    args = parser.parse_args()

    if args.workdir and args.iata:
        uhi_urb_v_nonurb()#args.workdir, args.iata)

    raster_list = glob.glob(r'D:\Users\cwmueller\data\UHI\FSD\output_folder\INTENSITY_ANNUAL\MAXLST\UHI_CU_FSD_*.tif')
    workdir = r'D:\Users\cwmueller\data\UHI\FSD\output_folder'
    iata = 'FSD'

    for raster in raster_list:
        hotspot(workdir, raster, iata)

    #raster = 'D:\\Users\\cwmueller\\data\\UHI\\FSD\\output_folder\\STACK\\UHI_CU_FSD_STACK_20211118_C01_MAXLST_MEAN.tif'
    #workdir = r'D:\Users\cwmueller\data\UHI\FSD\output_folder'
    #iata = 'FSD'
    #hotspot(workdir, raster, iata)