
import json
import math
import os
import shutil
import time

import rasterio
from rasterio.mask import mask
from rasterio.io import MemoryFile
from rasterio.merge import merge

from osgeo import osr
import pycrs
import pyproj
import numpy as np
import fiona

from joblib import Parallel, delayed

import datetime as dt

import argparse

"""
https://losrlcmp10.cr.usgs.gov/jira/browse/LSRD-6281
https://losrlcmp10.cr.usgs.gov/jira/browse/LSRD-6381
https://losrlcmp10.cr.usgs.gov/jira/browse/LSRD-6343

My summary:
    Take the pixels marked by the peristently urban class, and use the UHI\data\LST_STACK_outputs for those pixels 
    to output MAXLST_MEAN, MEANLST_MEAN, MINLST_MEAN rasters.
    Do the same for Persistently urban coverage


#    PUC = persistent urban class
#    PNC = persistent non urban class
#    PUG = persistent urban coverage
#    PNG = persistent non urban coverage
"""


def do_step_354(raster_path, lookup_raster_path, final_output_directory, final_output_file_name):
    raster = rasterio.open(raster_path)
    lookup_raster = rasterio.open(lookup_raster_path)

    #print("raster shape1 "+str(raster.shape)+" lookup_raster shape2 "+str(lookup_raster.shape))

    # so lookup raster
    # it is a smaller size than raster
    # because of that before we compare grids
    # we have to make sure we get the same grids from each raster
    # do this by getting the bounds of the smaller window
    # then getting the grids from the larger window using the smaller window bounds
    # https://github.com/rasterio/rasterio/issues/622

    small_bounds = lookup_raster.bounds
    raster_equivalent_window = raster.window(*small_bounds)
    # so what we can do is do a read
    # using the windows to get a matching dimensional array
    # then write to the window in the output dataset
    # https://rasterio.readthedocs.io/en/latest/api/rasterio._io.html?highlight=read(#rasterio._io.DatasetReaderBase.read
    raster_data = raster.read(1, window=raster_equivalent_window)
    lookup_raster_data = lookup_raster.read(1)
    lookthis = raster_equivalent_window.col_off
    # -------------------------------------------------

    #new_shape = (1, raster.shape[0], raster.shape[1])
    # new_shape = (1, smallest_dim_0, smallest_dim_1)
    new_shape = (1, raster_data.shape[0], raster_data.shape[1])

    new_out_meta = lookup_raster.meta
    new_out_meta['dtype'] = lookup_raster.dtypes[0]

    designated_output_nodata_val = -9999.0
    new_image = np.full(new_shape, designated_output_nodata_val,dtype=np.float32)

    # looking = new_image.shape
    # raster_data = raster.read(1)
    # lookup_raster_data = lookup_raster.read(1)

    #print("new image shape "+ str(new_shape))
    #print("new raster_data shape "+ str(raster_data.shape))
    #print("new lookup_raster_data shape "+ str(lookup_raster_data.shape))

    unique = np.unique(lookup_raster_data)
    unique1 = np.unique(raster_data)

    for n1y in range(new_image.shape[1]):
        for n2x in range(new_image.shape[2]):

            #offset_relative_y = int(n1y - raster_equivalent_window.row_off)
            #offset_relative_x = int(n2x - raster_equivalent_window.col_off)
            #if (offset_relative_y) >= 0 and offset_relative_x >= 0 and offset_relative_y < raster_equivalent_window.height and offset_relative_x < raster_equivalent_window.width:
            
            raster_value = raster_data[n1y][n2x]
            if raster_value > 0:
                lookup_value = lookup_raster_data[n1y][n2x]
                
                if lookup_value != -9999:
                    #new_image[0][n1y][n2x] = 1
                    new_image[0][n1y][n2x] = lookup_value                    
                
                '''
                if lookup_value != -9999:
                    new_image[0][n1y][n2x] = 1
                    #new_image[0][n1y][n2x] = lookup_value
                else:
                    new_image[0][n1y][n2x] = 0
                '''

    unique2 = np.unique(new_image)
    new_out_meta.update({"driver": "GTiff",
                         "height": new_image.shape[1],
                         "width": new_image.shape[2],
                         "transform": lookup_raster.transform,
                         "nodata": designated_output_nodata_val,
                         "compress": "LZW"})

    output_path = os.path.join(final_output_directory, final_output_file_name)
    with rasterio.open(output_path, "w", **new_out_meta) as dest:
        dest.write(new_image)
        print(output_path)



def pull_forward_values_from_phase1_slope_et_all_calcuations(raster_from_354_filepath,source_phase1_raster_filepath,final_output_directory,final_output_file_name):
    
    #-------------------------------------------------

    # get the input raster for masking
    # where it is 1
    # copy the value from the source rasters
    # loop through the source rasters
    # they are all of the tif files from /caldera/projects/usgs/eros/urban_heat_islands/Phase1/SMF/output_folder/TOTAL
    # take the last part of the name too
    # sftp://rhussain@denali.cr.usgs.gov/caldera/projects/usgs/eros/urban_heat_islands/Phase1/SMF/output_folder/TOTAL/UHI_CU_SMF_1984_2020_20210830_C01_MINLST_NUMCLEAR.tif
    # in that case we take the _MINLST_NUMCLEAR and use it in the new filename

    #-------------------------------------------------

    raster_354_dataset = rasterio.open(raster_from_354_filepath)
    phase1_raster_dataset = rasterio.open(source_phase1_raster_filepath)
    raster_354 = raster_354_dataset.read(1)
    phase1_raster = phase1_raster_dataset.read(1)

    #-------------------------------------------------

    #print("shape1 "+str(phase1_raster.shape)+" shape2 "+str(raster_354.shape))
    #new_shape = (1, phase1_raster.shape[0], phase1_raster.shape[1])
    new_shape = (1, raster_354.shape[0], raster_354.shape[1])
    new_out_meta = phase1_raster_dataset.meta
    new_out_meta['dtype'] = phase1_raster_dataset.dtypes[0]

    designated_output_nodata_val = -9999.0
    new_image = np.full(new_shape, designated_output_nodata_val,dtype=np.float32)

    for n1y in range(new_image.shape[1]):
        for n2x in range(new_image.shape[2]):
            if raster_354[n1y][n2x] != designated_output_nodata_val:
                new_image[0][n1y][n2x] = phase1_raster[n1y][n2x]

    new_out_meta.update({"driver": "GTiff",
                         "height": new_image.shape[1],
                         "width": new_image.shape[2],
                         "transform": phase1_raster_dataset.transform,
                         "nodata": designated_output_nodata_val,
                         "compress": "LZW"})

    output_path = os.path.join(final_output_directory, final_output_file_name)
    with rasterio.open(output_path, "w", **new_out_meta) as dest:
        dest.write(new_image)
        print(output_path)




if __name__ == '__main__':

    #=======================================================

    # https://eroslab.cr.usgs.gov/SCIENCE/urban-heat-islands/-/blob/5589d7a9c4a830459f14a3abd3bc90fa8e365224/Phase2/uhi_stats_classbased.py
    # https://stackabuse.com/command-line-arguments-in-python/

    parser = argparse.ArgumentParser()
    parser.add_argument("--output_345_path", "-a", help="output_345_path")
    parser.add_argument("--output_346_path", "-b", help="output_346_path")
    parser.add_argument("--stack_path", "-c", help="stack_path")
    parser.add_argument("--iata_location_identifier", "-d", help="city IATA code")
    #/caldera/projects/usgs/eros/urban_heat_islands/Phase1/MSO/output_folder/TOTAL
    parser.add_argument("--phase1_total_folder_path", "-e", help="phase1_total_folder_path")
    parser.add_argument("--final_output_directory", "-f", help="final_output_directory")
    args = parser.parse_args()

    output_345_path = args.output_345_path
    output_346_path = args.output_346_path
    stack_path = args.stack_path
    iata_code = args.iata_location_identifier

    #=======================================================

    output_345_path = "/caldera/projects/usgs/eros/urban_heat_islands/Phase2/section_345/"+iata_code
    filenames = os.listdir(output_345_path)
    for filename in filenames:
        if "PUC" in filename and iata_code in filename:
            persistent_urban_class_raster_path_from_345 = os.path.join(output_345_path,filename)
        if "PNC" in filename and iata_code in filename:
            persistent_non_urban_class_raster_path_from_345 = os.path.join(output_345_path,filename)
        if "NUC" in filename and iata_code in filename:
            output_new_urban_class_raster_path_from_345 = os.path.join(output_345_path,filename)


    output_346_path = "/caldera/projects/usgs/eros/urban_heat_islands/Phase2/section_346/"+iata_code
    filenames = os.listdir(output_346_path)
    for filename in filenames:
        if "PUG" in filename and iata_code in filename:
            pug_raster_path_from_346 = os.path.join(output_346_path,filename)
        if "PNG" in filename and iata_code in filename:
            persistent_non_urban_coverage_raster_path_from_346 = os.path.join(output_346_path,filename)
        if "NUG" in filename and iata_code in filename:
            nug_raster_path_from_346 = os.path.join(output_346_path,filename)

    stack_path = "/caldera/projects/usgs/eros/urban_heat_islands/Phase1/" + iata_code + "/output_folder/STACK"

    filenames = os.listdir(stack_path)
    for filename in filenames:
        if "MAXLST_MEAN" in filename:
            max_lst_mean_raster_stack_path = os.path.join(stack_path,filename)
        if "MEANLST_MEAN" in filename:
            mean_lst_mean_raster_stack_path = os.path.join(stack_path,filename)
        if "MINLST_MEAN" in filename:
            min_lst_mean_raster_stack_path = os.path.join(stack_path,filename)






    #-------------------------------------------------------

    # 345 is persistent urban class Persist_imp_4cl
    # Y:\UHI\data\phase2_test_outputs\uhi_step_345\persist_imp_4cl.tif
    # 346 is persistent urban coverage Persist_imp_4cl
    # Y:\UHI\data\phase2_test_outputs\uhi_step_346\persist_imp_4cl.tif

    '''
    #  PUC = persistent urban class
    persistent_urban_class_raster_path_from_345 = "Y:\\UHI\\data\\phase2_test_outputs\\uhi_step_345_lsrd_6378\\UHI_CU_FSD_20211116_C01_PUC.tif"

    #  PNC = persistent non urban class
    persistent_non_urban_class_raster_path_from_345 = "Y:\\UHI\\data\\phase2_test_outputs\\uhi_step_345_lsrd_6378\\UHI_CU_FSD_20211221_C01_PNC.tif"

    #  PUG = persistent urban coverage
    persistent_urban_coverage_raster_path_from_346 = "Y:\\UHI\\data\\phase2_test_outputs\\uhi_step_346_with_non_urban\\UHI_CU_FSD_20211221_C01_PUG.tif"

    #  PNG = persistent non urban coverage
    persistent_non_urban_coverage_raster_path_from_346 = "Y:\\UHI\\data\\phase2_test_outputs\\uhi_step_346_with_non_urban\\UHI_CU_FSD_20211221_C01_PNG.tif"

    max_lst_mean_raster_stack_path = "Y:\\UHI\\data\\phase2_test_outputs\\LST_STACK_outputs\\UHI_CU_FSD_STACK_20211118_C01_MAXLST_MEAN.tif"
    mean_lst_mean_raster_stack_path = "Y:\\UHI\\data\\phase2_test_outputs\\LST_STACK_outputs\\UHI_CU_FSD_STACK_20211118_C01_MEANLST_MEAN.tif"
    min_lst_mean_raster_stack_path = "Y:\\UHI\\data\\phase2_test_outputs\\LST_STACK_outputs\\UHI_CU_FSD_STACK_20211118_C01_MINLST_MEAN.tif"

    final_output_directory = "D:\\RezaTemp\\uhi_step_354\\"
    '''

    print(args.final_output_directory)
    if os.path.exists(args.final_output_directory):
        shutil.rmtree(args.final_output_directory)
        time.sleep(10)

    #================================

    if not os.path.exists(args.final_output_directory):
        os.mkdir(args.final_output_directory)


    do_step_354(persistent_non_urban_class_raster_path_from_345, max_lst_mean_raster_stack_path, args.final_output_directory,
                "UHI_CU_"+iata_code+"_STACK_20211118_C01_MAXLST_MEAN_PNC.TIF")
    do_step_354(persistent_non_urban_class_raster_path_from_345, mean_lst_mean_raster_stack_path, args.final_output_directory,
                "UHI_CU_"+iata_code+"_STACK_20211118_C01_MEANLST_MEAN_PNC.TIF")
    do_step_354(persistent_non_urban_class_raster_path_from_345, min_lst_mean_raster_stack_path, args.final_output_directory,
                "UHI_CU_"+iata_code+"_STACK_20211118_C01_MINLST_MEAN_PNC.TIF")
    
    do_step_354(persistent_urban_class_raster_path_from_345, max_lst_mean_raster_stack_path, args.final_output_directory,
                "UHI_CU_"+iata_code+"_STACK_20211118_C01_MAXLST_MEAN_PUC.TIF")
    do_step_354(persistent_urban_class_raster_path_from_345, mean_lst_mean_raster_stack_path, args.final_output_directory,
                "UHI_CU_"+iata_code+"_STACK_20211118_C01_MEANLST_MEAN_PUC.TIF")
    do_step_354(persistent_urban_class_raster_path_from_345, min_lst_mean_raster_stack_path, args.final_output_directory,
                "UHI_CU_"+iata_code+"_STACK_20211118_C01_MINLST_MEAN_PUC.TIF")

    do_step_354(output_new_urban_class_raster_path_from_345, max_lst_mean_raster_stack_path, args.final_output_directory,
                "UHI_CU_"+iata_code+"_STACK_20211118_C01_MAXLST_MEAN_NUC.TIF")
    do_step_354(output_new_urban_class_raster_path_from_345, mean_lst_mean_raster_stack_path, args.final_output_directory,
                "UHI_CU_"+iata_code+"_STACK_20211118_C01_MEANLST_MEAN_NUC.TIF")
    do_step_354(output_new_urban_class_raster_path_from_345, min_lst_mean_raster_stack_path, args.final_output_directory,
                "UHI_CU_"+iata_code+"_STACK_20211118_C01_MINLST_MEAN_NUC.TIF")

    #================================

    do_step_354(pug_raster_path_from_346, max_lst_mean_raster_stack_path, args.final_output_directory,
                "UHI_CU_"+iata_code+"_STACK_20211118_C01_MAXLST_MEAN_PUG.TIF")
    do_step_354(pug_raster_path_from_346, mean_lst_mean_raster_stack_path, args.final_output_directory,
                "UHI_CU_"+iata_code+"_STACK_20211118_C01_MEANLST_MEAN_PUG.TIF")
    do_step_354(pug_raster_path_from_346, min_lst_mean_raster_stack_path, args.final_output_directory,
                "UHI_CU_"+iata_code+"_STACK_20211118_C01_MINLST_MEAN_PUG.TIF")


    do_step_354(nug_raster_path_from_346, max_lst_mean_raster_stack_path, args.final_output_directory,
                "UHI_CU_"+iata_code+"_STACK_20211118_C01_MAXLST_MEAN_NUG.TIF")
    do_step_354(nug_raster_path_from_346, mean_lst_mean_raster_stack_path, args.final_output_directory,
                "UHI_CU_"+iata_code+"_STACK_20211118_C01_MEANLST_MEAN_NUG.TIF")
    do_step_354(nug_raster_path_from_346, min_lst_mean_raster_stack_path, args.final_output_directory,
                "UHI_CU_"+iata_code+"_STACK_20211118_C01_MINLST_MEAN_NUG.TIF")
    
    #================================

    


    rasters_to_run_slope_et_al_against = [
    ("UHI_CU_"+iata_code+"_STACK_20211118_C01_MAXLST_MEAN_PNC","MAXLST"),
    ("UHI_CU_"+iata_code+"_STACK_20211118_C01_MEANLST_MEAN_PNC","MEANLST"),
    ("UHI_CU_"+iata_code+"_STACK_20211118_C01_MINLST_MEAN_PNC","MINLST"),
    ("UHI_CU_"+iata_code+"_STACK_20211118_C01_MAXLST_MEAN_PUC","MAXLST"),
    ("UHI_CU_"+iata_code+"_STACK_20211118_C01_MEANLST_MEAN_PUC","MEANLST"),
    ("UHI_CU_"+iata_code+"_STACK_20211118_C01_MINLST_MEAN_PUC","MINLST"),
    ("UHI_CU_"+iata_code+"_STACK_20211118_C01_MAXLST_MEAN_NUC","MAXLST"),
    ("UHI_CU_"+iata_code+"_STACK_20211118_C01_MEANLST_MEAN_NUC","MEANLST"),
    ("UHI_CU_"+iata_code+"_STACK_20211118_C01_MINLST_MEAN_NUC","MINLST"),

    ("UHI_CU_"+iata_code+"_STACK_20211118_C01_MAXLST_MEAN_PUG","MAXLST"),
    ("UHI_CU_"+iata_code+"_STACK_20211118_C01_MEANLST_MEAN_PUG","MEANLST"),
    ("UHI_CU_"+iata_code+"_STACK_20211118_C01_MINLST_MEAN_PUG","MINLST"),

    ("UHI_CU_"+iata_code+"_STACK_20211118_C01_MAXLST_MEAN_NUG","MAXLST"),
    ("UHI_CU_"+iata_code+"_STACK_20211118_C01_MEANLST_MEAN_NUG","MEANLST"),
    ("UHI_CU_"+iata_code+"_STACK_20211118_C01_MINLST_MEAN_NUG","MINLST")
    ]

    

    print("doing slope et al now")
    print("phase1 path "+args.phase1_total_folder_path)

    filenames = os.listdir(args.phase1_total_folder_path)

    for filename in filenames:
        if "tif" in filename:
            #print(filename)
            filename_pieces = filename.split("_")
            filename_suffix_we_want = "_" + filename_pieces[-1]
            for (raster_filename,filename_filter) in rasters_to_run_slope_et_al_against:

                if filename_filter not in filename:
                    continue

                output_file_name = raster_filename + filename_suffix_we_want

                source_phase1_raster_filepath = os.path.join(args.phase1_total_folder_path,filename)
                raster_from_354_filepath = os.path.join(args.final_output_directory,raster_filename + ".TIF")
                
                print("raster_from_354_filepath= " + raster_from_354_filepath)
                print("phase_1_raster= "+source_phase1_raster_filepath)
                print("output= " + args.final_output_directory + output_file_name)

                pull_forward_values_from_phase1_slope_et_all_calcuations(raster_from_354_filepath,source_phase1_raster_filepath,args.final_output_directory,output_file_name)
    
    #/caldera/projects/usgs/eros/urban_heat_islands/Phase1/MSO/output_folder/TOTAL
    #UHI_CU_MSO_1984_2020_20220829_C01_MINLST_NUMCLEAR.tif

    #pull_forward_values_from_phase1_slope_et_all_calcuations(raster_from_354_filepath,source_phase1_raster_filepath,final_output_directory,final_output_file_name)

    #def pull_forward_values_from_phase1_slope_et_all_calcuations(raster_from_354_filepath,source_phase1_raster_filepath,final_output_directory,final_output_file_name):


