import os
import numpy as np
from osgeo import gdal
import datetime as dt
import glob
import utility_functions as utility_functions
import gc
import numpy.ma as ma
import fiona
import rasterio

import argparse

from rasterio.mask import mask


# def STstats_clip_urban(workpath, start_year_inclusive, end_year_inclusive):
# Clips each input raster to the respective urban boundaries with 5 km buffer

def getPixelStack(file_path_and_date_tuples):
    applicable_file_paths = []
    for (file_path, file_date, regional_grid_name, IATA_code, processing_year,
         collection_number) in file_path_and_date_tuples:
        applicable_file_paths.append(file_path)

    data = None
    for file_path in applicable_file_paths:

        st_data = utility_functions.get_tifflayer(file_path)

        if data is None:
            data = st_data
        else:
            data = np.dstack((data, st_data))

    if data is not None:
        if len(data.shape) == 2:
            # then there was only one image but we still need to make it (width,height,depth) shape
            data2 = np.dstack((data, data))
            data3 = np.dsplit(data2, 2)
            data4 = data3[0]
            data = data4

    return data

def STstats_stacked(workpath, instats, outDir):
    # Calculates Sum for the full stack (1985-2020) of hotspot data (instats)

    inputdir = workpath
    # outDir = os.path.join(workpath, str(start_year_inclusive) + '-' + str(end_year_inclusive))

    if not os.path.exists(outDir):
        os.makedirs(outDir, exist_ok=True)

    # Sample Filename: UHI_CU_FSD_1984_20210727_C01_MAXLST_HOTSPOT.tif
    st_files = np.asarray(glob.glob(inputdir + os.sep + "*_" + instats.upper() + "_HOTSPOT.tif"))
    trans, prj = utility_functions.get_geo(st_files[0])

    # generate stack of input data (clip at the end)
    file_info_tuples = []

    for st_file in st_files:
        print(st_file)
        file_path = st_file
        file_name = os.path.basename(file_path)
        pieces = file_name.split("_")
        regional_grid_name = pieces[1]
        IATA_code = pieces[2]
        file_date = dt.datetime.strptime(pieces[3], "%Y")
        # representative_year = str(file_date.year)
        # processing_year = str(dt.datetime.strptime(pieces[4], "%Y%m%d").year)
        processing_year = dt.datetime.today().strftime('%Y%m%d')
        collection_number = pieces[5]
        file_info_tuples.append(
            (file_path, file_date, regional_grid_name, IATA_code, processing_year, collection_number))
    print("Stacking Pixels")
    data = getPixelStack(file_info_tuples)
    data = ma.masked_array(data, data == -9999)
    y_res, x_res, t_res = data.shape
    ########################## HOTSPOT Total Count #########################################################
    print("SUM")
    # axis goes 0,1,2, so we are doing these stats per pixel through the time dimension
    nanhotspot_data = np.nansum(data, axis=2)  # / 10 - 273.15
    nanhotspot_data = ma.filled(nanhotspot_data, fill_value=-9999)


    data_product = instats.upper() + "_HOTSPOT_COUNT"
    save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_1985_2020_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"

    save_path = os.path.join(outDir, save_file_name)
    utility_functions.saveGeoTiff(save_path,
                                  nanhotspot_data,
                                  [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

    nanhotspot_data = np.divide(nanhotspot_data, t_res)
    nanhotspot30p_data[nanhotspot_data >= 0.30] = 1
    nanhotspot30p_data[nanhotspot30p_data < 0.30] = 0

    data_product = instats.upper() + "_HOTSPOT_30p"
    save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_1985_2020_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"

    save_path = os.path.join(outDir, save_file_name)
    utility_functions.saveGeoTiff(save_path,
                                  nanhotspot30p_data,
                                  [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

    nanhotspot_data = np.divide(nanhotspot_data, t_res)
    nanhotspot50p_data[nanhotspot_data >= 0.50] = 1
    nanhotspot50p_data[nanhotspot50p_data < 0.50] = 0

    data_product = instats.upper() + "_HOTSPOT_50p"
    save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_1985_2020_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"

    save_path = os.path.join(outDir, save_file_name)
    utility_functions.saveGeoTiff(save_path,
                                  nanhotspot50p_data,
                                  [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

    nanhotspot_data = None

    gc.collect()


if __name__ == '__main__':

    instats_list = ["MAXLST"]

    parser = argparse.ArgumentParser()
    parser.add_argument("--hotspot_directory", "-p", help="uhi phase 2 hotspot outputs")
    parser.add_argument("--iata_location_identifier", "-i", help="city IATA code")

    # '/caldera/projects/usgs/eros/urban_heat_islands/Phase1'

    args = parser.parse_args()

    if args.phase_1_directory and args.iata_location_identifier:
        # '/caldera/projects/usgs/eros/urban_heat_islands/Phase1/'
        workpath = '{}/{}/output_folder/ANNUAL'.format(args.phase_1_directory, args.iata_location_identifier)
        outdir = '{}/{}/output_folder/STACK'.format(args.phase_1_directory, args.iata_location_identifier)

        workpath = r'D:\Users\cwmueller\data\UHI\FSD\output_folder\HOTSPOT'
        outdir = r'D:\Users\cwmueller\data\UHI\FSD\output_folder\HOTSPOT_PROBABILITY'
        for instats in instats_list:
            STstats_stacked(workpath, instats, outdir)

    # parser = argparse.ArgumentParser()
    # parser.add_argument("--workpath", "-w", help="specify workpath directory")
    # parser.add_argument("--iata_location_identifier", "-i", help="iata_location_identifier")
    # parser.add_argument("--instats", "-i", help="ST input statistics to be stacked (max/min/median/mean)")

    # parser.add_argument("--start_year_inclusive", "-s", help="specify start_year_inclusive")
    # parser.add_argument("--end_year_inclusive", "-e", help="specify end_year_inclusive")

    # args = parser.parse_args()

    # if args.workpath and args.hv_tilename and args.start_year_inclusive and args.end_year_inclusive:
    #    print(str(args))
    #    st_stats_forHua(args.workpath, [args.hv_tilename], args.iata_location_identifier, args.start_year_inclusive,
    #                   args.end_year_inclusive)
    # STstats_clipped_stacked