'''
/***************************************************************************
Name		     : PreprocessData
Description          : Pre-process data downloaded from EarthExplorer or 
                       Copernicus
Created              : Jul 19, 2024
Updated              :
******************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
'''
import traceback
import glob
import math
import numpy as np
import os
import shutil
import tarfile
import zipfile
import time
import gzip
from xml.dom import minidom
try:
    from osgeo import gdal
except ImportError:
    import gdal

from qgis.core import QgsMessageLog, Qgis
from PyQt5.QtWidgets import QMessageBox

from FMT3 import constants as const
from FMT3 import preprocess_util as util


class PreprocessData(object):

    def __init__(self, reproject_flag, output_dir, base_dir, area, sampling, work_dir):
        '''
        Initialize preprocess script

        Params:
            reproject_flag
            output_dir
            base_dir
            area
            sampling
        '''
        self.reproject_flag = reproject_flag
        self.output_dir = output_dir
        self.base_dir = base_dir
        self.area = area
        self.sampling = sampling
        self.work_dir = work_dir

    def prepare_landsat_file(self, input_file):
        '''
        Prepare landsat file to create utm, refl, and nbr outputs
        
        Params:
            input_file (str) - path to input file
        '''
        try:
            # make temp working dir if it does not already exist
            if os.path.exists(self.work_dir):
                shutil.rmtree(self.work_dir)
            os.makedirs(self.work_dir)

            # make landsat dir if it doesn't exist
            landsat_dir = os.path.join(self.output_dir, const.LANDSAT)
            if not os.path.exists(landsat_dir):
                os.makedirs(landsat_dir)

            # get file name from input file path
            scene_id = os.path.basename(input_file)
            QgsMessageLog.logMessage('Begin processing for : ' + scene_id,
                                    level=Qgis.Info)
            
            # input file name format:
            # ex. LC08_L1TP_021037_20240312_20240401_02_T1.tar
            # '08'  = sensor number
            # 'L1'  = level
            # '021' = path
            # '037' = row
            # '02'  = collection
            sensor_name = scene_id[0:4]
            level = scene_id[5:7]
            self.sensor = scene_id[3:4]
            self.path = int(scene_id[10:13])
            self.row = int(scene_id[13:16])
            self.collection = scene_id.split('_')[5]

            temp_path = str(self.path).zfill(3)
            temp_row = str(self.row).zfill(3)

            # create output location if needed
            pr_dir = os.path.join(landsat_dir,
                                temp_path + temp_row)
            scene_id_short = (self.sensor + temp_path + temp_row 
                              + scene_id[17:25] + '_' + level)
            scene_id_path = os.path.join(pr_dir, scene_id_short)
            if not os.path.exists(scene_id_path):
                os.makedirs(scene_id_path)
            else:
                reply = QMessageBox.question(None, 'Overwrite?',
                                            ('Output folder for file ' + scene_id +
                                            ' already exists.'
                                            '\nOverwrite?'),
                                            QMessageBox.Yes,
                                            QMessageBox.No)
                if reply == QMessageBox.No:
                    return
                else:
                    shutil.rmtree(scene_id_path)
                    os.mkdir(scene_id_path)
            
            # untar files
            self.untar_file(input_file)

            # untar gap mask files for landsat 7
            if self.sensor == '07':
                self.untar_gapmask(self.work_dir)

            # stack and reproject
            self.stack_and_reproject(self.work_dir, const.LANDSAT, landsat_dir, 
                                     self.area, self.sampling, sensor_name, level=level)

            # remove landsat fringes
            self.remove_landsat_fringes(self.work_dir, sensor_name)

            # toa images
            sunEarthDist = os.path.join(self.base_dir, const.SUN_EARTH_DIST_FILE)
            self.create_toa_images(sensor_name, self.work_dir, scene_id_path, sunEarthDist, level=level)

            # nbr images
            self.create_nbr_images(self.work_dir, scene_id_path, const.LANDSAT, sensor_name.lower(), level=level)

            return True

        except Exception as ex:
            QgsMessageLog.logMessage(str(ex) + "\n" + traceback.format_exc(), level=Qgis.Critical)
            return False

    
    def check_landsat_datatype(self, file, datatype, collectionType):
        '''Checks if user selected data type agrees with the image data type
        (Based on LPGS Script, Step01_Uncompress_Tarballs.py)'''
        dataTypes = const.LANDSAT_DATATYPES
        try:
            if collectionType == '01':
                QgsMessageLog.logMessage('The current data file ' + os.path.split(file)[1] + 
                      ' is collection 1 data and collection 1 data processing'
                       ' is not allow in this version of LPGS', level=Qgis.Info)
                raise Exception
        except:
            return False
        if (dataTypes['L1T']) & (datatype == 'L1TP'):
            return True
        if (dataTypes['L1Gt']) & (datatype == 'L1GT'):
            return True
        if (dataTypes['L1G']) & (datatype == 'L1GS'):
            return True
        return False
    
    def untar_file(self, tar_file):
        '''
        Untar the tar and targz files, 
        place extracted files in working directory

        :param string: tar filename
        '''
        msg = 'Extracting ' + tar_file + '...'
        QgsMessageLog.logMessage('------------------------------------', level=Qgis.Info)
        QgsMessageLog.logMessage(msg, level=Qgis.Info)
        # Extract tarball into the working directory
        tar = tarfile.open(tar_file)
        tar.extractall(self.work_dir)
        tar.close()
        QgsMessageLog.logMessage('File extraction completed for: ' + tar_file,
                                 level=Qgis.Info)
        return
    
    def untar_gapmask(self, baseDir):
        '''
        Uncompress landsat 7 gap mask gz files

        Params:
            baseDir
        '''
        msg = 'Extracting Landsat 7 Gap Mask...'
        QgsMessageLog.logMessage('------------------------------------', level=Qgis.Info)
        QgsMessageLog.logMessage(msg, level=Qgis.Info)
        
        bands = [4,7]
        gzFiles = []
        for rootDir, _, files in os.walk(baseDir):
            rootDir = rootDir.lower()
            if os.path.split(rootDir)[1] == 'gap_mask':
                for f in files:
                    f = f.lower()
                    if f.endswith('.gz'):
                        try:
                            band = int(f.split('.')[0].split('_')[8][-1])
                        except:
                            QgsMessageLog.logMessage('Time: ' + util.time_string() + ' -> failed to get the band number for ' + f,
                                                     level=Qgis.Info)
                        else:
                            try:
                                x = bands.index(band)
                            except:
                                pass
                            else:
                                if not os.path.exists(os.path.join(rootDir, os.path.splitext(f)[0])):
                                    gzFiles.append(os.path.join(rootDir, f))
                                    QgsMessageLog.logMessage('Time: ' + util.time_string() + ' -> file to extract: ' + os.path.join(rootDir, f),
                                                            level=Qgis.Info)
                                else:
                                    QgsMessageLog.logMessage('Time: ' + util.time_string() + ' -> ' + os.path.join(rootDir, f) + ' was already extracted',
                                                            level=Qgis.Info)

        for f in gzFiles:
            try:
                inFile = gzip.open(f, 'rb')
                contents = inFile.read()
                inFile.close()
                outFile = open(os.path.splitext(f)[0], 'wb')
                outFile.write(contents)
                outFile.close()
            except Exception as err:
                QgsMessageLog.logMessage('Time: ' + util.time_string() + ' -> ERROR: cannot decompress ' + f + ' because ' + os.strerror(err.errno).lower(),
                                         level=Qgis.Info)
            else:
                QgsMessageLog.logMessage('Time: ' + util.time_string() + ' -> Successfully extracted ' + f,
                                         level=Qgis.Info)

    def prepare_sentinel2_file(self, input_file, stack10mFlag):
        '''Prepare sentinel2 file to create utm, refl, and nbr outputs
        
            Params:
                input_file (str) - path to input file
                stack10mFlag (bool) - 
        '''
        try:
            # make temp working dir if it does not already exist
            if os.path.exists(self.work_dir):
                shutil.rmtree(self.work_dir)
            os.makedirs(self.work_dir)

            # make sentinel dir if it doesn't exist
            sentinel2_dir = os.path.join(self.output_dir, const.SENTINEL2)
            if not os.path.exists(sentinel2_dir):
                os.makedirs(sentinel2_dir)

            # create output dirs
            # input file name format:
                # ex. S2A_MSIL1C_20240221T163311_N0510_R083_T16SDB_20240221T200923.SAFE.zip
                # 'S2A' = sensor/mission number
                # 'T16SDB' = tile number (split '16' and 'SDB')
            scene_id = os.path.basename(input_file)
            name_split = scene_id.split('_')
            sensor = name_split[0]
            level = name_split[1][3:]
            path = name_split[5][1:4]
            row = name_split[5][4:]
            date = name_split[2][0:8]

            # create output files
            # create output location if needed
            pr_dir = os.path.join(sentinel2_dir, path,
                                  path + row)
            scene_id_folder = (sensor[2:] + path + row + date)
            scene_id_path = os.path.join(pr_dir, scene_id_folder)
            scene_id_10 = scene_id_path + "_10m_" + level
            scene_id_20 = scene_id_path + "_20m_" + level
            scene_id_30 = scene_id_path + "_30m_" + level
            scene_id_60 = scene_id_path + "_60m_" + level
            if not os.path.exists(pr_dir):
                os.makedirs(pr_dir)
            else:
                if os.path.exists(scene_id_10) or\
                    os.path.exists(scene_id_20) or\
                    os.path.exists(scene_id_30) or\
                    os.path.exists(scene_id_60):

                    reply = QMessageBox.question(None, 'Overwrite?',
                                                ('Output folder for file ' + scene_id +
                                                ' already exists.'
                                                '\nOverwrite?'),
                                                QMessageBox.Yes,
                                                QMessageBox.No)
                    if reply == QMessageBox.No:
                        return
                    else:
                        if os.path.exists(scene_id_10):
                            shutil.rmtree(scene_id_10)
                        if os.path.exists(scene_id_20):
                            shutil.rmtree(scene_id_20)
                        if os.path.exists(scene_id_30):
                            shutil.rmtree(scene_id_30)
                        if os.path.exists(scene_id_60):
                            shutil.rmtree(scene_id_60)

            # unzip file and copy files for easier access
            self.unzip_file(input_file)
            self.copy_jp2_files(self.work_dir)

            # stack and reproject
            if level == 'L2A':
                for res in [10, 20, 60]:
                    self.stack_and_reproject_sentinel_l2(self.work_dir, const.SENTINEL2, scene_id_path,
                                                self.area, self.sampling, sensor.lower(), level,
                                                stack10mSentinel=stack10mFlag,
                                                scene_id_name=scene_id_folder, l2a_res=res)
            else:
                self.stack_and_reproject(self.work_dir, const.SENTINEL2, scene_id_path,
                                         self.area, self.sampling, sensor.lower(), level=level,
                                         stack10mSentinel=stack10mFlag,
                                         scene_id_name=scene_id_folder)

            # toa images
            if not level == 'L2A':
                sunEarthDist = os.path.join(self.base_dir, const.SUN_EARTH_DIST_FILE)
                self.create_toa_images(sensor.lower(), self.work_dir, scene_id_path,
                                       sunEarthDist, level)

            # nbr images
            self.create_nbr_images(self.work_dir, scene_id_path,
                                   const.SENTINEL2, sensor.lower(), level=level)
            
            return True
            
        except Exception as ex:
            QgsMessageLog.logMessage(str(ex) + "\n" + traceback.format_exc(), level=Qgis.Critical)
            return False
        
    def unzip_file(self, zip_file):
        '''
        Unzip a file, 
        place extracted files in working directory

        :param string: zip filename
        '''
        msg = 'Extracting ' + zip_file + '...'
        QgsMessageLog.logMessage('------------------------------------', level=Qgis.Info)
        QgsMessageLog.logMessage(msg, level=Qgis.Info)
        # Extract tarball into the working directory
        with zipfile.ZipFile(zip_file, 'r') as zip:
            zip.extractall(self.work_dir)
            zip.close()
        QgsMessageLog.logMessage('File extraction completed for: ' + zip_file,
                                 level=Qgis.Info)
        return
    
    def copy_jp2_files(self, directory):
        """Copy the sentinel 2 jp2 files, which when unzipped are buried inside a
        directory structure, to the base directory to make them more easy to find."""

        for rootDirectory, _, files in os.walk(directory):
            for rootFile in files:
                if (rootFile.lower().endswith('.jp2')):
                    shutil.copy(os.path.join(rootDirectory, rootFile), 
                                directory)
                    QgsMessageLog.logMessage('File copied from ' + os.path.join(rootDirectory, rootFile)
                                             + ' to ' + directory, level=Qgis.Info)

    def stack_and_reproject(self, baseDir, satellite, copyToDir, area, sampling, sensor, level='', stack10mSentinel = False,
                            scene_id_name = ""):
        """Stacks and reprojects landsat tiffs and sentinel-2 jp2 images to GeoTiffs
        in USGS Albers projection.
        
        Parameters:
            baseDir : 
            satellite : landsat,sentinel
            copyToDir : 
            area :
            sampling :
            sensor : s2a,s2b,lt04,lt05,le07,lc08
            scene_id_name : scene folder name, only used by sentinel2
        """
        msg = 'Stack and reproject...'
        QgsMessageLog.logMessage('------------------------------------', level=Qgis.Info)
        QgsMessageLog.logMessage(msg, level=Qgis.Info)

        sensor = sensor.lower()  # switch sensor case to lower to match dictionary
        for rootDir, _, _ in os.walk(baseDir):
            if satellite == const.LANDSAT:
                images = glob.glob(os.path.join(rootDir, 'L*.TIF'))
                satelliteBands = const.LANDSAT_BANDS
            else:
                images = glob.glob(os.path.join(rootDir, "*.jp2"))
                satelliteBands = const.SENTINEL2_BANDS
                if images:
                    images = [image for image in images if "MSK_" not in image] 

            if images:
                #strip off the band and the extension from the image name
                imageRootName = images[0].split('_' + images[0].split('_')[-1])[0]
                if (satellite == const.LANDSAT and level == 'L2'):
                    imageRootName = imageRootName.replace('QA', 'SR')
                        
                processingDict = self.setup_processing_dict(satellite,imageRootName,sensor,area,scene_id_name, level=level)
                outputsExist = self.check_if_outputs_exist(processingDict)
                if (outputsExist == False):
                    for image in images:
                        band = os.path.splitext(image)[0].split('_')[-1]
                        try:
                            x = satelliteBands[sensor].index(band)
                        except:
                            QgsMessageLog.logMessage('Time: ' + util.time_string() + ' -> not processing band ' + band,
                                                    level=Qgis.Info)
                        else:
                            spatialResolution = 30.0
                            if (satellite == const.SENTINEL2):
                                try:
                                    x = satelliteBands[10.0].index(band)
                                except:
                                    spatialResolution = 20.0
                                else:
                                    spatialResolution = 10.0

                            if (spatialResolution == 10.0):
                                processingDict = self.setup_processing_dict(satellite, imageRootName, sensor, area, scene_id_name, level=level)
                                for resampleResolution in [10.0, 20.0, 30.0]:
                                    if ((resampleResolution == 10.0) & (stack10mSentinel == True)) | ((resampleResolution != 10.0) & (band != 'B08')):
                                        processingDict = util.get_inputImage_projection_information(processingDict,image,satellite,resampleResolution,1.0/(resampleResolution/10.0),45000)
                                        outputImageName = os.path.splitext(image)[0] + '_' + str(int(resampleResolution)) + 'm_albers.tif'
                                        if (resampleResolution == 10.0):
                                            QgsMessageLog.logMessage('Time: '+util.time_string()+' -> reprojecting '+ image +' to create '+ outputImageName,
                                                                    level=Qgis.Info)
                                        else:
                                            QgsMessageLog.logMessage('Time: ' + util.time_string()+' -> resampling '+image+' from 10m to '+
                                                                    str(int(resampleResolution)) + ' m to create '+outputImageName,
                                                                    level=Qgis.Info)
                                        
                                        util.reproject_image(image,outputImageName,processingDict,resampleResolution,sampling)
                            elif (spatialResolution == 20.0):
                                processingDict = self.setup_processing_dict(satellite,imageRootName,sensor,area,scene_id_name, level=level)
                                for resolution in [20.0,10.0,30.0]:
                                    if (resolution == 20.0):
                                        rescaleFactor = 1
                                        previousOutput = image
                                        padding = 45000
                                    elif (resolution == 10.0):
                                        rescaleFactor = 2
                                        padding = 45000
                                        previousOutput = image
                                    elif (resolution == 30.0):
                                        rescaleFactor = 1.0/3.0
                                        previousOutput = os.path.splitext(image)[0] + '_10m_albers.tif'  
                                        padding = 0
                                    processingDict = util.get_inputImage_projection_information(processingDict,previousOutput,satellite,resolution,rescaleFactor,padding)
                                    outputImageName = os.path.splitext(image)[0] + '_' + str(int(resolution)) + 'm_albers.tif'
                                    if (resolution == 20.0):
                                        QgsMessageLog.logMessage('Time:' + util.time_string() + ' -> reprojecting '+ previousOutput + ' to create ' + outputImageName, level=Qgis.Info)
                                    elif (resolution == 10.0):
                                        QgsMessageLog.logMessage('Time: '+util.time_string()+' -> resampling '+previousOutput+' from 20m to 10m to create '+outputImageName, level=Qgis.Info)
                                    else:
                                        QgsMessageLog.logMessage('Time: '+util.time_string()+' -> resampling '+previousOutput+' from 10m to 30m to create '+outputImageName, level=Qgis.Info)
                                        
                                    util.reproject_image(previousOutput,outputImageName,processingDict,resolution,sampling)
                                    previousOutput = outputImageName
                            elif (satellite == const.LANDSAT):
                                processingDict = self.setup_processing_dict(satellite,imageRootName,sensor,area,scene_id_name, level=level)
                                processingDict = util.get_inputImage_projection_information(processingDict, image, satellite, spatialResolution, 1, 45000)
                                outputImageName = os.path.splitext(image)[0] + '_' + str(int(spatialResolution)) + 'm_albers.tif'
                                QgsMessageLog.logMessage('Time: '+ util.time_string()+' -> reprojecting '+image+' to create '+outputImageName, level=Qgis.Info)
                                util.reproject_image(image,outputImageName,processingDict,spatialResolution,sampling)

                    for spatialResolution in [10.0, 20.0, 30.0]:
                        if ((spatialResolution == 10.0) & (stack10mSentinel == True)) | (spatialResolution != 10.0):
                            if (processingDict.get(spatialResolution) != None):
                                if (os.path.exists(processingDict[spatialResolution]['outputImage']) == False):
                                    extension = os.path.splitext(processingDict[spatialResolution]['outputImage'])[1]
                                    bands = satelliteBands[sensor]
                                    if (satelliteBands.get(spatialResolution) != None):
                                        bands = satelliteBands[spatialResolution]

                                    inputImages = []

                                    for band in bands:
                                        inputImage = imageRootName + '_' + band + '_' + str(int(spatialResolution)) + 'm_albers' + extension
                                        if (os.path.exists(inputImage)):
                                            inputImages.append(imageRootName + '_' + band + '_' + str(int(spatialResolution)) + 'm_albers' + extension)
                                        else:
                                            QgsMessageLog.logMessage('Time:'+util.time_string()+' -> '+inputImage+' does not exist ', level=Qgis.Info)
                                    if inputImages:
                                        processingDict = util.get_inputImage_projection_information(processingDict,inputImages[0],satellite,spatialResolution,1,0)
                                        util.stack_bands(inputImages, processingDict[spatialResolution]['outputImage'],
                                                        processingDict[spatialResolution],spatialResolution,
                                                        processingDict['dataType'])
                                        if (satellite == const.SENTINEL2):
                                            util.create_pyramid_layers(processingDict[spatialResolution]['outputImage'])
                                            copyTo = copyToDir + "_" + str(int(spatialResolution)) + "m"
                                            if level:
                                                copyTo = copyTo + "_" + level
                                            self.copy_outputs(processingDict[spatialResolution]['outputImage'],copyTo)

    def stack_and_reproject_sentinel_l2(self, baseDir, satellite, copyToDir, area, sampling, sensor, level, stack10mSentinel = False,
                               scene_id_name = "", l2a_res = 10):
        msg = 'Stack and reproject...'
        QgsMessageLog.logMessage('------------------------------------', level=Qgis.Info)
        QgsMessageLog.logMessage(msg, level=Qgis.Info)

        if l2a_res==10 and not stack10mSentinel:
            QgsMessageLog.logMessage("Not processing 10m inputs... Quitting", level=Qgis.Info)
        else:
            sensor = sensor.lower()  # switch sensor case to lower to match dictionary
            for rootDir, _, _ in os.walk(baseDir):
                images = glob.glob(os.path.join(rootDir, "*" + str(l2a_res) + "m.jp2"))
                satelliteBands = const.SENTINEL2_BANDS
                if images:
                    images = [image for image in images if "MSK_" not in image]

                if images:
                    #strip off the band and the extension from the image name
                    imageRootName = images[0].split('_' + images[0].split('_')[-2] 
                                                    + '_' + images[0].split('_')[-1])[0]
                    
                    processingDict = self.setup_processing_dict(satellite, imageRootName, sensor, area,
                                                                scene_id_name, level='L2A', l2a_res=l2a_res)
                    outputsExist = self.check_if_outputs_exist_sentinel_l2(processingDict)
                    if (outputsExist == False):
                        for image in images:
                            band = os.path.splitext(image)[0].split('_')[-2]
                            try:
                                x = satelliteBands[sensor].index(band)
                                x = satelliteBands[l2a_res*1.0].index(band)
                            except:
                                QgsMessageLog.logMessage('Time: ' + util.time_string() + ' -> not processing band ' + band)
                            else:
                                resampleResolution = l2a_res*1.0
                                processingDict = util.get_inputImage_projection_information(processingDict,image,satellite,resampleResolution,1.0/(resampleResolution/10.0),45000)
                                outputImageName = os.path.splitext(image)[0] + '_albers.tif'
                                QgsMessageLog.logMessage(outputImageName + "-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------")
                                util.reproject_image(image,outputImageName,processingDict,resampleResolution,sampling)
                    if (processingDict.get(l2a_res*1.0) != None):
                        if (os.path.exists(processingDict[l2a_res*1.0]['outputImage']) == False):
                            extension = os.path.splitext(processingDict[l2a_res*1.0]['outputImage'])[1]
                            bands = satelliteBands[sensor]
                            if (satelliteBands.get(l2a_res*1.0) != None):
                                bands = satelliteBands[l2a_res*1.0]
                            
                            inputImages = []
                            for band in bands:
                                inputImage = imageRootName + '_' + band + '_' + str(l2a_res) + 'm_albers' + extension
                                if (os.path.exists(inputImage)):
                                    inputImages.append(imageRootName + '_' + band + '_' + str(l2a_res) + 'm_albers' + extension)
                                else:
                                    QgsMessageLog.logMessage('Time:'+util.time_string()+' -> '+inputImage+' does not exist ', level=Qgis.Info)

                            if inputImages:
                                processingDict = util.get_inputImage_projection_information(processingDict,inputImages[0],satellite,l2a_res*1.0,1,0)
                                util.stack_bands(inputImages, processingDict[l2a_res*1.0]['outputImage'],
                                                processingDict[l2a_res*1.0],l2a_res*1.0,
                                                processingDict['dataType'])
                                if (satellite == const.SENTINEL2):
                                    util.create_pyramid_layers(processingDict[l2a_res*1.0]['outputImage'])
                                    copyTo = copyToDir + "_" + str(int(l2a_res)) + "m_" + level
                                    self.copy_outputs(processingDict[l2a_res*1.0]['outputImage'],copyTo)

    def setup_processing_dict(self, satellite, imageRootName, sensor, area, outImgName, level='', l2a_res=''):
        """Setup dictionary that is used to hold information about the processes"""
        processingDict = {}
        sensor = sensor.lower()

        if (area == 1):
            processingDict['area'] = 'CONUS'
        elif (area == 2):
            processingDict['area'] = 'Alaska'
        elif (area == 3):
            processingDict['area'] = 'Hawaii'
        elif (area == 4):
            processingDict['area'] = 'Puerto Rico'

        processingDict['dataType'] = 2 #gdal.GDT_UInt16
        if (sensor == 'lt04') | (sensor == 'lt05') | (sensor == 'le07'):
            processingDict['dataType'] = 1 #gdal.GDT_Byte
        processingDict[30.0] = {}
        processingDict[30.0]['outputImage'] = imageRootName + '_30m_albers_stacked.tif'
        if (satellite == const.SENTINEL2):
            if (level == 'L2A'):
                name = os.path.split(os.path.split(imageRootName)[1])[1]
                for spatialResolution in [int(l2a_res)]:
                    processingDict[spatialResolution] = {}
                    processingDict[spatialResolution]['outputImage'] = os.path.join(
                            os.path.split(imageRootName)[0], 
                            outImgName 
                            + '_REFL_' + str(int(spatialResolution)) + '.tif'
                        )
            else:
                name = os.path.split(os.path.split(imageRootName)[1])[1]
                for spatialResolution in [10.0,20.0,30.0]:
                    processingDict[spatialResolution] = {}
                    processingDict[spatialResolution]['outputImage'] = os.path.join(
                            os.path.split(imageRootName)[0], 
                            outImgName 
                            + '_REFL_' + str(int(spatialResolution)) + '.tif'
                        )
            
        return processingDict
    
    def check_if_outputs_exist(self, output_dict):
        """Checks if output images exist."""
        
        outputsExist = True
        for spatialResolution in [10.0, 20.0, 30.0]:
            if (output_dict.get(spatialResolution) != None):
                if (os.path.exists(output_dict[spatialResolution]['outputImage']) == False):
                    outputsExist = False
        return outputsExist
    
    def check_if_outputs_exist_sentinel_l2(self, output_dict):
        """Checks if output images exist."""
        
        outputsExist = True
        for spatialResolution in [10.0, 20.0, 60.0]:
            if (output_dict.get(spatialResolution) != None):
                if (os.path.exists(output_dict[spatialResolution]['outputImage']) == False):
                    outputsExist = False
        return outputsExist

    def remove_landsat_fringes(self, baseDir, sensor):
        """
        Removes fringes from landsat 4,5,7 scenes.

        Removes fringes occurring along left and right sides of landsat 4-7 scenes.
        """
        msg = 'Remove Landsat Fringes...'
        QgsMessageLog.logMessage('------------------------------------', level=Qgis.Info)
        QgsMessageLog.logMessage(msg, level=Qgis.Info)

        for rootDir, _, _ in os.walk(baseDir):
            if(sensor == 'lt05') | (sensor == 'lt04') | (sensor == 'le07'):
                images = glob.glob(os.path.join(rootDir, '*_30m_albers_stacked.tif'))
                if (len(images) != 1):
                    QgsMessageLog.logMessage('Time: ' + util.time_string() + ' -> ERROR: there are too many images including the name "_30m_albers_stacked.tif".  There should only be one image; I do not know which one to process.',
                                             level=Qgis.Info)
                elif (images):
                    image = images[0]
                    outputImage = os.path.splitext(image)[0] + '_defringed.tif'
                    if (os.path.exists(outputImage) == False):
                        QgsMessageLog.logMessage('Time: '+util.time_string()+' -> removing fringes of this image: '+ image,
                                                 level=Qgis.Info)
                        self.defringe_image(image,outputImage)

        QgsMessageLog.logMessage(str(time.localtime()[1]) + '/' + str(time.localtime()[2]) + '/' 
                                 + str(time.localtime()[0]) + ' ' + str(time.localtime()[3]) + ':' 
                                 + str(time.localtime()[4]) + ':' + str(time.localtime()[5]) + ' -> Finished',
                                 level=Qgis.Info)
        
    def defringe_image(self, imageName, outputImage):
        """Remove fringes of L5 & L7 scenes."""

        image = gdal.Open(imageName)

        inputProjection = image.GetProjection()
        inputTransformImage = image.GetGeoTransform()
        x_size = image.RasterXSize
        y_size = image.RasterYSize

        outputImageDriver = gdal.GetDriverByName('HFA')
        maskOutputImage = outputImageDriver.Create(outputImage, x_size, y_size, 6, gdal.GDT_Byte)
        maskOutputImage.SetGeoTransform(inputTransformImage)
        maskOutputImage.SetProjection(inputProjection)
        del inputProjection
        del inputTransformImage
        del outputImageDriver

        outputBands = {}
        imageBands = {}
        for band in range(1,7):
            outputBands[band] = maskOutputImage.GetRasterBand(band)
            imageBands[band] = image.GetRasterBand(band)
        del band
        
        tileSizeX = 4000
        tileSizeY = 4000

        imageBufferSizeX = 21
        imageBufferSizeY = 21
        for col in range(0,x_size,tileSizeX):
            for row in range(0,y_size,tileSizeY):  
                xoff = col
                yoff = row
                bufferX = imageBufferSizeX * 1 + tileSizeX
                bufferY = imageBufferSizeY * 1 + tileSizeY
                if (col - imageBufferSizeX >= 0):
                    xoff = col - imageBufferSizeX
                    bufferX = imageBufferSizeX * 2 + tileSizeX
                if (row - imageBufferSizeY >= 0):
                    yoff = row - imageBufferSizeY
                    bufferY = imageBufferSizeY * 2 + tileSizeY

                if (xoff + bufferX > x_size):
                    bufferX = x_size - xoff
                if (yoff + bufferY > y_size):
                    bufferY = y_size - yoff
                    
                bandData = imageBands[1].ReadAsArray(xoff,yoff,bufferX,bufferY)
                mask = np.zeros(bandData.shape,dtype = int)
                mask[mask==0] = 1
                mask[bandData==0] = 0
                del bandData

                for band in range(2,7):
                    bandData = imageBands[band].ReadAsArray(xoff,yoff,bufferX,bufferY)
                    mask[bandData==0] = 0
                    del bandData
                del bufferX
                del bufferY

                (fwHeight, fwWidth) = const.FOCAL_WINDOW_41x41.shape
                outX = fwWidth//2
                outY = fwHeight//2
                fwWidth = outX*2
                fwHeight = outY*2

                filteredMaskTile = np.zeros(mask.shape,dtype = int)
                for x in range(fwWidth):
                    for y in range(fwHeight):
                        if (const.FOCAL_WINDOW_41x41[y,x] != 0):
                            filteredMaskTile[outY:-outY, outX:-outX] += (mask[y:y-(outY+1*outY), x:x-(outX+1*outX)])
                for x in range(fwWidth):
                    if (const.FOCAL_WINDOW_41x41[-1,x] != 0):
                        filteredMaskTile[outY:-outY, outX:-outX] += mask[(outY+1*outY):, x:x-(outX+1*outX)]
                for y in range(fwHeight):
                    if (const.FOCAL_WINDOW_41x41[y,-1] != 0):
                        filteredMaskTile[outY:-outY, outX:-outX] += mask[y:y-(outY+1*outY), (outX+1*outX):]
                if (const.FOCAL_WINDOW_41x41[-1,-1] != 0):
                    filteredMaskTile[outY:-outY, outX:-outX] += mask[(outY+1*outY):, (outX+1*outX):]

                del mask
                del fwHeight
                del fwWidth
                del outX
                del outY

                if (os.path.split(imageName)[1].split('_')[0].lower() == 'le07'):
                    filteredMaskTile[filteredMaskTile <= 100] = 0
                    filteredMaskTile[filteredMaskTile > 100] = 1
                else:
                    filteredMaskTile[filteredMaskTile < 279] = 0
                    filteredMaskTile[filteredMaskTile >= 279] = 1

                maskX = imageBufferSizeX
                maskY = imageBufferSizeY
                if (xoff == col):
                    maskX = 0
                if (yoff == row):
                    maskY = 0
                del xoff
                del yoff

                dim = filteredMaskTile.shape
                nrows = maskY + tileSizeY
                ncols = maskX + tileSizeX
                if (nrows > dim[0]):
                    nrows = dim[0]
                if (ncols > dim[1]):
                    ncols = dim[1]
                del dim

                filteredMaskTile = filteredMaskTile[maskY:nrows,maskX:ncols]
                del maskX
                del maskY
                del nrows
                del ncols
                
                for band in range(1,7):
                    tileX = tileSizeX
                    tileY = tileSizeY
                    if (col + tileSizeX > x_size):
                        tileX = x_size - col
                    if (row + tileSizeY > y_size):
                        tileY = y_size - row
                        
                    bandData = imageBands[band].ReadAsArray(col,row,tileX,tileY)
                    bandData[filteredMaskTile==0] = 0
                    outputBands[band].WriteArray(bandData,col,row)
                    del bandData
                del filteredMaskTile
                del band

        for band in range(1,7):
            outputBands[band].FlushCache()

        del image
        del x_size
        del y_size
        del maskOutputImage
        del outputBands
        del imageBands
        del tileSizeX
        del tileSizeY

    def create_toa_images(self, sensor, baseDir, copyToDir, sunEarthDistFile, level=''):
        """
        Create top-of-atmosphere images.

        Create top-of-atmosphere images. The equations used in this script come from
        these references:
        * Chander, G., B.L. Markham, and D.L. Helder. 2009. Summary of current
        radiometric calibration coefficients for Landsat MSS, TM, ETM+, and EO-1 ALI
        sensors.  RSE 113:893-903.
        * https://landsat.usgs.gov/using-usgs-landsat-8-product
        
        Author: Bonnie Ruefenacht PhD, Senior Specialist, RedCastle Resources Inc.
        """
        if not level == 'L2':
            msg = 'Create TOA Images...'
            QgsMessageLog.logMessage('------------------------------------', level=Qgis.Info)
            QgsMessageLog.logMessage(msg, level=Qgis.Info)
            sunEarthDist = {}
            if not os.path.exists(sunEarthDistFile):
                QgsMessageLog.logMessage('Time: ' + util.time_string() + ' -> cannot find SunEarthDistance.txt file',
                                        level=Qgis.Info)
            else:
                lines = open(sunEarthDistFile, 'r').readlines()
                for line in lines:
                    sunEarthDist[int(line.split(' ')[0])] = float(line.split(' ')[1])

                QgsMessageLog.logMessage('Time: ' + util.time_string() + ' -> Base Directory: ' + baseDir,
                                        level=Qgis.Info)
                QgsMessageLog.logMessage('Time: ' + util.time_string() + ' -> sunEarthDistFile: ' + sunEarthDistFile,
                                        level=Qgis.Info)
                
                inputImageName = ''
                for rootDir, _, files in os.walk(baseDir):
                    for file in files:
                        file = file.lower()
                        if file.endswith('_mtl.txt'):
                            mtlFile = os.path.join(rootDir, file)
                            inputImageName = mtlFile.split('_mtl.txt')[0] + '_30m_albers_stacked.tif' 
                            QgsMessageLog.logMessage("inputImageName: " + inputImageName, level=Qgis.Info)
                        if os.path.exists(inputImageName):
                            metadata = self.get_metadata(mtlFile)
                            if not os.path.exists(metadata[1]['outputName']):
                                bands = {1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 7}
                                sensor = sensor.lower()
                                if (sensor == 'lc08' or sensor == 'lc09' or sensor == 'lo09'):
                                    bands = {1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 9}
                                self.calculate_toa(sensor, inputImageName, bands, metadata, sunEarthDist)

                                bandsList = list(bands.keys())
                                bandsList.sort()
                                inputImages = []
                                for band in bandsList:
                                    inputImages.append(mtlFile.split('_mtl.txt')[0].upper() + '_B' + str(bands[band]) + '_toa.tif')

                                image = gdal.Open(inputImageName)
                                projection = {}
                                projection['xsize'] = image.RasterXSize
                                projection['ysize'] = image.RasterYSize
                                projection['outputTransform'] = image.GetGeoTransform()
                                projection['outputProjection'] = image.GetProjection()
                                del image

                                util.stack_bands(inputImages,metadata[1]['outputName'],projection,30.0,2)
                                util.create_pyramid_layers(metadata[1]['outputName'])
                                self.copy_outputs(metadata[1]['outputName'], copyToDir)
                                
                        #sentinel toa correction new as of 20220519, sentinel changed jan 26 2022            
                        elif (file.endswith("msil1c.xml")):
                            mtlFile = os.path.join(rootDir, file)
                            #metadata has BANDOFFSET and QUANTIFICATION_VALUE as keys at the moment.
                            metadata = self.get_sentinel_metadata(mtlFile)
                            if (metadata is None or len(metadata) == 0):
                                QgsMessageLog.logMessage('Time: ' + util.time_string() + ' -> not processing sentinel for TOA',
                                                        level=Qgis.Info)
                            else:
                                QgsMessageLog.logMessage('Time: ' + util.time_string() + ' ->  processing sentinel for TOA',
                                                        level=Qgis.Info)  
                                pathList = rootDir.split(os.path.sep)
                                gridId = pathList[-2][1:].upper()
                                baseDir = pathList[-1].split('_')
                                
                                sensorLetter = baseDir[0][-1].upper()
                                imageDate = baseDir[2][0:8].upper()

                                resolutions = ['10','20','30']
                                for res in resolutions:
                                    #30m input image
                                    inputImageName = os.path.join(rootDir, sensorLetter + gridId + imageDate + '_REFL_' + str(int(res)) + '.tif')
                                    QgsMessageLog.logMessage(inputImageName, level=Qgis.Info)
                                    if os.path.exists(inputImageName):
                                        bands = {1 : 2, 2 : 3, 3 : 4, 4 : 8, 5 : 11, 6 : 12}
                                        if res == '10':
                                            bands = {1 : 2, 2 : 3, 3 : 4, 4 : 8}
                                        self.calculate_sentinel_toa(inputImageName,bands,metadata,res)
            
                                        bandsList = list(bands.keys())
                                        bandsList.sort()
                                        inputImages = []
                                        for band in bandsList:
                                            bandVal = '8A' if str(bands[band]) == '8' else str(bands[band])
                                            if bandVal == '8':
                                                #in the metadata file, there is no 8a correction value, so it maps to 8
                                                bandVal = '8A'
                                            inputImages.append(inputImageName.replace('_REFL_' + str(int(res)) + '.tif', '_B' + bandVal + '_' + str(int(res)) + '_toa.tif'))
            
                                        image = gdal.Open(inputImageName)
                                        projection = {}
                                        projection['xsize'] = image.RasterXSize
                                        projection['ysize'] = image.RasterYSize
                                        projection['outputTransform'] = image.GetGeoTransform()
                                        projection['outputProjection'] = image.GetProjection()
                                        image = None
                                        del image
                                        outputImgName = inputImageName
                                        util.stack_bands(inputImages,outputImgName,projection,float(res),2)
                                        util.create_pyramid_layers(outputImgName)
                                        copyTo = copyToDir + "_" + res + "m_" + level
                                        self.copy_outputs(outputImgName, copyTo)
        else:
                QgsMessageLog.logMessage('Time: ' + util.time_string() + ' -> Base Directory: ' + baseDir,
                                        level=Qgis.Info)
                
                inputImageName = ''
                for rootDir, _, files in os.walk(baseDir):
                    for file in files:
                        file = file.lower()
                        if file.endswith('_mtl.txt'):
                            mtlFile = os.path.join(rootDir, file)
                            inputImageName = mtlFile.split('_mtl.txt')[0] + '_sr_30m_albers_stacked.tif' 
                            QgsMessageLog.logMessage("inputImageName: " + inputImageName, level=Qgis.Info)
                        if os.path.exists(inputImageName):
                            metadata = self.get_metadata(mtlFile)
                            if not os.path.exists(metadata[1]['outputName']):
                                bands = {1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 7}
                                sensor = sensor.lower()
                                if (sensor == 'lc08' or sensor == 'lc09' or sensor == 'lo09'):
                                    bands = {1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 9}
                                if (level == 'L2'):
                                    if (sensor == 'lc08' or sensor == 'lc09' or sensor == 'lo09'):
                                        bands = {1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7}

                                bandsList = list(bands.keys())
                                bandsList.sort()
                                inputImages = []
                                for band in bandsList:
                                    inputImages.append(mtlFile.split('_mtl.txt')[0].upper() + '_SR_B' + str(bands[band]) + '.tif')

                                image = gdal.Open(inputImageName)
                                projection = {}
                                projection['xsize'] = image.RasterXSize
                                projection['ysize'] = image.RasterYSize
                                projection['outputTransform'] = image.GetGeoTransform()
                                projection['outputProjection'] = image.GetProjection()
                                del image

                                util.stack_bands(inputImages,metadata[1]['outputName'],projection,30.0,2)
                                util.create_pyramid_layers(metadata[1]['outputName'])
                                self.copy_outputs(metadata[1]['outputName'], copyToDir)

    def calculate_sentinel_toa(self, image, bands, metadata, res):
        QgsMessageLog.logMessage('Time: '+util.time_string()+' -> calculating TOA for '+image)
        ds = gdal.Open(image)

        bandsList = list(bands.keys())
        bandsList.sort()
        for band in bandsList:
            bandVal = '8A' if str(bands[band]) == '8' else str(bands[band])        
            outputBandName = image.split('_REFL')[0].upper() + '_B' + bandVal + '_' + res + '_toa.tif'
            if not os.path.exists(outputBandName):
                outputImageDriver = gdal.GetDriverByName("GTiff")
                outputImage = outputImageDriver.Create(outputBandName, ds.RasterXSize, ds.RasterYSize, 1, gdal.GDT_Float32)
                outputImage.SetGeoTransform(ds.GetGeoTransform())
                outputImage.SetProjection(ds.GetProjection())
                outputBand = outputImage.GetRasterBand(1)

                QgsMessageLog.logMessage('Time: '+util.time_string()+' -> processing ' + image + ' band '+ bandVal)
                imageBand = ds.GetRasterBand(band)
                imageArr = imageBand.ReadAsArray()
                offsetDict = metadata['BANDOFFSET']
                quantVal = float(metadata['QUANTIFICATION_VALUE'])
                offsetVal = float(offsetDict[bandVal])
                #calculate toa (dn + offset)/quantification val and scaled by 1000
                imageArr = ((imageArr + offsetVal)/quantVal) * 1000
                outputBand.WriteArray(imageArr)
                imageArr = None
                del imageArr
                outputBand.FlushCache()
                outputImage = None
                del outputImage
                outputBand = None
                del outputBand
                imageBand = None
                del imageBand
                
        ds = None
        del ds            
        
        return 
    
    def calculate_toa(self, sensor, inputImageName,bands,metadata,sunEarthDistance):
        """Creates a TOA image."""

        QgsMessageLog.logMessage('Time: ' + util.time_string() + ' -> calculating TOA for ' + inputImageName,
                                 level=Qgis.Info)

        gapMaskImageName = ''
        if (sensor == 'le07'):
            try:
                gapMaskImageName = glob.glob(os.path.join(os.path.split(inputImageName)[0],
                                                          'gap_mask',
                                                          '*_gapmask.tif'))[0]
            except:
                gapMaskImageName = ''
            else:
                QgsMessageLog.logMessage('Time: ' + util.time_string() + ' -> gap mask: ' + gapMaskImageName)
                gapMaskImage = gdal.Open(gapMaskImageName)
                gapMaskImageBand = gapMaskImage.GetRasterBand(1)

        image = gdal.Open(inputImageName)
        x_size = image.RasterXSize
        y_size = image.RasterYSize 
        bandsList = list(bands.keys())
        bandsList.sort()
        for band in bandsList:
            outputBandName = inputImageName.split('_30m_albers_stacked')[0].upper() + '_B' + str(bands[band]) + '_toa.tif'
            if not os.path.exists(outputBandName):
                outputImageDriver = gdal.GetDriverByName("GTiff")
                outputImage = outputImageDriver.Create(outputBandName, image.RasterXSize, image.RasterYSize, 1, gdal.GDT_Float32)
                outputImage.SetGeoTransform(image.GetGeoTransform())
                outputImage.SetProjection(image.GetProjection())
                outputBand = outputImage.GetRasterBand(1)

                QgsMessageLog.logMessage('Time: ' + util.time_string()+' -> processing ' 
                                         + inputImageName + ' band ' + str(bands[band]),
                                         level=Qgis.Info)
                imageBand = image.GetRasterBand(band)
                for x in range(0,x_size,1000):
                    for y in range(0,y_size,1000):
                        bufferX = 1000
                        bufferY = 1000
                        if (x + bufferX > x_size):
                            bufferX = x_size - x 
                        if (y + bufferY > y_size):
                            bufferY = y_size - y
                        
                        bandData = imageBand.ReadAsArray(x,y,bufferX,bufferY).astype(np.float64)
                        if (os.path.exists(gapMaskImageName)):
                            gapMaskImageBandData = gapMaskImageBand.ReadAsArray(x,y,bufferX,bufferY).astype(np.int)
                        mask = np.zeros(bandData.shape,dtype = int)
                        mask[mask==0] = 1
                        mask[bandData==0] = 0
                        if (sensor == 'lt04') | (sensor == 'lt05') | (sensor == 'le07'):
                            bandData = (metadata[bands[band]]['GAIN'] * bandData) + metadata[bands[band]]['BIAS']
                            bandData = math.pi * bandData * sunEarthDistance[metadata[1]['DAY']] * sunEarthDistance[metadata[1]['DAY']]
                            bandData = bandData / (const.S_EXO[sensor][band] * metadata[1]['TAU'])
                        else:
                            bandData = ( metadata[bands[band]]['REFLECTANCE_MULT_BAND'] * bandData + metadata[bands[band]]['REFLECTANCE_ADD_BAND'] ) / metadata[1]['TAU']
                        if (os.path.exists(gapMaskImageName)):
                            bandData = bandData * 400.0 * mask * gapMaskImageBandData
                        else:
                            bandData = bandData * 400.0 * mask
                        outputBand.WriteArray(bandData,x,y)

                        if (os.path.exists(gapMaskImageName)):
                            del gapMaskImageBandData
                        del bandData
                        del mask
                        
                outputBand.FlushCache()

                del outputImage
                del outputBand
                del imageBand
        del image
        if (os.path.exists(gapMaskImageName)):
            del gapMaskImage
            del gapMaskImageBand
    
    def get_metadata(self, mtlFile):
        """Extract information from the landsat .mtl file."""
        lines = open(mtlFile,'r').readlines()
        metadataDict = {}
        metadataDict[1] = {}
        metadataDict[1]['TAU'] = 0.0
        metadataDict[1]['outputName'] = ''
        itemsToFind = ['RADIANCE_MAXIMUM_BAND_','RADIANCE_MINIMUM_BAND_','QUANTIZE_CAL_MAX_BAND_','QUANTIZE_CAL_MIN_BAND_','REFLECTANCE_MULT_BAND_','REFLECTANCE_ADD_BAND_']
        for line in lines:
            line = line.split('\n')[0].upper()
            for itemToFind in itemsToFind:
                if (line.find(itemToFind) >= 0):
                    items = line.split(' = ')
                    band = int(items[0].split('_BAND_')[1][0])
                    if (metadataDict.get(band) == None):
                        metadataDict[band] = {}
                    metadataDict[band][itemToFind[:-1]] = float(items[1])
            if (line.find('LANDSAT_SCENE_ID = "') >= 0):
                metadataDict[1]['DAY'] = int(line.split('LANDSAT_SCENE_ID = "')[1][:-1][13:16])
                LANDSAT_SCENE_ID = line.split(' = "')[1][:-1]
                firstPosition = -1
                lastPosition = 0
                for i in range(len(LANDSAT_SCENE_ID)):
                    char = LANDSAT_SCENE_ID[i]
                    if (firstPosition < 0):
                        if (char.isdigit()):
                            firstPosition = i
                    elif (char.isalpha()):
                        lastPosition = i
                        break
                metadataDict[1]['outputName'] = os.path.join(os.path.split(mtlFile)[0],
                                                             LANDSAT_SCENE_ID[firstPosition:lastPosition][:7] 
                                                             + os.path.split(os.path.split(mtlFile)[1])[1].split('_')[3] 
                                                             + '_REFL.tif')
            if (line.find('SUN_ELEVATION = ') >= 0):
                metadataDict[1]['TAU'] = math.cos((90.0 - float(line.split(' = ')[1])) * (math.pi / 180.0))

        for band in metadataDict.keys():
            metadataDict[band]['GAIN'] = (metadataDict[band]['RADIANCE_MAXIMUM_BAND'] - metadataDict[band]['RADIANCE_MINIMUM_BAND']) / (metadataDict[band]['QUANTIZE_CAL_MAX_BAND'] - metadataDict[band]['QUANTIZE_CAL_MIN_BAND'])
            metadataDict[band]['BIAS'] = metadataDict[band]['RADIANCE_MINIMUM_BAND'] - ((metadataDict[band]['RADIANCE_MAXIMUM_BAND'] - metadataDict[band]['RADIANCE_MINIMUM_BAND']) / (metadataDict[band]['QUANTIZE_CAL_MAX_BAND'] - metadataDict[band]['QUANTIZE_CAL_MIN_BAND'])) * metadataDict[band]['QUANTIZE_CAL_MIN_BAND']

        if (metadataDict[1]['TAU'] == 0.0):
            metadataDict = {}
        
        return metadataDict
    
    def get_sentinel_metadata(self, mtlFile):
        metadataDict = {}
        metadataDict['BANDOFFSET'] = {}
        tagToFind = 'Product_Image_Characteristics'
        xmlRoot = minidom.parse(mtlFile)
        processingBaseline = xmlRoot.getElementsByTagName('PROCESSING_BASELINE')[0].childNodes[0].data
        if float(processingBaseline) >= 4.0:
            elements = xmlRoot.getElementsByTagName(tagToFind)[0].childNodes
            if len(elements) == 0:
                QgsMessageLog.logMessage("couldn't find the product image characteristics", level=Qgis.Info)
                return None
            #the way this is structure, theres newline text nodes inbetween items in the list.
            for element in elements:
                if element.nodeType == element.TEXT_NODE:
                    continue
                elif element.nodeType == element.ELEMENT_NODE:
                    if element.tagName == "Radiometric_Offset_List":
                        for child in element.childNodes:
                            if child.nodeType == child.TEXT_NODE:
                                continue
                            elif child.nodeType == child.ELEMENT_NODE:
                                bandId = child.attributes['band_id'].value
                                bandId = "8A" if bandId == '8' else bandId #this fixes for 8a.
                                offsetVal = child.childNodes[0].data
                                metadataDict['BANDOFFSET'][bandId] = offsetVal
                    elif element.tagName == "QUANTIFICATION_VALUE":
                        metadataDict["QUANTIFICATION_VALUE"] = element.childNodes[0].data            
            return metadataDict
        return {}

    def copy_outputs(self, image, copyToDirectory):
        """ Copies files and verifies that files were copied."""
        directory = copyToDirectory
        if not os.path.exists(directory):
            os.makedirs(directory)
        copyFromFiles = glob.glob(os.path.splitext(image)[0] + '*')
        for copyFrom in copyFromFiles:
            copyTo = os.path.join(directory, os.path.split(copyFrom)[1])
            shutil.copy(copyFrom, copyTo)

    def create_nbr_images(self, baseDir, copyToDir, satellite, sensor, level=''):
        """
        Create normalized burn ratio images.

        Author: Bonnie Ruefenacht PhD, Senior Specialist, RedCastle Resources Inc.
        """
        QgsMessageLog.logMessage('------------------------------------', level=Qgis.Info)
        QgsMessageLog.logMessage('Create NBR Images...', level=Qgis.Info)
        QgsMessageLog.logMessage('Base Directory: ' + baseDir, level=Qgis.Info)
        
        QgsMessageLog.logMessage('Time: '+util.time_string()+' -> Base Directory: ' + baseDir,
                                 level=Qgis.Info)

        for rootDir, _, _ in os.walk(baseDir):
            try:
                x = const.SENSOR_LIST.index(sensor)
            except:
                pass
            else:
                images = glob.glob(rootDir + '/*_REFL_*.tif')
                images.extend(glob.glob(rootDir + '/*_REFL.tif'))
                for inputImage in images:

                    outputImageName = inputImage.split('_REFL')[0] + '_NBR' + inputImage.split('_REFL')[1]
                    if (inputImage.endswith('_10.TIF') or inputImage.endswith('_10.tif')):
                        outputImageName = inputImage.split('_REFL')[0] + '_NDVI' + inputImage.split('_REFL')[1]
                    if (os.path.exists(outputImageName) == False):
                        QgsMessageLog.logMessage('Time: '+util.time_string()+' -> creating NBR/NDVI image for '+inputImage,
                                                 level=Qgis.Info)
                        self.calc_index(inputImage,outputImageName,satellite,sensor)
                    util.create_pyramid_layers(outputImageName)
                    if satellite == const.SENTINEL2:
                        res = inputImage.split('_')[-1].split('.')[0]
                        copyTo = copyToDir + '_' + res + 'm'
                        if level:
                            copyTo = copyTo + '_' + level
                    else:
                        copyTo = copyToDir
                    self.copy_outputs(outputImageName, copyTo)

    def calc_index(self, inputImage,outputImageName,satellite,sensor):
        """Calculate either NDVI or NBR index."""
        
        image = gdal.Open(inputImage)

        inputProjection = image.GetProjection()
        inputTransformImage = image.GetGeoTransform()
        x_size = image.RasterXSize
        y_size = image.RasterYSize

        outputImageDriver = gdal.GetDriverByName('GTiff')
        outputImage = outputImageDriver.Create(outputImageName, x_size, y_size, 1, gdal.GDT_Int16)
        outputImage.SetGeoTransform(inputTransformImage)
        outputImage.SetProjection(inputProjection)
        outputBand = outputImage.GetRasterBand(1)

        imageNIR = image.GetRasterBand(const.NBR_BANDS[satellite][sensor]['nir'])
        if (inputImage.endswith('_10.TIF') or inputImage.endswith('_10.tif')):
            imageSWIR = image.GetRasterBand(2)
        else:
            imageSWIR = image.GetRasterBand(const.NBR_BANDS[satellite][sensor]['swir'])

        for x in range(0,x_size,1000):
            for y in range(0,y_size,1000):
                bufferX = 1000
                bufferY = 1000
                if (x + bufferX > x_size):
                    bufferX = x_size - x
                if (y + bufferY > y_size):
                    bufferY = y_size - y

                NIRData = imageNIR.ReadAsArray(x,y,bufferX,bufferY).astype(np.float64)
                SWIRData = imageSWIR.ReadAsArray(x,y,bufferX,bufferY).astype(np.float64)

                with np.errstate(invalid='ignore'):
                    NBR = ((NIRData - SWIRData)/(NIRData + SWIRData))*1000.0
                    NBR[(NIRData == 0) | (SWIRData == 0)] = -32768
                outputBand.WriteArray(NBR,x,y)

                del NIRData
                del SWIRData
                del NBR

        outputBand.FlushCache()

        del outputImage
        del outputBand

    def time_string(self):
        """Return an easy to read time string."""
        hour = str(time.localtime()[3])
        minute = str(time.localtime()[4])
        second = str(time.localtime()[5])
        if (len(hour) == 1):
            hour = '0' + hour
        if (len(minute) == 1):
            minute = '0' + minute
        if (len(second) == 1):
            second = '0' + second
        return hour + ':' + minute + ':' + second
