import os
import numpy as np
from osgeo import gdal
import datetime as dt
import glob
import utility_functions as utility_functions
import gc
import numpy.ma as ma
import fiona
import rasterio

import argparse

from rasterio.mask import mask

import shutil


from shapely.geometry import Polygon
#https://stackoverflow.com/questions/62075847/using-qgis-and-shaply-error-geosgeom-createlinearring-r-returned-a-null-pointer
from shapely import speedups
speedups.disable()


# def STstats_clip_urban(workpath, start_year_inclusive, end_year_inclusive):
# Clips each input raster to the respective urban boundaries with 5 km buffer

def getPixelStack(file_path_and_date_tuples):
    applicable_file_paths = []
    for (file_path, file_date, regional_grid_name, IATA_code, processing_year,
         collection_number) in file_path_and_date_tuples:
        applicable_file_paths.append(file_path)

    data = None
    for file_path in applicable_file_paths:

        #st_data = utility_functions.get_tifflayer(file_path)
        st_data = utility_functions.get_tifflayer(file_path,dtype=np.float16)

        if data is None:
            data = st_data
        else:
            data = np.dstack((data, st_data))

    if data is not None:
        if len(data.shape) == 2:
            # then there was only one image but we still need to make it (width,height,depth) shape
            data2 = np.dstack((data, data))
            data3 = np.dsplit(data2, 2)
            data4 = data3[0]
            data = data4

    return data

def STstats_stacked(workpath, instats, outDir, NUG_file_path, PUG_file_path):
    # Calculates Sum for the full stack (1985-2020) of hotspot data (instats)

    inputdir = workpath
    # outDir = os.path.join(workpath, str(start_year_inclusive) + '-' + str(end_year_inclusive))


    if not os.path.exists(outDir):
        os.makedirs(outDir, exist_ok=True)

    # Sample Filename: UHI_CU_FSD_1984_20210727_C01_MAXLST_HOTSPOT.tif
    match_string = inputdir + os.sep + "*_" + instats.upper() + "_HOTSPOT.tif"
    print(match_string)
    #st_files = np.asarray(glob.glob(match_string))
    st_files = glob.glob(match_string)
    trans, prj = utility_functions.get_geo(st_files[0])

    # generate stack of input data (clip at the end)
    file_info_tuples = []
    print(st_files)
    for st_file in st_files:
        print(st_file)
        file_path = st_file
        file_name = os.path.basename(file_path)
        pieces = file_name.split("_")
        regional_grid_name = pieces[1]
        IATA_code = pieces[2]
        file_date = dt.datetime.strptime(pieces[3], "%Y")
        # representative_year = str(file_date.year)
        # processing_year = str(dt.datetime.strptime(pieces[4], "%Y%m%d").year)
        processing_year = dt.datetime.today().strftime('%Y%m%d')
        collection_number = pieces[5]
        file_info_tuples.append(
            (file_path, file_date, regional_grid_name, IATA_code, processing_year, collection_number))
    print("Stacking Pixels")
    data = getPixelStack(file_info_tuples)
    data = ma.masked_array(data, data == -9999)
    y_res, x_res, t_res = data.shape
    ########################## HOTSPOT Total Count #########################################################
    print("SUM")
    # axis goes 0,1,2, so we are doing these stats per pixel through the time dimension
    nanhotspot_data = np.nansum(data, axis=2)  # / 10 - 273.15
    nanhotspot_data = ma.filled(nanhotspot_data, fill_value=-9999)


    data_product = instats.upper() + "_HOTSPOT_COUNT"
    save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_1985_2020_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"

    save_path = os.path.join(outDir, save_file_name)
    utility_functions.saveGeoTiff(save_path,
                                  nanhotspot_data,
                                  [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

    nanhotspot_data = np.divide(nanhotspot_data, t_res)*100

    data_product = instats.upper() + "_HOTSPOT_100p"
    save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_1985_2020_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"


    save_path = os.path.join(outDir, save_file_name)
    print("save_path:" + save_path)
    utility_functions.saveGeoTiff(save_path,
                                  nanhotspot_data.astype(np.uint16),
                                  [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_UInt16)

    save_path_100p = save_path

    nanhotspot100p_ds = rasterio.open(save_path_100p)
    nug_ds = rasterio.open(NUG_file_path)
    pug_ds = rasterio.open(PUG_file_path)
    
    bbox = nanhotspot100p_ds.bounds
    polygon_format = [
        [bbox.left, bbox.bottom],
        [bbox.left, bbox.top],
        [bbox.right, bbox.top],
        [bbox.right, bbox.bottom]

    ]
    a_polygon = Polygon(polygon_format)
    nanhotspot100p_window = rasterio.features.geometry_window(nanhotspot100p_ds, [a_polygon])
    nug_window = rasterio.features.geometry_window(nug_ds, [a_polygon])
    pug_window = rasterio.features.geometry_window(pug_ds, [a_polygon])

    
    nug_data = nug_ds.read(1, window=nug_window)
    
    pug_data = pug_ds.read(1, window=pug_window)

    ABC = None
    #https://landsat.usgs.gov/jira/browse/LSRD-7624

    #-------------------------------------------------------------------------

    nanhotspot30p_data = nanhotspot100p_ds.read(1, window=nanhotspot100p_window)

    look1 = np.unique(nanhotspot30p_data)

    nanhotspot30p_data[nanhotspot30p_data < 30] = 0.0
    nanhotspot30p_data[nanhotspot30p_data >= 30] = 1.0
    nanhotspot30p_data[nug_data==0] = 0.0


    #data_product = instats.upper() + "_HOTSPOT_30p"
    data_product = instats.upper() + "_HOTSPOT_PROBABILITY_NU"
    save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_1985_2020_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"
    save_path = os.path.join(outDir, save_file_name)
    utility_functions.saveGeoTiff(save_path,
                                  nanhotspot30p_data.astype(np.float32),
                                  [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)
    
    nanhotspot50p_data = nanhotspot100p_ds.read(1, window=nanhotspot100p_window)
    nanhotspot50p_data[nanhotspot50p_data < 50] = 0.0
    nanhotspot50p_data[nanhotspot50p_data >= 50] = 1.0
    nanhotspot50p_data[pug_data==0] = 0.0

    #data_product = instats.upper() + "_HOTSPOT_50p"
    data_product = instats.upper() + "_HOTSPOT_PROBABILITY_PU"
    save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_1985_2020_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"
    save_path = os.path.join(outDir, save_file_name)
    utility_functions.saveGeoTiff(save_path,
                                  nanhotspot50p_data.astype(np.float32),
                                  [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)
    

    ABC = None


if __name__ == '__main__':

    instats_list = ["MAXLST","MEANLST"]

    parser = argparse.ArgumentParser()
    #parser.add_argument("--hotspot_directory", "-p", help="uhi phase 2 hotspot outputs")
    parser.add_argument("--iata_location_identifier", "-i", help="city IATA code")
    parser.add_argument("--phase_1_directory", "-b", help="phase_1_directory")

    # '/caldera/projects/usgs/eros/urban_heat_islands/Phase1'
    # '/caldera/projects/usgs/eros/urban_heat_islands/Phase1/'

    args = parser.parse_args()
    
    workpath = '{}/{}/output_folder/HOTSPOT'.format(args.phase_1_directory, args.iata_location_identifier)
    print(workpath)

    outdir = '{}/{}/output_folder/HOTSPOT_PROBABILITY'.format(args.phase_1_directory, args.iata_location_identifier)

    #workpath = r'D:\Users\cwmueller\data\UHI\FSD\output_folder\HOTSPOT'
    #outdir = r'D:\Users\cwmueller\data\UHI\FSD\output_folder\HOTSPOT_PROBABILITY'

    #workpath = r'D:\Users\cwmueller\data\UHI\FSD\output_folder\HOTSPOT'
    #outdir = r'D:\Users\cwmueller\data\UHI\FSD\output_folder\HOTSPOT_PROBABILITY'

    # /caldera/projects/usgs/eros/urban_heat_islands/Phase2/section_346/MSO
    # use NUG + PUG
    # https://landsat.usgs.gov/jira/browse/LSRD-7624
    NUG_raster_path = "/caldera/projects/usgs/eros/urban_heat_islands/Phase2/section_346/"+args.iata_location_identifier+"/*NUG*"
    PUG_raster_path = "/caldera/projects/usgs/eros/urban_heat_islands/Phase2/section_346/"+args.iata_location_identifier+"/*PUG*"
    NUG_file_path = glob.glob(NUG_raster_path)[0]
    PUG_file_path = glob.glob(PUG_raster_path)[0]

    if os.path.exists(outdir):
        shutil.rmtree(outdir)

    for instats in instats_list:
        STstats_stacked(workpath, instats, outdir,NUG_file_path,PUG_file_path)

