import os
import numpy as np
from osgeo import gdal
import datetime as dt
import glob
import utility_functions as utility_functions
import gc
import numpy.ma as ma
import fiona
import rasterio
from rasterio.windows import from_bounds
from rasterio.enums import Resampling

import argparse

from rasterio.mask import mask


# def STstats_clip_urban(workpath, start_year_inclusive, end_year_inclusive):
# Clips each input raster to the respective urban boundaries with 5 km buffer
def getPixelStack(file_path_and_date_tuples):
    applicable_file_paths = []
    for (file_path, file_date, regional_grid_name, IATA_code, processing_year,
         collection_number) in file_path_and_date_tuples:
        applicable_file_paths.append(file_path)

    data = None
    for file_path in applicable_file_paths:

        #st_data = utility_functions.get_tifflayer(file_path)
        img = gdal.Open(file_path)

        st_data = np.array(img.GetRasterBand(1).ReadAsArray())
        y = np.shape(st_data)
        print(y)

        if data is None:
            data = st_data
        else:
            data = np.dstack((data, st_data))

    if data is not None:
        if len(data.shape) == 2:
            # then there was only one image but we still need to make it (width,height,depth) shape
            data2 = np.dstack((data, data))
            data3 = np.dsplit(data2, 2)
            data4 = data3[0]
            data = data4

    return data

def STstats_stacked(workpath, instats, iata, nugpath, pugpath, outDir):
    # Calculates Sum for the full stack (1985-2020) of hotspot data (instats)
<<<<<<< HEAD
    #print(workpath + os.sep + iata + os.sep + "output_folder" + os.sep +"HOTSPOT")
    inputdir = r'Z:\UHI\data\PreHPCTesting\HOTSPOT' #workpath + os.sep + iata + os.sep + "output_folder" + os.sep +"HOTSPOT"
=======
    print(workpath + os.sep + iata + os.sep + "output_folder" + os.sep +"HOTSPOT")
    inputdir = workpath + os.sep + iata + os.sep + "output_folder" + os.sep +"HOTSPOT"
>>>>>>> origin/Phase2additions

    if not os.path.exists(outDir):
        os.makedirs(outDir, exist_ok=True)

    # Sample Filename: UHI_CU_FSD_1984_20210727_C01_MAXLST_HOTSPOT.tif
<<<<<<< HEAD
    st_files = np.asarray(glob.glob(inputdir + os.sep + "*_" + instats.upper() + "__HOTSPOT.tif"))
=======
    st_files = np.asarray(glob.glob(inputdir + os.sep + "*_" + instats.upper() + "_HOTSPOT.tif"))
>>>>>>> origin/Phase2additions
    trans, prj = utility_functions.get_geo(st_files[0])

    # generate stack of input data (clip at the end)
    file_info_tuples = []
    print(st_files)
    for st_file in st_files:
        print(st_file)
        file_path = st_file
        file_name = os.path.basename(file_path)
        pieces = file_name.split("_")
        regional_grid_name = pieces[1]
        IATA_code = pieces[2]
        file_date = dt.datetime.strptime(pieces[3], "%Y")
        processing_year = dt.datetime.today().strftime('%Y%m%d')
        collection_number = pieces[5]
        file_info_tuples.append(
            (file_path, file_date, regional_grid_name, IATA_code, processing_year, collection_number))
    print("Stacking Pixels")
    data = getPixelStack(file_info_tuples)
    data = ma.masked_array(data, data == -9999)
    y_res, x_res, t_res = data.shape


    ########################## HOTSPOT Total Count #########################################################
    print("SUM")

    r_ds = gdal.Open(file_path)
    ulx, xres, xskew, uly, yskew, yres = r_ds.GetGeoTransform()
    lrx = ulx + (r_ds.RasterXSize * xres)
    lry = uly + (r_ds.RasterYSize * yres)

    with rasterio.open(nugpath) as nug_src:
        nug = nug_src.read(1, window=from_bounds(ulx,lry , lrx, uly, nug_src.transform))
    with rasterio.open(pugpath) as pug_src:
        pug = pug_src.read(1, window=from_bounds(ulx, lry, lrx, uly, pug_src.transform))

    print("SUM")
    # axis goes 0,1,2, so we are doing these stats per pixel through the time dimension
    nanhotspot_data = np.nansum(data, axis=2)  # / 10 - 273.15
    nanhotspot_data = ma.filled(nanhotspot_data, fill_value=-9999)


    data_product = instats.upper() + "_HOTSPOT_COUNT"
    save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_1985_2020_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"

    save_path = os.path.join(outDir, save_file_name)
    utility_functions.saveGeoTiff(save_path,
                                  nanhotspot_data,
                                  [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

<<<<<<< HEAD
    #Percent
    nanhotspot_data_perc = np.nansum(data, axis=2)  # / 10 - 273.15
    nanhotspot_data_perc = np.divide(nanhotspot_data_perc,t_res)
    nanhotspot_data_perc = np.multiply(nanhotspot_data_perc,100)
    nanhotspot_data_perc[nanhotspot_data_perc < 0] = -9999

    data_product = instats.upper() + "_HOTSPOT_PERCENT"
    save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_1985_2020_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"

    save_path = os.path.join(outDir, save_file_name)
    utility_functions.saveGeoTiff(save_path,
                                  nanhotspot_data_perc,
                                  [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Byte)

    nanhotspot_percent_data = np.copy(nanhotspot_data_perc)
    nanhotspot_percent_data[nanhotspot_data_perc >= 30] = 1
    nanhotspot_percent_data[nanhotspot_data_perc < 30] = 0
=======
    nanhotspot_percent_data = np.divide(nanhotspot_data, t_res)
    nanhotspot_percent_data[nanhotspot_percent_data >= 0.30] = 1
    nanhotspot_percent_data[nanhotspot_percent_data < 0.30] = 0
>>>>>>> origin/Phase2additions
    nanhotspot_nug_data = np.multiply(nanhotspot_percent_data,nug)
    nanhotspot_nug_data[nanhotspot_data < 0] = -9999

    data_product = instats.upper() + "_HOTSPOT_NUG"
    save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_1985_2020_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"

    save_path = os.path.join(outDir, save_file_name)
    utility_functions.saveGeoTiff(save_path,
                                  nanhotspot_nug_data,
                                  [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

<<<<<<< HEAD
    nanhotspot_percent_data = np.copy(nanhotspot_data_perc)
    nanhotspot_percent_data[nanhotspot_data_perc >= 50] = 1
    nanhotspot_percent_data[nanhotspot_data_perc < 50] = 0
=======
    nanhotspot_percent_data[nanhotspot_data >= 0.50] = 1
    nanhotspot_percent_data[nanhotspot_percent_data < 0.50] = 0
>>>>>>> origin/Phase2additions
    np.size(nanhotspot_percent_data)
    np.size(pug)
    nanhotspot_pug_data = np.multiply(nanhotspot_percent_data,pug)
    nanhotspot_pug_data[nanhotspot_data < 0] = -9999

    data_product = instats.upper() + "_HOTSPOT_PUG"
    save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_1985_2020_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"

    save_path = os.path.join(outDir, save_file_name)
    utility_functions.saveGeoTiff(save_path,
                                  nanhotspot_pug_data,
                                  [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

    nanhotspot_data = None
    nanhotspot_percent_data = None

    gc.collect()


if __name__ == '__main__':
    # https://stackabuse.com/command-line-arguments-in-python/

    #parser = argparse.ArgumentParser()
    #parser.add_argument("--workdir", "-w", help="specify workpath directory")
    #parser.add_argument("--stat", "-s", help='hotspot derived stat ["MAXLST"]', nargs='?', const='"MAXLST"')
    #parser.add_argument("--iata_location_identifier", "-i", help="iata_location_identifier")
    #parser.add_argument("--nugpath", "-n", help='New Urban Group ["NUG"]')
    #parser.add_argument("--pugpath", "-p", help='Persistent Urban Group ["PUG"]')
    #parser.add_argument("--outdir", "-o", help="uhi_lc_root")
    #args = parser.parse_args()

    #STstats_stacked(args.workdir, args.stat, args.iata_location_identifier, args.nugpath, args.pugpath,
    #                args.outdir)

    #these are the locations I used and how it is presently set up
    workdir = r'D:\Users\cwmueller\data\UHI'
    iata = "FSD"
    stat = "MAXLST"
<<<<<<< HEAD
    outdir = r'Z:\UHI\data\PreHPCTesting\HOTSPOT_PROBABILITY'
=======
    outdir = r'D:\Users\cwmueller\data\UHI\FSD\output_folder\HOTSPOT_PROBABILITY'
>>>>>>> origin/Phase2additions
    nugpath = r'Z:\UHI\data\PreHPCTesting\UHI_LC_NEW_UGROWTH\uhi_step_334B_sioux_falls\new_imp_85_20.tif'
    pugpath = r'Z:\UHI\data\PreHPCTesting\UHI_LC_PERSISTENT\uhi_step_346_with_non_urban\UHI_CU_FSD_20211116_C01_PUG.tif'
    STstats_stacked(workdir, stat, iata, nugpath, pugpath, outdir)

<<<<<<< HEAD
=======

>>>>>>> origin/Phase2additions
