################################################################################
#
# Author:
#
# Bonnie Ruefenacht, PhD
# Senior Specialist
# RedCastle Resources, Inc.
# Working onsite at: 
# USDA Forest Service 
# Remote Sensing Applications Center (RSAC) 
# 2222 West 2300 South
# Salt Lake City, UT 84119
# Office: (801) 975-3828 
# Mobile: (801) 694-9215
# Email: bruefenacht@fs.fed.us
# RSAC FS Intranet website: http://fsweb.rsac.fs.fed.us/
# RSAC FS Internet website: http://www.fs.fed.us/eng/rsac/
#
################################################################################

import os
import glob
import time
try:
    from osgeo import gdal
    from osgeo import osr
except ImportError:
    import gdal
    import osr

from qgis.core import QgsMessageLog, Qgis
from FMT3 import constants as const

def reproject_image(inputImage, outputImageName, projection, spatialResolution, sampling):
    """
    Reproject image

    Parameters:
        * inputImage
        * outputImageName
        * projection
        * spatialResolution
        * sampling
    """
    inputImage = inputImage.upper()

    if not os.path.exists(outputImageName):

        inputImage = gdal.Open(inputImage)
        if (os.path.splitext(outputImageName)[1] == '.tif'):
            driver = gdal.GetDriverByName('GTiff')
        else:
            driver = gdal.GetDriverByName('HFA')
        outputImage = driver.Create(outputImageName,
                                    int(((projection[spatialResolution]['coordinates']['lrx'] 
                                        - projection[spatialResolution]['coordinates']['ulx'])//spatialResolution)+1), 
                                    int(((projection[spatialResolution]['coordinates']['uly'] 
                                        - projection[spatialResolution]['coordinates']['lry'])//spatialResolution)+1), 
                                    inputImage.RasterCount,
                                    projection['dataType'])
        outputImage.SetGeoTransform(projection[spatialResolution]['outputTransform'])
        outputImage.SetProjection(projection[spatialResolution]['outputProjection'].ExportToWkt())
        
        inputWKT = projection[spatialResolution]['inputProjection'].ExportToWkt()
        outputWKT = projection[spatialResolution]['outputProjection'].ExportToWkt()
        
        opts = gdal.WarpOptions(srcSRS=inputWKT, dstSRS=outputWKT, resampleAlg=sampling)
        outputImageRef = gdal.Warp(outputImage, inputImage, options=opts)        
        
        del inputImage
        del driver
        del outputImageRef
        del outputImage
        
    else:
        QgsMessageLog.logMessage( '-> ' + outputImageName + ' already exists.',
                                level=Qgis.Info)

def create_pyramid_layers(image):
    """
    Creates pyramid layers and generates statistics for an image.

    Parameters:
        * image
    """

    if (os.path.exists(image)):
        if not os.path.exists(image + '.ovr'):
            QgsMessageLog.logMessage('Time: ' + time_string() + ' -> creating pyramid layers for ' + image,
                                     level=Qgis.Info)
            ds = gdal.Open(image,0)
            gdal.SetConfigOption('COMPRESS_OVERVIEW', 'LZW')
            ds.BuildOverviews("NEAREST", [2,4,8,16,32])
            ds = None
            del ds            
        if not os.path.exists(image + '.aux.xml'):
            QgsMessageLog.logMessage('Time: ' + time_string() + ' -> generating statistics for ' + image,
                                     level=Qgis.Info)
            ds = gdal.Open(image, 0)
            infoOptions = gdal.InfoOptions(format='text', deserialize=True, computeMinMax=True, reportHistograms=True, reportProj4=False, stats=True, approxStats=False, computeChecksum=True, showGCPs=True, showMetadata=True, showRAT=True, showColorTable=True, listMDD=False, showFileList=True, allMetadata=True, extraMDDomains=None, wktFormat=None)
            gdal.Info(ds,options=infoOptions)

            ds = None
            del ds  

def stack_bands(inputImages, outputImageName, projection, spatialResolution, dataType):
    """
    Stacks bands into one image.

    Parameters:
        * inputImages
        * outputImageName
        * projection
        * spatialResolution
        * dataType
    """
    vrtName = outputImageName.replace(".tif", ".vrt")
    bandList = []
    inputImagesExist = True
    for inputImage in inputImages:
        if (os.path.exists(inputImage) == False):
            inputImagesExist = False
            QgsMessageLog.logMessage('Time: ' + time_string() + ' -> cannot find ' + inputImage,
                                     level=Qgis.Info)
        else:
            bandList.append(inputImage)

    if (inputImagesExist) and (len(bandList) > 0):
        QgsMessageLog.logMessage('Time: ' + time_string() + ' -> creating ' + outputImageName,
                                 level=Qgis.Info)
           
        if (projection.get('xsize') != None) & (projection.get('ysize') != None):
            x_size = projection['xsize']
            y_size = projection['ysize']
        else:
            x_size = int(((projection['coordinates']['lrx'] - projection['coordinates']['ulx'])//spatialResolution)+1)
            y_size = int(((projection['coordinates']['uly'] - projection['coordinates']['lry'])//spatialResolution)+1)
        
        try:
            vrtOptions = gdal.BuildVRTOptions(separate=True)
            kwargs = {
                'outputType':dataType,
                'creationOptions':["TILED=YES","BLOCKXSIZE=128","BLOCKYSIZE=128"]
            }            
            translateOpts = gdal.TranslateOptions(outputType=dataType, creationOptions="[BLOCKXSIZE=128, BLOCKYSIZE=128]")
        except:
            vrtOptions = gdal.BuildVRTOptions(separate=True)
            kwargs = {
                'outputType':dataType,
                'creationOptions':["TILED=YES","BLOCKXSIZE=128","BLOCKYSIZE=128"]
            }
            translateOpts = gdal.TranslateOptions(outputType=dataType, creationOptions="[BLOCKXSIZE=128, BLOCKYSIZE=128]")
    
        vrt = gdal.BuildVRT(vrtName, bandList, options=vrtOptions)    
        transform = vrt.GetGeoTransform()
        outputTif = gdal.Translate(outputImageName, vrt, **kwargs)

        vrt = None
        outputTif = None
        del vrt
        del outputTif
        os.remove(vrtName)

def get_inputImage_projection_information(projDict, image, satellite, spatialResolution, rescaleFactor, padding):
    """
    Get the input projection information for an image.
    
    Parameters:
        * projDict
        * image
        * satellite
        * spatialResolution
        * rescaleFactor
        * padding
    """

    projectionDict = dict(projDict)
    
    if satellite == 'sentinel2':
        outImg = image.replace('.JP2','.tif')
        if not os.path.exists(outImg):
            QgsMessageLog.logMessage(image, level=Qgis.Info)
            QgsMessageLog.logMessage(outImg, level=Qgis.Info)
            inputImage = gdal.Translate(outImg, image)
        else:
            inputImage = gdal.Open(outImg)
    else:
        outImg = image.replace('.TIF','_temp.tif')
        if not os.path.exists(outImg):
            QgsMessageLog.logMessage(image, level=Qgis.Info)
            QgsMessageLog.logMessage(outImg, level=Qgis.Info)
            inputImage = gdal.Translate(outImg, image)
        else:
            inputImage = gdal.Open(outImg, gdal.GA_ReadOnly)
        
    if (projectionDict == None):
        projectionDict = {}

    if (projectionDict.get(spatialResolution) == None):
        projectionDict[spatialResolution] = {}

    if (projectionDict[spatialResolution].get('inputProjection') == None):
        projectionDict[spatialResolution]['inputProjection'] = osr.SpatialReference(inputImage.GetProjection())
        
    if (projectionDict[spatialResolution].get('outputProjection') == None):
        projectionString = const.PROJ_STRINGS[projectionDict['area']]
        projectionDict[spatialResolution]['outputProjection'] = osr.SpatialReference()
        projectionDict[spatialResolution]['outputProjection'].ImportFromWkt(projectionString)
        
    if (projectionDict[spatialResolution].get('coordinates') == None):
        transformCoordinates = osr.CoordinateTransformation(projectionDict[spatialResolution]['inputProjection'], 
                                                            projectionDict[spatialResolution]['outputProjection'])
        transformImage = inputImage.GetGeoTransform()
        
        if (satellite == const.SENTINEL2):
            x_size = int(inputImage.RasterXSize*rescaleFactor)
            y_size = int(inputImage.RasterYSize*rescaleFactor)
        else:
            x_size = int(inputImage.RasterXSize*rescaleFactor)
            y_size = int(inputImage.RasterYSize*rescaleFactor)
            
        
        (ulx, uly, ulz) = transformCoordinates.TransformPoint(transformImage[0], transformImage[3])
        (lrx, lry, lrz) = transformCoordinates.TransformPoint(transformImage[0] + spatialResolution*x_size, transformImage[3] + spatialResolution*y_size*-1)

        NLCDulx = const.NLCD_COORD[projectionDict['area']]['ulx']
        NLCDuly = const.NLCD_COORD[projectionDict['area']]['uly']
        NLCDlrx = const.NLCD_COORD[projectionDict['area']]['lrx']
        NLCDlry = const.NLCD_COORD[projectionDict['area']]['lry']

        projectionDict[spatialResolution]['coordinates'] = {}
        projectionDict[spatialResolution]['coordinates']['ulx'] = NLCDulx - 30 * (int((NLCDulx - ulx)/30)+1) - padding
        projectionDict[spatialResolution]['coordinates']['uly'] = NLCDuly - 30 * (int((NLCDuly - uly)/30)+1) + padding
        projectionDict[spatialResolution]['coordinates']['lrx'] = NLCDlrx - 30 * (int((NLCDlrx - lrx)/30)+1) + padding
        projectionDict[spatialResolution]['coordinates']['lry'] = NLCDlry - 30 * (int((NLCDlry - lry)/30)+1) - padding


    if (projectionDict[spatialResolution].get('outputTransform') == None):
        projectionDict[spatialResolution]['outputTransform'] = (projectionDict[spatialResolution]['coordinates']['ulx'], 
                                                                spatialResolution, transformImage[2], 
                                                                projectionDict[spatialResolution]['coordinates']['uly'], 
                                                                transformImage[4], spatialResolution*-1)

    inputImage = None
    del inputImage
    tempFList = glob.glob(os.path.dirname(outImg) + '/*_temp.*')
    for tempFile in tempFList:
        os.remove(tempFile)
        
    QgsMessageLog.logMessage('********************** PROJECTION INFORMATION **********************',
                             level=Qgis.Info)
    QgsMessageLog.logMessage('Time: ' + time_string() + ' -> INPUT PROJECTION: ' + str(projectionDict[spatialResolution]['inputProjection']),
                             level=Qgis.Info)
    QgsMessageLog.logMessage('Time: ' + time_string() + ' -> OUTPUT PROJECTION: ' + str(projectionDict[spatialResolution]['outputProjection']),
                             level=Qgis.Info)
    QgsMessageLog.logMessage('Time: ' + time_string() + ' -> OUTPUT COORDINATES: ULX: ' + str(projectionDict[spatialResolution]['coordinates']['ulx']) 
                             + ', LRX: ' + str(projectionDict[spatialResolution]['coordinates']['lrx']) + ', ULY: ' 
                             + str(projectionDict[spatialResolution]['coordinates']['uly']) + ', LRY: ' + str(projectionDict[spatialResolution]['coordinates']['lry']),
                             level=Qgis.Info)
    QgsMessageLog.logMessage('********************************************************************',
                             level=Qgis.Info)


    return projectionDict

def time_string():
    """Return an easy to read time string."""
    hour = str(time.localtime()[3])
    minute = str(time.localtime()[4])
    second = str(time.localtime()[5])
    if (len(hour) == 1):
        hour = '0' + hour
    if (len(minute) == 1):
        minute = '0' + minute
    if (len(second) == 1):
        second = '0' + second
    return hour + ':' + minute + ':' + second
