

import time
import os
import datetime as dt
import numpy as np
from osgeo import gdal

from scipy import stats
import glob, tarfile
import psutil

from joblib import Parallel, delayed


from multiprocessing import Process, Pipe

import argparse

import gc
import utility_functions as utility_functions
import trends_calculation as trends_calculation

#for testing
#start_year = 1988
#end_year = 1989

#start_year = 1984
#end_year = 1992

#start_year = 1984
#end_year = 2018
#start_year = 1984
#end_year = 2019
#start_year = 1984
#end_year = 2018

#############################From: lcmap.py############################################
qamap = {'fill': 0,
         'clear': 1,
         'water': 2,
         'shadow': 3,
         'snow': 4,
         'cloud': 5,
         'cloud_conf1': 6,
         'cloud_conf2': 7,
         'cirrus1': 8,
         'cirrus2': 9,
         'occulsion': 10}

#############################From: Kelcy - SR-stuff.ipynb############################################
def bitmask(arr, bit):
    """
    Create a bool mask where the bit applies to the array
    """
    #reza explanation
    #take a 1 which is in bit position 0 and shift it left to the bit index with <<
    #then do an & to get the matching bits
    #then see if we have greater than 0 matching bits
    return arr & 1 << bit > 0

##############################################################################
def get_LSTclear_idx(LST, st_qa, px_qa):
    '''
    Context on scale values from ard dfcb
    If you want to un-scale the ST data that you read from ARD, you would have to multiply by 0.1
    similarly, if you want to un-scale the ST_QA read from ARD, you would have to multiply by 0.01
    However, in this example, we are not reading from ARD. We are adding a hardcoded threshold which
    is in units of Kelvin. The un-scaled value is 5 Kelvin. If you want to scale it using the ST_QA
    scale factor, then we must divide it by the scale factor (/0.01 or multiply it by 100)

    we could also just unscale both values if we wanted to
    '''
    five_kelvin_scaled = (5/0.01)
    ten_kelvin_scaled = (10/0.01)
    
    return np.asarray((st_qa <= ten_kelvin_scaled)*(st_qa >= 0)\
           * (bitmask(px_qa, qamap['water']) + bitmask(px_qa, qamap['clear']))\
           * (LST > 2731.5))
    
    '''return np.asarray((st_qa >= 0)\
        * (bitmask(px_qa, qamap['water']) + bitmask(px_qa, qamap['clear']))\
        * (LST > 2731.5))
    '''

def get_Thermalclear_idx(px_qa):
    return bitmask(px_qa, qamap['water']) + bitmask(px_qa, qamap['clear'])



def extract_ST(workpath, HV):
    inputdir = os.path.join(workpath, HV)
    outDir = os.path.join(workpath, HV, 'extract')

    if not os.path.exists(outDir):  ## if outfolder is not already existed creating one
        os.makedirs(outDir,exist_ok=True)
    print(inputdir)
    tarFiles = glob.glob(inputdir + os.sep + "*.tar")

    #--------------------------------------------------multi threaded version
    '''
    def extract(file_path):
        print(file_path)
        baz = tarfile.open(file_path, 'r:')
        for item in baz:
            # print(item)
            if "_CU_" in item.name:#CONUS
                if '_ST.tif' in item.name:
                    baz.extract(item, path=outDir)
                if '_STQA.tif' in item.name:
                    baz.extract(item, path=outDir)
                if '_PIXELQA.tif' in item.name:
                    baz.extract(item, path=outDir)
        baz.close()

    file_paths = []
    for count, tarF in enumerate(tarFiles):
        file_paths.append(tarF)

    num_cpu = psutil.cpu_count(logical=True)
    Parallel(n_jobs=num_cpu)(delayed(extract)(some_file_path) for some_file_path in file_paths)
    '''
    #--------------------------------------------------single threaded version
    
    for count, tarF in enumerate(tarFiles):
        print(tarF)
        baz = tarfile.open(tarF, 'r:')
        for item in baz:
            # print(item)
            if '_ST.tif' in item.name:
                baz.extract(item, path=outDir)
            if '_STQA.tif' in item.name:
                baz.extract(item, path=outDir)
            if '_PIXELQA.tif' in item.name:
                baz.extract(item, path=outDir)
        baz.close()
    
    #--------------------------------------------------


def getPixelStackForMonthsFromYear(chosen_months,chosen_year,file_path_and_date_tuples):

    applicable_file_paths = []
    for (file_path,file_date,regional_grid_name,representative_year_string,processing_year,collection_number) in file_path_and_date_tuples:
        if file_date.year == chosen_year and file_date.month in chosen_months:
            applicable_file_paths.append(file_path)

    data = None
    for file_path in applicable_file_paths:

        stQA_file = file_path.replace('_ST.tif', '_STQA.tif')
        pxQA_file = file_path.replace('_ST.tif', '_PIXELQA.tif')

        st_data = utility_functions.get_tifflayer(file_path)
        stQA_data = utility_functions.get_tifflayer(stQA_file, dtype=int)
        pxQA_data = utility_functions.get_tifflayer(pxQA_file, dtype=np.uint16)

        mask = get_LSTclear_idx(st_data, stQA_data, pxQA_data)
        st_data[~mask] = np.nan

        if data is None:
            data = st_data
        else:
            data = np.dstack((data,st_data))

    if data is not None:
        if len(data.shape)==2:
            #then there was only one image but we still need to make it (width,height,depth) shape
            data2 = np.dstack((data,data))
            data3 = np.dsplit(data2,2)
            data4 = data3[0]
            data = data4

    return data

def ST_annual_stats( HV_input_path, IATA_code, start_year_inclusive, end_year_inclusive, processing_year_label):

    start_year = start_year_inclusive
    end_year = end_year_inclusive

    #inputdir = os.path.join(workpath, HV, 'extract')
    inputdir = HV_input_path
    #outDir = os.path.join(workpath, HV, 'ANNUAL')
    outDir = os.path.join(HV_input_path, 'ANNUAL')

    if not os.path.exists(outDir):
        os.makedirs(outDir,exist_ok=True)

    look = inputdir + os.sep + "*_ST*.tif"
    st_files = np.asarray(glob.glob(inputdir + os.sep + "*_ST_B*"))
    trans, prj = utility_functions.get_geo(st_files[0])

    file_info_tuples = []
    for st_file in st_files:
        file_path = st_file
        file_name = os.path.basename(file_path)
        pieces = file_name.split("_")
        regional_grid_name = pieces[1]
        file_date = dt.datetime.strptime(pieces[3], "%Y%m%d")
        representative_year = str(file_date.year)
        #processing_year = str(dt.datetime.strptime(pieces[4], "%Y%m%d").year)
        processing_year = processing_year_label
        collection_number = pieces[5]
        file_info_tuples.append((file_path,file_date,regional_grid_name,representative_year,processing_year,collection_number))

    for iyear in range(start_year, end_year + 1):
        
        #gc.collect()
        #load one to get the shape
        images_to_use1 = getPixelStackForMonthsFromYear(chosen_months=[1,2,3,4,5,6,7,8,9,10,11,12], chosen_year=iyear,
                                                        file_path_and_date_tuples=file_info_tuples)
        
        
        #I put this in but later realized, the problem was zelenaks m2m2 script wasnt downloading everything correctly
        #if images_to_use1 is None:
        #    print("ST_annual_stats no valid images for year "+str(iyear))
        #    continue
        
        y_res, x_res, t_res = images_to_use1.shape

        #get our file_name construction variables
        applicable_tuples = []
        for (file_path, file_date, regional_grid_name, representative_year, processing_year,
             collection_number) in file_info_tuples:
            if file_date.year == iyear:
                applicable_tuples.append((file_path, file_date, regional_grid_name, representative_year, processing_year,
             collection_number))
        (file_path, file_date, regional_grid_name, representative_year, processing_year,
         collection_number) = applicable_tuples[0]


        data = getPixelStackForMonthsFromYear(chosen_months=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], chosen_year=iyear,
                                              file_path_and_date_tuples=file_info_tuples)
        ##########################Annual mean ST#########################################################

        #axis goes 0,1,2, so we are doing these stats per pixel through the time dimension
        nanmean_data = np.nanmean(data, axis=2) / 10 - 273.15

        data_product = "MEANLST"
        save_file_name = "UHI_"+regional_grid_name+"_"+IATA_code+"_"+representative_year+"_"+processing_year+"_"+collection_number+"_"+data_product+".tif"

        save_path = os.path.join(outDir, save_file_name)
        utility_functions.saveGeoTiff(save_path,
                    nanmean_data,
                    [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

        #########################Annual max ST#####################################################

        # axis goes 0,1,2, so we are doing these stats per pixel through the time dimension
        nanmax_data = np.nanmax(data, axis=2) / 10 - 273.15

        data_product = "MAXLST"
        save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + representative_year + "_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"

        save_path = os.path.join(outDir, save_file_name)
        utility_functions.saveGeoTiff(save_path,
                                      nanmax_data,
                                      [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)


        #########################Annual min ST#####################################################

        # axis goes 0,1,2, so we are doing these stats per pixel through the time dimension
        nanmin_data = np.nanmin(data, axis=2) / 10 - 273.15

        data_product = "MINLST"
        save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + representative_year + "_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"

        save_path = os.path.join(outDir, save_file_name)
        utility_functions.saveGeoTiff(save_path,
                                      nanmin_data,
                                      [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)


        #########################Annual median ST#####################################################

        # axis goes 0,1,2, so we are doing these stats per pixel through the time dimension
        nanmedian_data = np.nanmedian(data, axis=2) / 10 - 273.15

        data_product = "MEDIANLST"
        save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + representative_year + "_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"

        save_path = os.path.join(outDir, save_file_name)
        utility_functions.saveGeoTiff(save_path,
                                      nanmedian_data,
                                      [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)
        
        gc.collect()


    #I just took the numclear and put it here, bc its only calced for annual+trendstats
    st_files = np.asarray(glob.glob(inputdir + os.sep + "*_ST*"))
    trans, prj = utility_functions.get_geo(st_files[0])
    dates_list = []
    for st_file in st_files:
        dates_list.append(dt.datetime.strptime(os.path.basename(st_file)[15:23], "%Y%m%d"))
    dates_list = np.asarray(dates_list)

    for iyear in range(start_year, end_year + 1):
        year_idx = np.asarray([iyear == dat.year for dat in dates_list])
        clear_img = np.zeros((5000, 5000), dtype=np.uint16)
        for st_file in st_files[year_idx]:
            #'/caldera/projects/usgs/eros/urban_heat_islands/Phase3/hv_downloads/016006/LT05_CU_016006_19840414_20210421_02_ST_B6.TIF'
            #need to mod it for this
            
            #old
            #stQA_file = st_file.replace('_ST.tif', '_STQA.tif')
            #pxQA_file = st_file.replace('_ST.tif', '_PIXELQA.tif')

            stQA_file = st_file.replace('_ST', '_STQA')
            pxQA_file = st_file.replace('_ST', '_PIXELQA')

            st_data = utility_functions.get_tifflayer(st_file)
            stQA_data = utility_functions.get_tifflayer(stQA_file, dtype=int)
            pxQA_data = utility_functions.get_tifflayer(pxQA_file, dtype=np.uint16)

            clear_img += get_LSTclear_idx(st_data, stQA_data, pxQA_data)

        print('Saving', HV_input_path, iyear)
        y_res, x_res = clear_img.shape

        representative_year = iyear
        data_product = "NUMCLEAR"
        save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + str(representative_year) + "_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"
        save_path = os.path.join(outDir, save_file_name)
        utility_functions.saveGeoTiff(save_path,
                    clear_img, [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_UInt16)



def ST_seasonal_stats(workpath,HV,IATA_code,start_year_inclusive,end_year_inclusive,processing_year_label):

    start_year = start_year_inclusive
    end_year = end_year_inclusive

    inputdir = os.path.join(workpath, HV, 'extract')
    outDir = os.path.join(workpath, HV, 'SEASONAL')

    if not os.path.exists(outDir):  ## if outfolder is not already existed creating one
        os.makedirs(outDir,exist_ok=True)

    st_files = np.asarray(glob.glob(inputdir + os.sep + "*_ST.tif"))
    trans, prj = utility_functions.get_geo(st_files[0])

    file_info_tuples = []
    for st_file in st_files:
        file_path = st_file
        file_name = os.path.basename(file_path)
        pieces = file_name.split("_")
        regional_grid_name = pieces[1]
        file_date = dt.datetime.strptime(pieces[3], "%Y%m%d")
        representative_year = str(file_date.year)
        #processing_year = str(dt.datetime.strptime(pieces[4], "%Y%m%d").year)
        processing_year = processing_year_label
        collection_number = pieces[5]
        file_info_tuples.append((file_path,file_date,regional_grid_name,representative_year,processing_year,collection_number))

    for iyear in range(start_year, end_year + 1):



        #----------------------------------------|

        #this one is split so we can get dec from the previous year
        dec_jan_feb_season_part1 = [12]
        dec_jan_feb_season_part2 = [1,2]

        mar_apr_may_season = [3,4,5]
        jun_jul_aug_season = [6,7,8]
        sep_oct_nov_season = [9,10,11]

        #-----------------------------------------|

        #load one to get the shape
        images_to_use1 = getPixelStackForMonthsFromYear(chosen_months=[1,2,3,4,5,6,7,8,9,10,11,12], chosen_year=iyear,
                                                        file_path_and_date_tuples=file_info_tuples)
        y_res, x_res, t_res = images_to_use1.shape

        #get our file_name construction variables
        applicable_tuples = []
        for (file_path, file_date, regional_grid_name, representative_year, processing_year,
             collection_number) in file_info_tuples:
            if file_date.year == iyear:
                applicable_tuples.append((file_path, file_date, regional_grid_name, representative_year, processing_year,
             collection_number))
        (file_path, file_date, regional_grid_name, representative_year, processing_year,
         collection_number) = applicable_tuples[0]



        ###############################Seasonal max################################################

        
        images_to_use1 = getPixelStackForMonthsFromYear(chosen_months=dec_jan_feb_season_part2,chosen_year=iyear,file_path_and_date_tuples=file_info_tuples)
        images_to_use2 = getPixelStackForMonthsFromYear(chosen_months=dec_jan_feb_season_part1, chosen_year=iyear-1,file_path_and_date_tuples=file_info_tuples)
        images_to_use3 = images_to_use1
        if images_to_use3 is None:
            images_to_use3 = images_to_use2
        if images_to_use1 is not None and images_to_use2 is not None:
            images_to_use3 = np.dstack((images_to_use1, images_to_use2))
        if images_to_use3 is not None:
            season_label = "S1"
            data_product = "MAXLST"
            save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + representative_year +"_"+season_label+"_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"
            save_path = os.path.join(outDir, save_file_name)
            utility_functions.saveGeoTiff(save_path,
                        np.nanmax(images_to_use3, axis=2) / 10 - 273.15,
                        [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)


        images_to_use = getPixelStackForMonthsFromYear(chosen_months=mar_apr_may_season,chosen_year=iyear,file_path_and_date_tuples=file_info_tuples)
        if images_to_use is not None:
            season_label = "S2"
            data_product = "MAXLST"
            save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + representative_year + "_" + season_label + "_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"
            save_path = os.path.join(outDir, save_file_name)
            utility_functions.saveGeoTiff(save_path,
                                          np.nanmax(images_to_use, axis=2) / 10 - 273.15,
                                          [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

        images_to_use = getPixelStackForMonthsFromYear(chosen_months=jun_jul_aug_season,chosen_year=iyear,file_path_and_date_tuples=file_info_tuples)
        if images_to_use is not None:
            season_label = "S3"
            data_product = "MAXLST"
            save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + representative_year + "_" + season_label + "_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"
            save_path = os.path.join(outDir, save_file_name)
            utility_functions.saveGeoTiff(save_path,
                                          np.nanmax(images_to_use, axis=2) / 10 - 273.15,
                                          [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

        images_to_use = getPixelStackForMonthsFromYear(chosen_months=sep_oct_nov_season,chosen_year=iyear,file_path_and_date_tuples=file_info_tuples)
        if images_to_use is not None:
            season_label = "S4"
            data_product = "MAXLST"
            save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + representative_year + "_" + season_label + "_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"
            save_path = os.path.join(outDir, save_file_name)
            utility_functions.saveGeoTiff(save_path,
                                          np.nanmax(images_to_use, axis=2) / 10 - 273.15,
                                          [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)


        ###############################Seasonal min################################################

        images_to_use1 = getPixelStackForMonthsFromYear(chosen_months=dec_jan_feb_season_part2, chosen_year=iyear,
                                                        file_path_and_date_tuples=file_info_tuples)
        images_to_use2 = getPixelStackForMonthsFromYear(chosen_months=dec_jan_feb_season_part1, chosen_year=iyear - 1,
                                                        file_path_and_date_tuples=file_info_tuples)
        images_to_use3 = images_to_use1
        if images_to_use3 is None:
            images_to_use3 = images_to_use2
        if images_to_use1 is not None and images_to_use2 is not None:
            images_to_use3 = np.dstack((images_to_use1, images_to_use2))
        if images_to_use3 is not None:
            season_label = "S1"
            data_product = "MINLST"
            save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + representative_year + "_" + season_label + "_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"
            save_path = os.path.join(outDir, save_file_name)
            utility_functions.saveGeoTiff(save_path,
                                          np.nanmin(images_to_use, axis=2) / 10 - 273.15,
                                          [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

        images_to_use = getPixelStackForMonthsFromYear(chosen_months=mar_apr_may_season, chosen_year=iyear,
                                                       file_path_and_date_tuples=file_info_tuples)
        if images_to_use is not None:
            season_label = "S2"
            data_product = "MINLST"
            save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + representative_year + "_" + season_label + "_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"
            save_path = os.path.join(outDir, save_file_name)
            utility_functions.saveGeoTiff(save_path,
                                          np.nanmin(images_to_use, axis=2) / 10 - 273.15,
                                          [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

        images_to_use = getPixelStackForMonthsFromYear(chosen_months=jun_jul_aug_season, chosen_year=iyear,
                                                       file_path_and_date_tuples=file_info_tuples)
        if images_to_use is not None:
            season_label = "S3"
            data_product = "MINLST"
            save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + representative_year + "_" + season_label + "_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"
            save_path = os.path.join(outDir, save_file_name)
            utility_functions.saveGeoTiff(save_path,
                                          np.nanmin(images_to_use, axis=2) / 10 - 273.15,
                                          [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

        images_to_use = getPixelStackForMonthsFromYear(chosen_months=sep_oct_nov_season, chosen_year=iyear,
                                                       file_path_and_date_tuples=file_info_tuples)
        if images_to_use is not None:
            season_label = "S4"
            data_product = "MINLST"
            save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + representative_year + "_" + season_label + "_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"
            save_path = os.path.join(outDir, save_file_name)
            utility_functions.saveGeoTiff(save_path,
                                          np.nanmin(images_to_use, axis=2) / 10 - 273.15,
                                          [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

        #############################Seasonal mean###########################################

        images_to_use1 = getPixelStackForMonthsFromYear(chosen_months=dec_jan_feb_season_part2, chosen_year=iyear,
                                                        file_path_and_date_tuples=file_info_tuples)
        images_to_use2 = getPixelStackForMonthsFromYear(chosen_months=dec_jan_feb_season_part1, chosen_year=iyear - 1,
                                                        file_path_and_date_tuples=file_info_tuples)
        images_to_use3 = images_to_use1
        if images_to_use3 is None:
            images_to_use3 = images_to_use2
        if images_to_use1 is not None and images_to_use2 is not None:
            images_to_use3 = np.dstack((images_to_use1, images_to_use2))
        if images_to_use3 is not None:
            season_label = "S1"
            data_product = "MEANLST"
            save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + representative_year + "_" + season_label + "_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"
            save_path = os.path.join(outDir, save_file_name)
            utility_functions.saveGeoTiff(save_path,
                                          np.nanmean(images_to_use, axis=2) / 10 - 273.15,
                                          [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

        images_to_use = getPixelStackForMonthsFromYear(chosen_months=mar_apr_may_season, chosen_year=iyear,
                                                       file_path_and_date_tuples=file_info_tuples)
        if images_to_use is not None:
            season_label = "S2"
            data_product = "MEANLST"
            save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + representative_year + "_" + season_label + "_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"
            save_path = os.path.join(outDir, save_file_name)
            utility_functions.saveGeoTiff(save_path,
                                          np.nanmean(images_to_use, axis=2) / 10 - 273.15,
                                          [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

        images_to_use = getPixelStackForMonthsFromYear(chosen_months=jun_jul_aug_season, chosen_year=iyear,
                                                       file_path_and_date_tuples=file_info_tuples)
        if images_to_use is not None:
            season_label = "S3"
            data_product = "MEANLST"
            save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + representative_year + "_" + season_label + "_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"
            save_path = os.path.join(outDir, save_file_name)
            utility_functions.saveGeoTiff(save_path,
                                          np.nanmean(images_to_use, axis=2) / 10 - 273.15,
                                          [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

        images_to_use = getPixelStackForMonthsFromYear(chosen_months=sep_oct_nov_season, chosen_year=iyear,
                                                       file_path_and_date_tuples=file_info_tuples)
        if images_to_use is not None:
            season_label = "S4"
            data_product = "MEANLST"
            save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + representative_year + "_" + season_label + "_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"
            save_path = os.path.join(outDir, save_file_name)
            utility_functions.saveGeoTiff(save_path,
                                          np.nanmean(images_to_use, axis=2) / 10 - 273.15,
                                          [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

        #############################Seasonal median###########################################
        images_to_use1 = getPixelStackForMonthsFromYear(chosen_months=dec_jan_feb_season_part2, chosen_year=iyear,
                                                        file_path_and_date_tuples=file_info_tuples)
        images_to_use2 = getPixelStackForMonthsFromYear(chosen_months=dec_jan_feb_season_part1, chosen_year=iyear - 1,
                                                        file_path_and_date_tuples=file_info_tuples)
        images_to_use3 = images_to_use1
        if images_to_use3 is None:
            images_to_use3 = images_to_use2
        if images_to_use1 is not None and images_to_use2 is not None:
            images_to_use3 = np.dstack((images_to_use1, images_to_use2))
        if images_to_use3 is not None:
            season_label = "S1"
            data_product = "MEDIANLST"
            save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + representative_year + "_" + season_label + "_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"
            save_path = os.path.join(outDir, save_file_name)
            utility_functions.saveGeoTiff(save_path,
                                          np.nanmedian(images_to_use, axis=2) / 10 - 273.15,
                                          [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

        images_to_use = getPixelStackForMonthsFromYear(chosen_months=mar_apr_may_season, chosen_year=iyear,
                                                       file_path_and_date_tuples=file_info_tuples)
        if images_to_use is not None:
            season_label = "S2"
            data_product = "MEDIANLST"
            save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + representative_year + "_" + season_label + "_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"
            save_path = os.path.join(outDir, save_file_name)
            utility_functions.saveGeoTiff(save_path,
                                          np.nanmedian(images_to_use, axis=2) / 10 - 273.15,
                                          [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

        images_to_use = getPixelStackForMonthsFromYear(chosen_months=jun_jul_aug_season, chosen_year=iyear,
                                                       file_path_and_date_tuples=file_info_tuples)
        if images_to_use is not None:
            season_label = "S3"
            data_product = "MEDIANLST"
            save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + representative_year + "_" + season_label + "_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"
            save_path = os.path.join(outDir, save_file_name)
            utility_functions.saveGeoTiff(save_path,
                                          np.nanmedian(images_to_use, axis=2) / 10 - 273.15,
                                          [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

        images_to_use = getPixelStackForMonthsFromYear(chosen_months=sep_oct_nov_season, chosen_year=iyear,
                                                       file_path_and_date_tuples=file_info_tuples)
        if images_to_use is not None:
            season_label = "S4"
            data_product = "MEDIANLST"
            save_file_name = "UHI_" + regional_grid_name + "_" + IATA_code + "_" + representative_year + "_" + season_label + "_" + processing_year + "_" + collection_number + "_" + data_product + ".tif"
            save_path = os.path.join(outDir, save_file_name)
            utility_functions.saveGeoTiff(save_path,
                                          np.nanmedian(images_to_use, axis=2) / 10 - 273.15,
                                          [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_Float32)

        #del data




def num_clear(workpath, HV, start_year_inclusive, end_year_inclusive):

    start_year = start_year_inclusive
    end_year = end_year_inclusive

    inputdir = os.path.join(workpath, HV, 'extract')
    outDir = os.path.join(workpath, HV, 'ANNUAL')

    if not os.path.exists(outDir):  ## if outfolder is not already existed creating one
        os.makedirs(outDir,exist_ok=True)

    st_files = np.asarray(glob.glob(inputdir + os.sep + "*_ST.tif"))
    trans, prj = utility_functions.get_geo(st_files[0])
    dates_list = []
    for st_file in st_files:
        dates_list.append(dt.datetime.strptime(os.path.basename(st_file)[15:23], "%Y%m%d"))
    dates_list = np.asarray(dates_list)

    for iyear in range(start_year, end_year + 1):
        year_idx = np.asarray([iyear == dat.year for dat in dates_list])
        clear_img = np.zeros((5000, 5000), dtype=np.uint16)
        for st_file in st_files[year_idx]:

            # LC08_CU_016006_20130320_20210501_02_QA_PIXEL
            # LC08_CU_002009_20130322_20210501_02_ST_QA.TIF

            stQA_file = st_file.replace('_ST.tif', '_STQA.tif')
            pxQA_file = st_file.replace('_ST.tif', '_PIXELQA.tif')

            st_data = utility_functions.get_tifflayer(st_file)
            stQA_data = utility_functions.get_tifflayer(stQA_file, dtype=np.int)
            pxQA_data = utility_functions.get_tifflayer(pxQA_file, dtype=np.uint16)

            clear_img += get_LSTclear_idx(st_data, stQA_data, pxQA_data)

        print('Saving', HV, iyear)
        y_res, x_res = clear_img.shape
        utility_functions.saveGeoTiff(os.path.join(outDir, '{}_numClear_{}.tif'.format(HV, iyear)),
                    clear_img, [x_res, y_res, 1], trans, prj, dtype=gdal.GDT_UInt16)


def determine_hv(x, y, aff=(-2565585, 150000, 0, 3314805, 0, -150000)):
    col = (x - aff[0] - aff[3] * aff[2]) / aff[1]
    row = (y - aff[3] - aff[0] * aff[4]) / aff[5]

    return int(col), int(row)



def st_stats_forHua(HV_path_list,iata_location_identifier,start_year_inclusive,end_year_inclusive):

    start_year_inclusive = int(start_year_inclusive)
    end_year_inclusive = int(end_year_inclusive)
    processing_year_label = dt.datetime.today().strftime('%Y%m%d')

    for HV_input_path in HV_path_list:

        # unzips the tar files to the tile folder
        # extract_ST(workpath, HV)

        ST_annual_stats( HV_input_path,iata_location_identifier, start_year_inclusive, end_year_inclusive,processing_year_label)

        # ST_seasonal_stats(workpath, HV_input_path,iata_location_identifier, start_year_inclusive, end_year_inclusive,processing_year_label)

        # print("starting trend_stats")

        # start_time = time.time()
        # trends_calculation.trend_stats(workpath, HV, iata_location_identifier, 'MEANLST', start_year_inclusive, end_year_inclusive,processing_year_label)
        # trends_calculation.trend_stats(workpath, HV, iata_location_identifier, 'MEDIANLST', start_year_inclusive,
        #                                end_year_inclusive,processing_year_label)
        # trends_calculation.trend_stats(workpath, HV, iata_location_identifier, 'MAXLST', start_year_inclusive,
        #                                end_year_inclusive,processing_year_label)
        # trends_calculation.trend_stats(workpath, HV, iata_location_identifier, 'MINLST', start_year_inclusive,
        #                                end_year_inclusive,processing_year_label)
        # end_time = time.time()
        # print("time taken " + str(end_time - start_time))



if __name__ == '__main__':
    #https://stackabuse.com/command-line-arguments-in-python/
    parser = argparse.ArgumentParser()
    #(workpath,HV_List,start_year_inclusive,end_year_inclusive)
    parser.add_argument("--workpath", "-w", help="specify workpath directory")
    parser.add_argument("--hv_tilename", "-t", help="specify tilename singular")
    parser.add_argument("--iata_location_identifier", "-i", help="iata_location_identifier")

    parser.add_argument("--start_year_inclusive", "-s", help="specify start_year_inclusive")
    parser.add_argument("--end_year_inclusive", "-e", help="specify end_year_inclusive")

    # Read arguments from the command line
    args = parser.parse_args()

    if args.workpath and args.hv_tilename and args.start_year_inclusive and args.end_year_inclusive:
        print(str(args))
        st_stats_forHua(args.workpath,[args.hv_tilename],args.iata_location_identifier,args.start_year_inclusive,args.end_year_inclusive)
        

        





