
import json
import math
import os
import shutil
import time

import rasterio
from rasterio.mask import mask
from rasterio.io import MemoryFile
from rasterio.merge import merge

from osgeo import osr
import pycrs
import pyproj
import numpy as np
import fiona

from joblib import Parallel, delayed

import datetime as dt
import argparse

#For PERSISTENT URBAN CLASS and URBAN COVERAGE the urban class value is passed through while everything else is 0.

#this script just does persistent urban class
#augmented to also output a persistent non urban raster


def do_step_345(urban_raster_name,non_urban_raster_name,new_urban_class_raster_name,raster_step_322_1985_path,raster_step_322_2020_path,final_output_directory,full_path_to_dissolved_city_boundaries_shapefile):

    if os.path.exists(final_output_directory):
        shutil.rmtree(final_output_directory)
        time.sleep(10)
    if not os.path.exists(final_output_directory):
        os.mkdir(final_output_directory)


    boundary_geometries = []
    with fiona.open(full_path_to_dissolved_city_boundaries_shapefile, "r") as shapefile:
        for feature in shapefile:
            boundary_geometries.append(feature["geometry"])

    source_raster_step_322_1985 = rasterio.open(raster_step_322_1985_path)
    source_raster_step_322_2020 = rasterio.open(raster_step_322_2020_path)


    masked_raster_A, masked_transform_raster_A = mask(source_raster_step_322_1985, boundary_geometries, crop=False, pad=False, pad_width=0.0,
                                                  all_touched=True)
    masked_raster_B, masked_transform_raster_B = mask(source_raster_step_322_2020, boundary_geometries, crop=False, pad=False, pad_width=0.0,
                                                  all_touched=True)

    designated_output_nodata_val = -9999


    #------------------------------------------------------------------------------------

    #PUC
    new_urban_image = np.full(masked_raster_A.shape, designated_output_nodata_val,dtype=np.int32)
    new_urban_out_meta = source_raster_step_322_1985.meta
    new_urban_out_meta['dtype'] = 'int32'
    for n1 in range(new_urban_image.shape[1]):
        for n2 in range(new_urban_image.shape[2]):

            old_val = masked_raster_A[0][n1][n2]
            new_val = masked_raster_B[0][n1][n2]
            o_val = designated_output_nodata_val

            if (new_val == 21 and old_val == 21) or \
               (new_val == 22 and old_val == 22) or \
               (new_val == 23 and old_val == 23) or \
               (new_val == 24 and old_val == 24) or \
               (new_val == 25 and old_val == 25):
                #o_val = new_val
                o_val = 1
            else:
                o_val = 0
            new_urban_image[0][n1][n2] = o_val
    new_urban_out_meta.update({"driver": "GTiff",
                         "height": masked_raster_A.shape[1],
                         "width": masked_raster_A.shape[2],
                         "transform": masked_transform_raster_A,
                         "nodata": designated_output_nodata_val,
                         "compress": "LZW"})
    output_path = os.path.join(final_output_directory,urban_raster_name)
    with rasterio.open(output_path, "w", **new_urban_out_meta) as dest:
        dest.write(new_urban_image)
        print(output_path)

    #------------------------------------------------------------------------------------

    #PNC
    new_non_urban_image = np.full(masked_raster_A.shape, designated_output_nodata_val,dtype=np.int32)
    new_non_urban_out_meta = source_raster_step_322_1985.meta
    new_non_urban_out_meta['dtype'] = 'int32'
    for n1 in range(new_non_urban_image.shape[1]):
        for n2 in range(new_non_urban_image.shape[2]):

            old_val = masked_raster_A[0][n1][n2]
            new_val = masked_raster_B[0][n1][n2]
            o_val = designated_output_nodata_val

            if (new_val == 2 and old_val == 2) or \
               (new_val == 3 and old_val == 3) or \
               (new_val == 4 and old_val == 4) or \
               (new_val == 5 and old_val == 5) or \
               (new_val == 6 and old_val == 6) or \
               (new_val == 7 and old_val == 7) or \
               (new_val == 8 and old_val == 8):
                #o_val = new_val
                o_val = 1
            else:
                o_val = 0
            new_non_urban_image[0][n1][n2] = o_val
    new_non_urban_out_meta.update({"driver": "GTiff",
                         "height": masked_raster_A.shape[1],
                         "width": masked_raster_A.shape[2],
                         "transform": masked_transform_raster_A,
                         "nodata": designated_output_nodata_val,
                         "compress": "LZW"})
    output_path = os.path.join(final_output_directory,non_urban_raster_name)
    with rasterio.open(output_path, "w", **new_non_urban_out_meta) as dest:
        dest.write(new_non_urban_image)
        print(output_path)
    
    #------------------------------------------------------------------------------------

    #new_urban_class = NUC
    
    new_urban_class_image = np.full(masked_raster_A.shape, designated_output_nodata_val,dtype=np.int32)
    new_urban_class_out_meta = source_raster_step_322_1985.meta
    new_urban_class_out_meta['dtype'] = 'int32'

    '''
    #old logic
    for n1 in range(new_urban_class_image.shape[1]):
        for n2 in range(new_urban_class_image.shape[2]):
            old_val = masked_raster_A[0][n1][n2]
            new_val = masked_raster_B[0][n1][n2]
            o_val = designated_output_nodata_val
            if (new_val == 21) or \
               (new_val == 22) or \
               (new_val == 23) or \
               (new_val == 24) or \
               (new_val == 25):
                o_val = 1
            else:
                o_val = 0
            new_urban_class_image[0][n1][n2] = o_val
    '''
    
    for n1 in range(new_urban_class_image.shape[1]):
        for n2 in range(new_urban_class_image.shape[2]):
            old_val = masked_raster_A[0][n1][n2]
            new_val = masked_raster_B[0][n1][n2]
            o_val = designated_output_nodata_val
            if new_val in [21,22,23,24,25] and old_val not in [21,22,23,24,25]:
                o_val = 1
            else:
                o_val = 0
            if o_val == 0 and old_val == 2 and new_val in [21,22,23,24,25]:
                o_val = -1 #should never happen
            new_urban_class_image[0][n1][n2] = o_val


    new_urban_class_out_meta.update({"driver": "GTiff",
                         "height": masked_raster_A.shape[1],
                         "width": masked_raster_A.shape[2],
                         "transform": masked_transform_raster_A,
                         "nodata": designated_output_nodata_val,
                         "compress": "LZW"})
    output_path = os.path.join(final_output_directory,new_urban_class_raster_name)
    with rasterio.open(output_path, "w", **new_urban_class_out_meta) as dest:
        dest.write(new_urban_class_image)
        print(output_path)

    #------------------------------------------------------------------------------------


'''
full_path_to_dissolved_city_boundaries_shapefile = "D:\\Users\\rhussain\\Desktop\\PycharmProjects\\prototype_uhi_phase2\\City_Boundaries_Dissolved\\ard_fsd_city_dissolved.shp"

raster_step_322_1985_path = "Y:\\UHI\\data\\phase2_test_outputs\\full_set8\\SiouxFalls\\UHI_CU_FSD_1985_20211116_C01_LC.tif"
raster_step_322_2020_path = "Y:\\UHI\\data\\phase2_test_outputs\\full_set8\\SiouxFalls\\UHI_CU_FSD_2020_20211221_C01_LC.tif"

output_urban_raster_name = "UHI_CU_FSD_20211116_C01_PUC.tif"
output_non_urban_raster_name = "UHI_CU_FSD_20211221_C01_PNC.tif"
output_new_urban_class_raster_name = "UHI_CU_FSD_20211221_C01_NUC.tif"

final_output_directory = "D:\\RezaTemp\\uhi_step_345\\"

do_step_345(urban_raster_name,non_urban_raster_name,raster_step_322_1985_path,raster_step_322_2020_path,final_output_directory,full_path_to_dissolved_city_boundaries_shapefile)
'''


if __name__ == '__main__':
    # https://eroslab.cr.usgs.gov/SCIENCE/urban-heat-islands/-/blob/5589d7a9c4a830459f14a3abd3bc90fa8e365224/Phase2/uhi_stats_classbased.py
    # https://stackabuse.com/command-line-arguments-in-python/
    parser = argparse.ArgumentParser()
    parser.add_argument("--full_path_to_dissolved_city_boundaries_shapefile", "-a", help="specify full_path_to_dissolved_city_boundaries_shapefile directory")
    parser.add_argument("--raster_step_332_1985_path", "-b", help="raster_step_332_1985_path")
    parser.add_argument("--raster_step_332_2020_path", "-c", help="raster_step_332_2020_path")
    parser.add_argument("--final_output_directory", "-d", help="final_output_directory")
    parser.add_argument("--output_urban_raster_name", "-e", help="output_urban_raster_name")
    parser.add_argument("--output_non_urban_raster_name", "-f", help="output_non_urban_raster_name")
    parser.add_argument("--output_new_urban_class_raster_name", "-g", help="output_new_urban_class_raster_name")
    args = parser.parse_args()
    do_step_345(args.output_urban_raster_name,args.output_non_urban_raster_name,args.output_new_urban_class_raster_name,args.raster_step_332_1985_path,args.raster_step_332_2020_path,args.final_output_directory,args.full_path_to_dissolved_city_boundaries_shapefile)
