import gdal
import ogr
import osr
from shapely.affinity import translate
from shapely.geometry import Polygon
import pandas as pd
import geopandas as gpd
from shutil import copyfile
import os

import shutil
import time

#Input Files
#ard = r"C:\Users\cwmueller\SRC\uhi\shapefiles\conus_ard_grid.shp"                   #ard grid (EPSG: 6326)
#city0 = r"C:\Users\cwmueller\SRC\uhi\shapefiles\usa_city_bd_2019.shp"               #city boundary shapefile
#Output Folders
#ard_city = r"C:\Users\cwmueller\SRC\uhi\shapefiles\ARD Tiles"                       #ARD tiles for specified city
#city5km = r"C:\Users\cwmueller\SRC\uhi\shapefiles\City_Boundaries_Buffer"           #city boundary w/5km buffer


#Input Files
ard = r"/caldera/projects/usgs/eros/urban_heat_islands/Phase2/uhi_metrobuffer_generation/conus_ard_grid.shp"                   #ard grid (EPSG: 6326)
#city0 = r"/caldera/projects/usgs/eros/urban_heat_islands/Phase2/uhi_metrobuffer_generation/usa_city_bd_2019.shp"               #city boundary shapefile
#city0 = r"/caldera/projects/usgs/eros/urban_heat_islands/Phase2/uhi_metrobuffer_generation/tl_rd22_us_uac20.shp"
city0 = r"/caldera/projects/usgs/eros/urban_heat_islands/Phase2/uhi_metrobuffer_generation/conus_uac20.shp"

#Output Folders
ard_city = r"/caldera/projects/usgs/eros/urban_heat_islands/Phase2/uhi_metrobuffer_generation/ARD Tiles"                       #ARD tiles for specified city
city5km = r"/caldera/projects/usgs/eros/urban_heat_islands/Phase2/uhi_metrobuffer_generation/City_Boundaries_Buffer"           #city boundary w/5km buffer
city5km_dissolved = r"/caldera/projects/usgs/eros/urban_heat_islands/Phase2/uhi_metrobuffer_generation/City_Boundaries_Dissolved"




ard_dic = {


            # ("/caldera/projects/usgs/eros/urban_heat_islands/Charlotte", ["H25V12", "H26V12", "H26V11"], "CLT", "Charlotte"),
            "CLT": ["Charlotte", [[25,12],[26,12],[26,11]]],
            # ("/caldera/projects/usgs/eros/urban_heat_islands/ColoradoSprings", ["H11V10", "H12V10", "H11V09", "H12V09"], "COS","ColoradoSprings"),
            "COS": ["ColoradoSprings", [[11,10],[12,10],[11,9],[12,9]]],
            # ("/caldera/projects/usgs/eros/urban_heat_islands/Washington", ["H27V09", "H28V09", "H27V08"], "DCA", "Washington"),
            "DCA": ["Washington", [[27,9],[28,9],[27,8]]],
            # ("/caldera/projects/usgs/eros/urban_heat_islands/Cincinnati", ["H23V09", "H23V10"], "CVG", "Cincinnati"),
            "CVG": ["Cincinnati", [[23,9],[23,10]]],
            # ("/caldera/projects/usgs/eros/urban_heat_islands/KansasCity", ["H17V09", "H17V10", "H18V10"], "MCI", "KansasCity"),
            "MCI": ["KansasCity", [[17,9],[17,10,],[18,10,]]],
            # ("/caldera/projects/usgs/eros/urban_heat_islands/Portland", ["H03V03"], "PDX", "Portland"),
            "PDX": ["Portland", [[3,3]]],
            # ("/caldera/projects/usgs/eros/urban_heat_islands/Phoenix", ["H06V13", "H07V13"], "PHX", "Phoenix"),
            "PHX": ["Phoenix", [[6,13],[7,13]]],
            # ("/caldera/projects/usgs/eros/urban_heat_islands/SanDiego", ["H03V13", "H04V12", "H04V13"], "SAN", "SanDiego"),
            "SAN": ["SanDiego", [[3,13],[4,12],[4,12],[4,13]]],
            # ("/caldera/projects/usgs/eros/urban_heat_islands/SaltLakeCity", ["H08V07", "H08V08"], "SLC", "SaltLakeCity"),
            "SLC": ["SaltLakeCity", [[8,7],[8,8]]],     
            # ("/caldera/projects/usgs/eros/urban_heat_islands/Sacramento", ["H02V08", "H03V08"], "SMF", "Sacramento")
            "SMF": ["Sacramento", [[2,8],[3,8]]],

            "CYS":["Cheyenne",[[12,8]]],
            "MSO": ["Missoula", [[7,3],[8,3]]],
            "FAR":["Fargo",[[16,4]]],
            "LIT":["LittleRock",[[19,13]]],
            "BOI":["Boise",[[6,5]]],
            "PIT":["Pittsburg",[[25,8],[26,8]]],
            "SDF":["Louisville",[[22,10],[23,10]]],
            "MSP":["Minneapolis",[[18,5],[18,6]]],
            "OMA":["Omaha",[[16,8],[17,8]]],
            "ABQ":["Alburquerque",[[10,12],[10,13]]],
            "MKE":["Milwaukee",[[21,6],[21,7]]],
            "LAS":["LasVegas",[[5,11]]],
            "DTW":["Detroit",[[23,6],[24,6],[23,7],[24,7]]],
            "BOS":["Boston",[[30,5],[30,6]]],
            "OKC":["OklahomaCity",[[16,12],[16,13]]],
            "MEM":["Memphis",[[20,12],[20,13]]],
            "BNA":["Nashville",[[22,11],[22,12]]],
            "ELP":["ElPaso",[[10,15]]],

            "DEN":["Denver",[[11,9],[12,9]]],
            "IND":["Indianopolis",[[22,9]]],

            "SFO":["SanFrancisco",[[1,8],[1,9],[2,8],[2,9]]],
            "CMH":["Columbus",[[24,8],[24,9]]],

            "DFW":["Dallas",[[16,14],[16,15]]],
            "SAT":["SanAntonio",[[15,17]]],
            "PHL":["Philadelphia",[[28,7],[28,8],[29,8]]],
            "DSM":["DesMoines",[[18,8]]],
            "JAX":["Jacksonville",[[26,15],[26,16]]],
            "IAH":["Houston",[[17,16],[17,17]]],

            "ORD":["Chicago",[[21,7],[21,8],[22,8]]],
            "LAX":["LosAngeles",[[3,12],[4,12]]],
            "LGA":["NewYork",[[28,7],[29,6],[29,7],[29,8]]],
            "FSD": ["Sioux Falls, SD", [[16,6]]],                                  #'IATA Code': ["City Boundary 'NAME10' ", [[ARD 'h', ARD 'v']] ]
            "BHX":["Birmingham",[[22,13],[22,14]]],

            "MSY":["NewOrleans",[[20,16],[20,17],[21,16],[21,17]]],
            "MIA":["Miami",[[27,18],[27,19]]],
            "RDU":["Raleigh",[[27,11]]],
            "ATL":["Atlanta",[[23,13],[23,14],[24,13],[24,14]]],
            "BWI":["Baltimore",[[27,8],[27,9],[28,8],[28,9]]],


            "SEA":["Seattle",[[3,2],[3,1],[4,1],[4,2]]],
            "AUS":["Austin",[[15,16],[16,16],[15,17]]]




           }

#Check city shapefile projection, if not identical to ARD projection, reproject to match

#returns ESPG value
def getproj(shp):
    driver= ogr.GetDriverByName('ESRI Shapefile')

    dataset = driver.Open(shp)
    layer = dataset.GetLayer()
    spatialRef = layer.GetSpatialRef()
    proj = spatialRef.GetAttrValue('AUTHORITY', 1)

    return proj

#returns projection information
def getprj(input):
    prj_file = input.replace(".shp", ".prj")
    prj = [l.strip() for l in open(prj_file, 'r')][0]
    return prj

def reproj(shp, inepsg, outepsg):
    print("Reprojecting City Shapefile to match CONUS ARD inputs.")
    driver = ogr.GetDriverByName('ESRI Shapefile')

    inSpatialRef = osr.SpatialReference()
    inSpatialRef.ImportFromEPSG(int(inepsg))

    # output SpatialReference
    outSpatialRef = osr.SpatialReference()
    outSpatialRef.ImportFromEPSG(int(outepsg))

    # create the CoordinateTransformation
    coordTrans = osr.CoordinateTransformation(inSpatialRef, outSpatialRef)

    # get the input layer
    inDataSet = driver.Open(shp)
    inLayer = inDataSet.GetLayer()

    # create the output layer
    outputShapefile = ard[:-4] + "_proj.shp"
    if os.path.exists(outputShapefile):
        driver.DeleteDataSource(outputShapefile)
    outDataSet = driver.CreateDataSource(outputShapefile)
    outLayer = outDataSet.CreateLayer("usa_city_bd_reproj", geom_type=ogr.wkbMultiPolygon)

    # add fields
    inLayerDefn = inLayer.GetLayerDefn()
    for i in range(0, inLayerDefn.GetFieldCount()):
        fieldDefn = inLayerDefn.GetFieldDefn(i)
        outLayer.CreateField(fieldDefn)

    # loop through the input features
    inFeature = inLayer.GetNextFeature()
    while inFeature:
        # get the input geometry
        geom = inFeature.GetGeometryRef()
        # reproject the geometry
        geom.Transform(coordTrans)
        # create a new feature
        outFeature = ogr.Feature(outLayer)
        # set the geometry and attribute
        outFeature.SetGeometry(geom)
        for i in range(0, outLayer.GetFieldCount()):
            outFeature.SetField(outLayer.GetFieldDefn(i).GetNameRef(), inFeature.GetField(i))
        # add the feature to the shapefile
        outLayer.CreateFeature(outFeature)
        # dereference the features and get the next input feature
        outFeature = None
        inFeature = inLayer.GetNextFeature()

def extract_ARDbyIATA(ard, iata, outdir):
    #Checks if output ARD subset exists, if not it exports the appropriate set of ARD tiles
    outfile = outdir + os.sep + "ard_" + iata + ".shp"

    if not os.path.exists(outfile):
        print("Extracting ARD Tiles for: " + iata)
        prj = getprj(ard)
        dataSrc = gpd.read_file(ard)
        shplist = []
        for hv in ard_dic[iata][1]:
            hv_shp = dataSrc[dataSrc['h'] == hv[0]][dataSrc['v'] == hv[1]]
            shplist = shplist + [hv_shp]

        df=gpd.GeoDataFrame(pd.concat(shplist))
        df.crs = prj
        df.to_file(outfile)

    return outfile
    #return output ARD

def clip_CityByARD(city_buffer, iata, target_ARD, outdir):
    #Clip exactly to the ARD boundary
    outfile = outdir + os.sep + "ard_" + iata + "_city_dissolved.shp"
    if not os.path.exists(outfile):
        print("Clipping buffered cities to ARD bounds: " + iata)
        prj = getprj(target_ARD)
        dataSrc = gpd.read_file(city_buffer)
        dataMask = gpd.read_file(target_ARD)
        df = gpd.clip(dataSrc, dataMask)
        df.crs = prj
        df.to_file(outfile)

    return outfile

def createBuffer(input, outputBuffer, bufferDist, epsg):
    buffShpPath = input.replace('.shp', '_{}km.shp'.format(int(bufferDist/1000)))
    inputds = ogr.Open(input)
    inputlyr = inputds.GetLayer()
    outputBuffer = outputBuffer + os.sep +os.path.basename(input).replace(".shp",'_{}km.shp'.format(int(bufferDist/1000)))

    if not os.path.exists(outputBuffer):
        print("Generating " + str(bufferDist / 1000) + " km buffer")
        dataSrc = gpd.read_file(input)
        df = dataSrc.buffer(bufferDist)
        df.to_file(outputBuffer)

    return outputBuffer

def dissolveCity(incity):
    outfile = incity.replace('.shp', '_dissolved.shp')

    if not os.path.exists(outfile):
        print("Dissolving overlapping city bounds")
        prj = getprj(incity)
        dataSrc = gpd.read_file(incity)
        geoms = dataSrc.geometry.unary_union
        df = gpd.GeoDataFrame(geometry=[geoms])
        df = df.explode().reset_index(drop=True)
        df.crs = prj
        df.to_file(outfile)

    return outfile


if __name__ == "__main__":
    #Checks if a projected city file already exists. If yes, city boundary is updated to the projected file
    if os.path.exists(city0[:-4] + "_proj.shp"):
        city0 = city0[:-4] + "_proj.shp"


    if os.path.exists(ard_city):
        shutil.rmtree(ard_city)
        time.sleep(10)
    
    if os.path.exists(city5km):
        shutil.rmtree(city5km)
        time.sleep(10)
    
    if os.path.exists(city5km_dissolved):
        shutil.rmtree(city5km_dissolved)
        time.sleep(10)
    
    if not os.path.exists(ard_city):
        os.mkdir(ard_city)
    if not os.path.exists(city5km):
        os.mkdir(city5km)
    if not os.path.exists(city5km_dissolved):
        os.mkdir(city5km_dissolved)

    #check city boundary file for matching projection with CONUS ARD - reproject if they don't match
    city0_epsg = getproj(city0)
    ard_epsg = getproj(ard)

    if city0_epsg != ard_epsg:
        print("Projecting City Boundaries to match ARD")
        print(city0_epsg)
        print("-----------------------------------------")
        print(ard_epsg)
        city0_proj=reproj(city0, city0_epsg, ard_epsg)
        city0 = city0[:-4] + "_proj.shp"

    # apply 5 km buffer (bufferDist is in meters)
    citybuffers = createBuffer(city0, city5km, bufferDist=5000, epsg=ard_epsg)

    citydissolve = dissolveCity(citybuffers)

    for iata in ard_dic:
        target_ARD = extract_ARDbyIATA(ard, iata, ard_city)
        clip_CityByARD(citydissolve, iata, target_ARD, city5km)
        clip_CityByARD(citydissolve, iata.lower(), target_ARD, city5km_dissolved)


#Per Hua, we need to have a standardized way of identifying the resultant boundaries within in a shapefile
    #One thought: XXXYYYY where XXX is the IATA code and YYY is the rank of the city by distance from the primary urban center
        #This may cause YYY
# to change in fture UHI versions, but changes should not be frequent and can be easily captured in documentation


#print("Script complete.")