import raster_geometry
import numpy as np
import glob
import csv
import utility_functions as utility_functions
import fiona
import rasterio

from rasterio.mask import mask
import argparse

import os
import gdal

import datetime as dt


import shutil
import time

from joblib import Parallel, delayed


def clip_raster(input, umask, output):
    #D:\Users\cwmueller\data\UHI\FSD\output_folder\CITY BOUNDARIES
    with fiona.open(umask, "r") as shapefile:
        shapes = [feature["geometry"] for feature in shapefile]

    with rasterio.open(input) as src:
        out_image, out_transform = mask(src, shapes, crop=True)
        #out_image, out_transform = mask(src, shapes, crop=True,pad=False, pad_width=0.0,all_touched=False)
        out_meta = src.meta
        out_image[(out_image==-9999)]=0
        out_image = np.ma.masked_where(out_image == 0, out_image)

    out_meta.update({"driver": "GTiff",
                     "height": out_image.shape[1],
                     "width": out_image.shape[2],
                     "transform": out_transform,
                     'dtype': 'uint8',
                     'nodata': 0
                     #'nodata': -9999
                     })
    print(out_image.shape)
    
    out_image2 = out_image.data.astype('uint8')

    with rasterio.open(output, "w", **out_meta) as dest:
        dest.write(out_image2)


def savetiff(workdir, raster, array, trans, prj, outputid):
    hotspot_out = workdir + os.sep + outputid

    if not os.path.exists(hotspot_out):
        os.mkdir(hotspot_out)
    save_path = hotspot_out + os.sep + os.path.basename(raster)[:-4] + '_' + outputid + '.tif'
    save_path = save_path.replace("_CLIPPED","")
    print("save_path:"+save_path)

    y_res, x_res = array.shape
    utility_functions.saveGeoTiff(save_path, array, [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)
    return save_path



#Load UHI Urban Mask
def hotspot(workdir, raster, iata,uhi_lc_root,lc_clipped_folder,ard_city_dissolved_shapefile):

    #workdir is gonna be like /caldera/projects/usgs/eros/urban_heat_islands/Phase1/FSD/output_folder


    #for a given stack pixel value, identify it's urban code and landclass and determine the lc MAXLST + 1 SD
    #return a value of 1 for everypixel that exceeds it's respective MAXLST + 1 SD (by class per urban area)

    
    ### LOAD INPUT RASTERS ###
    uz_file = glob.glob(workdir + os.sep + "UrbanZones" + os.sep + "*" + iata + ".tif")[0]
    uz_raster = gdal.Open(uz_file)
    uz_array = np.array(uz_raster.ReadAsArray())
    #print(uz_array.dtype)



    fn_raster = gdal.Open(raster)
    #geot = fn_raster.GetGeoTransform()
    fn_array = np.array(fn_raster.ReadAsArray())
    trans, prj = utility_functions.get_geo(raster)
    #print(fn_array.dtype)

    #HERE1
    # load UHI_LC
    # determine what UHI_LC layer to use (if STACK, use most recently available layer)
    #
    
    
    lc_rasters = glob.glob(uhi_lc_root + "*.tif")

    #Set year to most recent dataset if input is STACK data, otherwise set year to match the input file
    max_lcyr = 0
    year = os.path.basename(raster)[11:15]
    for lc in lc_rasters:
        yr = int(os.path.basename(lc)[11:15])
        if yr > max_lcyr:
            max_lcyr = yr

    if year == "STAC":
        year = str(max_lcyr)


    if "MAXLST" in raster:
        lst_stat = "MAXLST"
    if "MEANLST" in raster:
        lst_stat = "MEANLST"

    print(year)
    # LST_CLASS_STATS_ANNUAL is generated by phase2_uhi_stats_classbased.py
    # these appear in the phase1 city output folder, eg
    # /caldera/projects/usgs/eros/urban_heat_islands/Phase1/FSD/output_folder
    #eg '/caldera/projects/usgs/eros/urban_heat_islands/Phase1/FSD/output_folder/LST_CLASS_STATS_ANNUAL/ByCSV/UHI_CU_FSD_2017_20210727_C01_MEANLST_CLIPPED_ZMEAN.csv'

    match_string = workdir + os.sep + "LST_CLASS_STATS_ANNUAL" + os.sep + "ByCSV" + os.sep + "*" + year + "*" + lst_stat +"*_ZMEAN.csv"
    print(match_string)
    stats_csv = glob.glob(match_string)[0]
    print(stats_csv)
    with open(stats_csv) as csv_file:
        data = list(csv.reader(csv_file))


    #HERE2
    #UHI_LC* rasters are the output of step 332
    #eg they would be in a place like /caldera/projects/usgs/eros/urban_heat_islands/Phase2/section_332/output/FSD
    #
    
    lc_file = glob.glob(uhi_lc_root + "*" + str(year) + "*.tif")[0]
    print(lc_file)
    lc_clipped_file = lc_clipped_folder + os.sep + os.path.basename(lc_file)[:-4] + '_clipped.tif'
    print(lc_clipped_file)
    umask = ard_city_dissolved_shapefile


    #so this is the part that is tricky
    #MEANLST + MAXLST will try to clip the same raster and collide
    #going to solve this by doing separate runs in main for MEANLST + MAXLST
    clip_raster(lc_file, umask, lc_clipped_file)


    lc_clipped_raster = gdal.Open(lc_clipped_file)
    trans, prj = utility_functions.get_geo(lc_clipped_file)

    lc_array= np.array(lc_clipped_raster.ReadAsArray()[:-1])

    #the lc_array is +1 on both x and y for ones like MSO
    #slice it down to the correct size
    #the lc_array is +1 on both x and y for ones like MSO
    #slice it down to the correct size
    #the lc_array is +1 on both x and y
    #slice it down
    fn_shape = fn_array.shape
    lc_shape = lc_array.shape
    smallest_shape = list(fn_shape)
    if lc_shape[0]<smallest_shape[0]:
        smallest_shape[0] = lc_shape[0]
    if lc_shape[1]<smallest_shape[1]:
        smallest_shape[1] = lc_shape[1]

    lc_array = lc_array[0:smallest_shape[0],0:smallest_shape[1]]
    fn_array = fn_array[0:smallest_shape[0],0:smallest_shape[1]]
    uz_array = uz_array[0:smallest_shape[0],0:smallest_shape[1]]


    lc_array_raster_path = lc_clipped_file
    fn_array_raster_path = raster
    uz_array_raster_path = uz_file
    common_bbox = raster_geometry.get_common_extent_from_raster_paths([lc_array_raster_path,fn_array_raster_path,uz_array_raster_path])
    lc_array_rio_ds = rasterio.open(lc_array_raster_path)
    fn_array_rio_ds = rasterio.open(fn_array_raster_path)
    uz_array_rio_ds = rasterio.open(uz_array_raster_path)
    
    lc_window = rasterio.features.geometry_window(lc_array_rio_ds, [common_bbox])
    fn_window = rasterio.features.geometry_window(fn_array_rio_ds, [common_bbox])
    uz_window = rasterio.features.geometry_window(uz_array_rio_ds, [common_bbox])

    lc_array = lc_array_rio_ds.read(1, window=lc_window)
    fn_array = fn_array_rio_ds.read(1, window=fn_window)
    uz_array = uz_array_rio_ds.read(1, window=uz_window)
    
    fn_shape2 = fn_array.shape
    lc_shape2 = lc_array.shape
    uz_shape2 = uz_array.shape





    ### GENERATE A HOTSPOT RASTER ###
    # Anywhere a MAXLST pixel is greater than the same urban zones class MAX + 1 SD

    uhi_id_list = []
    # Create list of UHI ID's
    for n in range(1, len(data)):
        if data[n][0] not in uhi_id_list:
            uhi_id_list = uhi_id_list + [data[n][0]]


    hotspot_array = np.copy(fn_array)

    print("Processing Hotspot, Urban vs Non-Urban, and Normalized by LC analyses")

    for n in range(1, len(data)):
        UHI_ID = data[n][0]
        lc_class = data[n][1]
        LST_threshold = float(data[n][4])+float(data[n][6])
        print(" Processing for LC class: " + str(lc_class))
        
        one = (uz_array == int(UHI_ID[-4:]))
        two = (lc_array == int(lc_class))
        three = (fn_array > LST_threshold)
        four = (uz_array == int(UHI_ID[-4:])) & (lc_array == int(lc_class)) & (fn_array > LST_threshold)

        hotspot_array = np.where((uz_array == int(UHI_ID[-4:])) & (lc_array == int(lc_class)) & (fn_array > LST_threshold),1,hotspot_array)
        

    hotspot_array[hotspot_array != 1] = 0
    #savetiff(workdir, raster, hotspot_array, trans, prj, 'HOTSPOT')


    raster_path = raster
    processing_date = dt.datetime.today().strftime('%Y%m%d')
    filename_pieces = os.path.basename(raster_path).split("_")
    raster_folder = os.path.dirname(raster_path)
    new_filename = filename_pieces[0] + "_" + filename_pieces[1] + "_" + filename_pieces[2] + "_" + filename_pieces[3] + "_" + processing_date + "_" + filename_pieces[5] + "_" + filename_pieces[6] + "_" + filename_pieces[7]
    new_raster_path = os.path.join(raster_folder, new_filename)

    #there is a bug in this one that saves with an offset
    #save_path = savetiff(workdir, new_raster_path, hotspot_array, trans, prj, 'HOTSPOT')

    hotspot_out = workdir + os.sep + 'HOTSPOT'
    #if not os.path.exists(hotspot_out):moved to main
    #    os.mkdir(hotspot_out)
    save_path = hotspot_out + os.sep + os.path.basename(new_raster_path)[:-4] + '_' + 'HOTSPOT' + '.tif'
    save_path = save_path.replace("_CLIPPED","")


    #hotspot_out


    new_raster_path2 = save_path
    win_transform = fn_array_rio_ds.window_transform(fn_window)
    out_meta = fn_array_rio_ds.meta
    out_meta.update({
                    #"driver":"GTiff",
                    #"count":1,
                     "height":fn_window.height,
                     "width":fn_window.width,
                     "transform":win_transform
                     })
    with rasterio.open(new_raster_path2,"w",**out_meta) as dest:
        dest.write(np.array([hotspot_array]))

    ABC = None

if __name__ == '__main__':


    #------------------------------------------------------|

    # raster = 'D:\\Users\\cwmueller\\data\\UHI\\FSD\\output_folder\\STACK\\UHI_CU_FSD_STACK_20211118_C01_MAXLST_MEAN.tif'
    # workdir = r'D:\Users\cwmueller\data\UHI\FSD\output_folder'
    # iata = 'FSD'
    # hotspot(workdir, raster, iata)
    # workdir = r'\\igskmncnfs016.cr.usgs.gov\lsrdfs1\UHI\data\Pre HPC Testing'
    # raster_list = glob.glob(workdir + r'\LST_ANNUAL_CLIPPED\UHI_CU_FSD_*.tif')
    # iata = 'FSD'

    # https://stackabuse.com/command-line-arguments-in-python/
    # parser = argparse.ArgumentParser()
    # parser.add_argument("--workdir", "-w", help="specify workpath directory")
    # parser.add_argument("--iata", "-i", help="iata_location_identifier")
    # Read arguments from the command line
    # args = parser.parse_args()

    parser = argparse.ArgumentParser()
    parser.add_argument("--workdir", "-w", help="specify workpath directory")
    parser.add_argument("--iata_location_identifier", "-i", help="iata_location_identifier")
    parser.add_argument("--uhi_lc_root", "-m", help="metrobuffer_shapefile_root") 
    args = parser.parse_args()

    workdir = args.workdir
    iata = args.iata_location_identifier
    uhi_lc_root = args.uhi_lc_root

    #workdir = "/caldera/projects/usgs/eros/urban_heat_islands/Phase1/FSD/output_folder"
    # LST_ANNUAL_CLIPPED comes from phase2_1_uhi_urbanbuffer_clip.py appears in this folder /caldera/projects/usgs/eros/urban_heat_islands/Phase1/MSO/output_folder
    #iata = 'FSD'
    #uhi_lc_root = "/caldera/projects/usgs/eros/urban_heat_islands/Phase2/section_332/output" + os.sep + iata + os.sep
    
    match_string = workdir + os.sep + r'LST_ANNUAL_CLIPPED'+ os.sep +'UHI_CU_'+iata+'_*.tif'
    raster_list = glob.glob(match_string)
    ard_city_dissolved_shapefile = "/caldera/projects/usgs/eros/urban_heat_islands/Phase2/uhi_metrobuffer_generation/City_Boundaries_Buffer/ard_"+iata+"_city_dissolved.shp"
    lc_clipped_output_folder = workdir + os.sep + 'LC_clipped'

    output_folder = '/caldera/projects/usgs/eros/urban_heat_islands/Phase1/'+iata+'/output_folder/HOTSPOT/'
    if os.path.exists(output_folder):
        shutil.rmtree(output_folder)
        time.sleep(5)
    if not os.path.exists(output_folder):
        os.mkdir(output_folder)

    hotspot_out = workdir + os.sep + 'HOTSPOT'
    if not os.path.exists(hotspot_out):
        os.mkdir(hotspot_out)

    #for raster in raster_list:
    #     hotspot(workdir, raster, iata,uhi_lc_root,lc_clipped_output_folder,ard_city_dissolved_shapefile)

    meanlst_raster_paths = []
    maxlst_raster_paths = []
    for raster_path in raster_list:
        if "MEANLST" in raster_path:
            meanlst_raster_paths.append(raster_path)
        if "MAXLST" in raster_path:
            maxlst_raster_paths.append(raster_path)
    #Parallel(n_jobs=20)(delayed(hotspot)(workdir, raster, iata,uhi_lc_root,lc_clipped_output_folder,ard_city_dissolved_shapefile) for raster in raster_list)
    Parallel(n_jobs=20)(delayed(hotspot)(workdir, raster, iata,uhi_lc_root,lc_clipped_output_folder,ard_city_dissolved_shapefile) for raster in meanlst_raster_paths)
    Parallel(n_jobs=20)(delayed(hotspot)(workdir, raster, iata,uhi_lc_root,lc_clipped_output_folder,ard_city_dissolved_shapefile) for raster in maxlst_raster_paths)


    #------------------------------------------------------|







