
import raster_geometry


import gdal
import os

import numpy
import numpy as np
import glob
import csv
import utility_functions as utility_functions
import fiona
import rasterio

from rasterio.mask import mask
import argparse

import shutil
import time
import datetime as dt

from rasterio.io import MemoryFile

import shutil
import time

from joblib import Parallel, delayed

def clip_raster(input, umask, output):
    #D:\Users\cwmueller\data\UHI\FSD\output_folder\CITY BOUNDARIES
    with fiona.open(umask, "r") as shapefile:
        shapes = [feature["geometry"] for feature in shapefile]

    with rasterio.open(input) as src:
        out_image, out_transform = mask(src, shapes, crop=True)
        #out_image, out_transform = mask(src, shapes, crop=True,pad=False, pad_width=0.0,all_touched=False)
        out_meta = src.meta
        out_image = np.ma.masked_where(out_image == 0, out_image)

    out_meta.update({"driver": "GTiff",
                     "height": out_image.shape[1],
                     "width": out_image.shape[2],
                     "transform": out_transform,
                     'dtype': 'uint8',
                     'nodata': 0})
    print(out_image.shape)
    out_image = out_image.data.astype('uint8')

    with rasterio.open(output, "w", **out_meta) as dest:
        dest.write(out_image)


def savetiff(workdir, raster, array, trans, prj, outputid):
    hotspot_out = workdir + os.sep + outputid

    if not os.path.exists(hotspot_out):
        os.mkdir(hotspot_out)
    save_path = hotspot_out + os.sep + os.path.basename(raster)[:-4] + '_' + outputid + '.tif'
    save_path = save_path.replace("_CLIPPED","")
    print("save_path:"+save_path)

    y_res, x_res = array.shape
    utility_functions.saveGeoTiff(save_path, array, [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)
    look3 = np.unique(array)
    return save_path


#Load UHI Urban Mask
def hotspot(workdir, raster, iata,uhi_lc_root,lc_clipped_folder,ard_city_dissolved_shapefile):

    #workdir is gonna be like /caldera/projects/usgs/eros/urban_heat_islands/Phase1/FSD/output_folder

    #for a given stack pixel value, identify it's urban code and landclass and determine the lc MAXLST + 1 SD
    #return a value of 1 for everypixel that exceeds it's respective MAXLST + 1 SD (by class per urban area)
    
    ### LOAD INPUT RASTERS ###
    uz_file = glob.glob(workdir + os.sep + "UrbanZones" + os.sep + "*" + iata + ".tif")[0]
    uz_raster = gdal.Open(uz_file)
    uz_array = np.array(uz_raster.ReadAsArray())
    #print(uz_array.dtype)


    fn_raster = gdal.Open(raster)
    #geot = fn_raster.GetGeoTransform()
    fn_array = np.array(fn_raster.ReadAsArray())
    #trans, prj = utility_functions.get_geo(raster)#old
    #print(fn_array.dtype)

    #HERE1
    # load UHI_LC
    # determine what UHI_LC layer to use (if STACK, use most recently available layer)
    
    lc_rasters = glob.glob(uhi_lc_root + "*.tif")

    #Set year to most recent dataset if input is STACK data, otherwise set year to match the input file
    max_lcyr = 0
    year = os.path.basename(raster)[11:15]
    for lc in lc_rasters:
        yr = int(os.path.basename(lc)[11:15])
        if yr > max_lcyr:
            max_lcyr = yr

    if year == "STAC":
        year = str(max_lcyr)


    if "MAXLST" in raster:
        csv_match_requirement = "MAXLST"
    
    if "MEANLST" in raster:
        csv_match_requirement = "MEANLST"

    print(year)
    # LST_CLASS_STATS_ANNUAL is generated by phase2_uhi_stats_classbased.py
    # these appear in the phase1 city output folder, eg
    # /caldera/projects/usgs/eros/urban_heat_islands/Phase1/FSD/output_folder
    print(workdir + os.sep + "LST_CLASS_STATS_ANNUAL" + os.sep + "ByCSV" + os.sep + "*" + year + "*" +  "_ZMEAN.csv")
    stats_csv = glob.glob(workdir + os.sep + "LST_CLASS_STATS_ANNUAL" + os.sep + "ByCSV" + os.sep + "*" + year + "*" + csv_match_requirement  +"*"+"_ZMEAN.csv")[0]
    print(stats_csv)
    with open(stats_csv) as csv_file:
        data = list(csv.reader(csv_file))

    #HERE2
    #UHI_LC* rasters are the output of step 332
    #eg they would be in a place like /caldera/projects/usgs/eros/urban_heat_islands/Phase2/section_332/output/FSD

    lc_file = glob.glob(uhi_lc_root + "*" + str(year) + "*.tif")[0]
    print(lc_file)

    processing_date = dt.datetime.today().strftime('%Y%m%d')
    lc_o_file = os.path.basename(lc_file)[:-4]#'UHI_CU_FSD_2007_20230124_C01_LC'
    lc_o_file_pieces = lc_o_file.split("_")
    new_name = lc_o_file_pieces[0]+"_"+lc_o_file_pieces[1]+"_"+lc_o_file_pieces[2]+"_"+lc_o_file_pieces[3]+"_"+processing_date+"_"+lc_o_file_pieces[5]+"_"+lc_o_file_pieces[6]
    lc_clipped_file = lc_clipped_folder + os.sep + new_name + '_clipped.tif'
    
    print(lc_clipped_file)

    umask = ard_city_dissolved_shapefile


    clip_raster(lc_file,umask, output=lc_clipped_file)

    #hack to test if clip_raster is causing the +1 which breaks
    #lc_clipped_file = lc_file


    lc_clipped_raster = gdal.Open(lc_clipped_file)

    trans, prj = utility_functions.get_geo(lc_clipped_file)

    #lc_array = np.array(lc_clipped_raster.ReadAsArray()[:-1])
    lc_array = np.array(lc_clipped_raster.ReadAsArray())

    ### GENERATE A HOTSPOT RASTER ###
    # Anywhere a MAXLST pixel is greater than the same urban zones class MAX + 1 SD

    uhi_id_list = []
    # Create list of UHI ID's
    for n in range(1, len(data)):
        if data[n][0] not in uhi_id_list:
            uhi_id_list = uhi_id_list + [data[n][0]]


    #the lc_array is +1 on both x and y
    #slice it down
    # fn_shape = fn_array.shape
    # lc_shape = lc_array.shape
    # smallest_shape = list(fn_shape)
    # if lc_shape[0]<smallest_shape[0]:
    #     smallest_shape[0] = lc_shape[0]
    # if lc_shape[1]<smallest_shape[1]:
    #     smallest_shape[1] = lc_shape[1]

    # lc_array = lc_array[0:smallest_shape[0],0:smallest_shape[1]]
    # fn_array = fn_array[0:smallest_shape[0],0:smallest_shape[1]]
    # uz_array = uz_array[0:smallest_shape[0],0:smallest_shape[1]]

    lc_array_raster_path = lc_clipped_file
    fn_array_raster_path = raster
    uz_array_raster_path = uz_file
    common_bbox = raster_geometry.get_common_extent_from_raster_paths([lc_array_raster_path,fn_array_raster_path,uz_array_raster_path])
    lc_array_rio_ds = rasterio.open(lc_array_raster_path)
    fn_array_rio_ds = rasterio.open(fn_array_raster_path)
    uz_array_rio_ds = rasterio.open(uz_array_raster_path)
    
    lc_window = rasterio.features.geometry_window(lc_array_rio_ds, [common_bbox])
    fn_window = rasterio.features.geometry_window(fn_array_rio_ds, [common_bbox])
    uz_window = rasterio.features.geometry_window(uz_array_rio_ds, [common_bbox])

    lc_array = lc_array_rio_ds.read(1, window=lc_window)
    fn_array = fn_array_rio_ds.read(1, window=fn_window)
    #uz_array = uz_array_rio_ds.read(1, window=uz_window)
    





    #lc_array = lc_array[0:fn_array.shape[0],0:fn_array.shape[1]]

    #some_x = fn_array.shape[0]
    #some_y = fn_array.shape[1]
    #lc_array_new = lc_array[0:some_x,0:some_y]
    #lc_array_new2 = lc_array[0:some_y,0:some_x]

    look = np.unique(lc_array)

    diff_array = np.where(lc_array > 20, fn_array, -9999)


    print("Processing Hotspot, Urban vs Non-Urban, and Normalized by LC analyses")


    memfile = MemoryFile()
    output_memfile = memfile.open(**uz_array_rio_ds.meta)

    for n in range(1, len(data)):
        UHI_ID = data[n][0]
        lc_class = data[n][1]

        #the full_uz raster has a problem
        #it was made by mosaicing the uz separate rasters
        #and they overlap
        #so they write nodata over eachother where there is supposed to be data and vice versa
        #so here we open the uz raster
        #and write those separate pieces over it as we iterate through them
        #to make sure that when we read from that extent it is correct

        #'/caldera/projects/usgs/eros/urban_heat_islands/Phase1/MSY/output_folder/UrbanZones/uhi_urbanzone_MSY.tif'
        uz_single_path = '/caldera/projects/usgs/eros/urban_heat_islands/Phase1/'+iata+'/output_folder/UrbanZones/'+iata+'_sep/'+'uhi_urbanzone_'+str(UHI_ID)+'.tif'
        uz_single_ds = rasterio.open(uz_single_path)
        uz_single_bbox = raster_geometry.get_common_extent_from_raster_paths([uz_single_path])
        uz_single_window_for_destination_write = rasterio.features.geometry_window(uz_array_rio_ds, [uz_single_bbox])

        output_memfile.write(np.array([uz_single_ds.read(1)]),window=uz_single_window_for_destination_write)
        uz_window = rasterio.features.geometry_window(uz_array_rio_ds, [common_bbox])
        uz_array = output_memfile.read(1,window=uz_window)
        #output_memfile.close()
        uz_single_ds.close()


        diff = (fn_array - float(data[n][9]))
        diff_array = np.where((uz_array == int(UHI_ID[-4:])) & (lc_array > 20), diff, diff_array)
        del uz_array
        print(" Processing for LC class: " + str(lc_class) + " "+str(UHI_ID))



    diff_array[lc_array < 20] = -9999
    diff_array[lc_array == 241] = -9999#bc I saw this in the lc_array
    diff_array[diff_array < -500] = -9999 #-500 is a temporary quick solution - need to exclude -9999 pixels in above calculations

    look3 = np.unique(fn_array)
    look4 = np.unique(lc_array)
    look5 = np.unique(diff_array)
    #saved_path = savetiff(workdir, raster, diff_array, trans, prj, 'INTENSITY_ANNUAL')

    outputid = 'INTENSITY_ANNUAL'
    hotspot_out = workdir + os.sep + outputid
    if not os.path.exists(hotspot_out):
        os.mkdir(hotspot_out)
    save_path = hotspot_out + os.sep + os.path.basename(raster)[:-4] + '_' + outputid + '.tif'
    save_path = save_path.replace("_CLIPPED","")
    print("save_path:"+save_path)


    new_raster_path2 = save_path
    win_transform = fn_array_rio_ds.window_transform(fn_window)
    out_meta = fn_array_rio_ds.meta
    out_meta.update({
                    #"driver":"GTiff",
                    #"count":1,
                     "height":fn_window.height,
                     "width":fn_window.width,
                     "transform":win_transform
                     })
    with rasterio.open(new_raster_path2,"w",**out_meta) as dest:
        dest.write(np.array([diff_array]))

    ABC = None


if __name__ == '__main__':
 

    iata = "DCA"
    workdir = "/caldera/projects/usgs/eros/urban_heat_islands/Phase1/"+iata+"/output_folder"
    #LST_ANNUAL_CLIPPED comes from phase2_1_uhi_urbanbuffer_clip.py appears in this folder /caldera/projects/usgs/eros/urban_heat_islands/Phase1/MSO/output_folder
    iata = iata
    uhi_lc_root = "/caldera/projects/usgs/eros/urban_heat_islands/Phase2/section_332/output" + os.sep + iata + os.sep

    match_string = workdir + os.sep + r'LST_ANNUAL_CLIPPED'+ os.sep +'UHI_CU_'+iata+'_*.tif'
    raster_list = glob.glob(match_string)
    ard_city_dissolved_shapefile = "/caldera/projects/usgs/eros/urban_heat_islands/Phase2/uhi_metrobuffer_generation/City_Boundaries_Buffer/ard_"+iata+"_city_dissolved.shp"
    lc_clipped_output_folder = workdir + os.sep + 'LC_clipped'

    # if os.path.exists(lc_clipped_output_folder):
    #     shutil.rmtree(lc_clipped_output_folder)
    #     time.sleep(5)
    # if not os.path.exists(lc_clipped_output_folder):
    #     os.mkdir(lc_clipped_output_folder)

    # #2_7 had to be multithreaded so it used its own folders 
    # #but these files need to be copied to the main UrbanZones folder
    # #and we also need to clean up those other folders

    # cleanup_folders = []
    # urban_zone_base_path = workdir + os.sep + "UrbanZones"
    # urban_zone_folders = os.listdir(workdir + os.sep + "UrbanZones")
    # for urban_zone_folder in urban_zone_folders:
    #     if "UHI" in urban_zone_folder:
    #         source_urban_zone_folder_path = os.path.join(urban_zone_base_path,urban_zone_folder)
    #     if "CLIPPED" in urban_zone_folder:
    #         cleanup_folders.append(urban_zone_folder)

    # uz_file =  "uhi_urbanzone_"+iata+".tif"
    # uz_file_dest = workdir + os.sep + "UrbanZones" + os.sep + uz_file
    # uz_file_source = os.path.join(source_urban_zone_folder_path,uz_file)

    # shutil.copy(uz_file_source,uz_file_dest)

    # urban_zone_sep_dest_path = workdir + os.sep + "UrbanZones" + os.sep + iata + "_sep"
    # urban_zone_sep_source_path = os.path.join(source_urban_zone_folder_path, iata + "_sep")

    # urban_zone_tif_files = os.listdir(urban_zone_sep_source_path)
    # for urban_zone_tif in urban_zone_tif_files:
    #     tif_full_source_path = os.path.join(urban_zone_sep_source_path,urban_zone_tif)
    #     shutil.copy(tif_full_source_path,urban_zone_sep_dest_path)
    #     ABC = None


    # for remove_this_folder in cleanup_folders:
    #     remove_folder_path = os.path.join(urban_zone_base_path,remove_this_folder)
    #     ABC = None
    #     shutil.rmtree(remove_folder_path)

    
    out_path = workdir + os.sep + 'INTENSITY_ANNUAL'
    # if os.path.exists(out_path):
    #     shutil.rmtree(out_path)
    #     time.sleep(5)
    # if not os.path.exists(out_path):
    #     os.mkdir(out_path)

    for raster in raster_list:
        if "2018" in raster:
            hotspot(workdir, raster, iata,uhi_lc_root,lc_clipped_output_folder,ard_city_dissolved_shapefile)

    #raster = 'D:\\Users\\cwmueller\\data\\UHI\\FSD\\output_folder\\STACK\\UHI_CU_FSD_STACK_20211118_C01_MAXLST_MEAN.tif'
    #workdir = r'D:\Users\cwmueller\data\UHI\FSD\output_folder'
    #iata = 'FSD'
    #hotspot(workdir, raster, iata)