'''
/***************************************************************************
Name		     : SubsetProcess
Description          : Subset Process functions for QGIS FMT3 Plugin
copyright            : (C) 2018 by Cheryl Holen
Created              : Sep 05, 2018 - Adapted from QGIS 2.x version
Updated              :

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
 This code is based off of work done by:

    Kelcy Smith
    Contractor to USGS/EROS
    kelcy.smith.ctr@usgs.gov

    dNBR processing for Open Source MTBS

    Computes dNBR values (pre - post) for two given NBR rasters
    Generally only works with projected units (meters) (not really tested)

    References:
    http://landsathandbook.gsfc.nasa.gov/data_prod/prog_sect11_3.html
    http://landsathandbook.gsfc.nasa.gov/pdfs/L5TMLUTIEEE2003.pdf

'''
from PyQt5.QtWidgets import QMessageBox

import numpy as np
import os
from osgeo import gdal
import scipy.ndimage as nd
import subprocess

from .Utilities import Utilities
from FMT3 import constants as const


class ScenePrep():
    BAND_CT = 6
    CMD_BASE = 'gdal_translate -co NBITS=1 '
    DNBR_MIN = 0
    DNBR_MAX = 20000
    NDV = -32768
    LS7 = '7'

    def __init__(self, mapping_id, scene_dir, img_src, overwrite):
        self.mapping_id = mapping_id
        self.path = scene_dir
        self.img_src = img_src
        self.overwrite = overwrite
        self.driver = gdal.GetDriverByName('GTiff')
        self.out_rows = 0
        self.out_cols = 0
        self.pre_x_offset = 0
        self.pre_y_offset = 0
        self.post_x_offset = 0
        self.post_y_offset = 0
        self.utils = Utilities()

    def toa_mask(self, src_path, dst_path):
        '''
        Creates 6 band TOA mask

        :param string: Input reflectance image
        :param string: Output mask path
        '''
        src_ds = gdal.Open(src_path)
        rows = src_ds.RasterYSize
        cols = src_ds.RasterXSize
        geo = src_ds.GetGeoTransform()
        projection = src_ds.GetProjection()

        dst_ds = self.driver.Create(dst_path, cols, rows, self.BAND_CT,
                                    gdal.GDT_Byte)
        dst_ds.SetProjection(projection)
        dst_ds.SetGeoTransform(geo)
        dst_ds.SetProjection(projection)

        # Process each band, one at a time
        for num in range(self.BAND_CT):
            num += 1

            src_band = src_ds.GetRasterBand(num)
            mask_band = dst_ds.GetRasterBand(num)

            src_array = np.array(src_band.ReadAsArray(0, 0, cols, rows),
                                 dtype=np.float32)
            dst_array = np.zeros((rows, cols), dtype=np.bool_)

            dst_array[src_array > 0] = 1

            mask_band.WriteArray(dst_array)
            del mask_band
            del dst_array
        dst_ds = None
        del dst_ds

    def buffer_mask(self, mask_path):
        '''
        Buffers Toa Mask by 2 pixels

        :param string: Input mask path
        '''
        src_ds = gdal.Open(mask_path, gdal.GA_Update)
        r = src_ds.RasterYSize
        c = src_ds.RasterXSize

        struct = nd.generate_binary_structure(2, 2)
        for x in range(1, src_ds.RasterCount + 1):
            band = src_ds.GetRasterBand(x)
            src_array = np.array(band.ReadAsArray(0, 0, c, r), dtype=np.int_)
            buffArr = np.logical_not(
                nd.binary_dilation(np.logical_not(src_array.astype(np.bool_)),
                                   structure=struct,
                                   iterations=2).astype(np.int_))

            band.WriteArray(buffArr)
            band = None
            buffArr = None

        src_ds = None
        del src_ds

    def nbr_mask(self, src_path, dst_path):
        '''
        Creates NBR mask from TOA bands 5 and 4 mask

        :param string: Input mask path
        :param string: Output mask path
        :type integer: -1 for error, 0 for success
        '''
        ndv, c, r, geo, projection, src_data_type =\
            self.utils.get_geo_info(src_path)
        src_ds = gdal.Open(src_path, gdal.GA_Update)

        projection = src_ds.GetProjection()

        src_ds = None
        del src_ds

        band5_array =\
            self.utils.file_to_array(src_path, src_data_type, c, r, 0, 0, 5)
        band4_array =\
            self.utils.file_to_array(src_path, src_data_type, c, r, 0, 0, 4)
        if (band4_array is None or band5_array is None):
            QMessageBox.critical(None,
                                 'Error!!',
                                 'Mask cannot be created, file read failed.',
                                 QMessageBox.Ok)
            return -1  # Error code
        dst_array = band5_array * band4_array

        band5_array = None
        band4_array = None

        dst_ds = self.driver.Create(dst_path, c, r, 1, gdal.GDT_Byte)
        dst_ds.SetGeoTransform(geo)
        dst_ds.SetProjection(projection)

        dst_band = dst_ds.GetRasterBand(1)
        dst_band.WriteArray(dst_array)

        dst_array = None
        dst_ds = None
        dst_band = None
        return 0  # Success

    def dnbr_mask(self, pre_mask_path, post_mask_path,
                  dst_path, cols, rows, geo, proj):
        '''
        Creates dNBR mask

        :param string: Prefire mask path
        :param string: Postfire mask path
        :param integer: Column count
        :param integer: Row count
        :param list: Geotransform
        :param string: Projection
        '''
        dst_array = None
        pre_array = None
        post_array = None

        if os.path.exists(pre_mask_path):
            pre_ds = gdal.Open(pre_mask_path)
            pre_band = pre_ds.GetRasterBand(1)
            pre_array = np.array(pre_band.ReadAsArray(self.pre_x_offset,
                                                      self.pre_y_offset,
                                                      self.out_cols,
                                                      self.out_rows))
        if os.path.exists(post_mask_path):
            post_ds = gdal.Open(post_mask_path)
            post_band = post_ds.GetRasterBand(1)
            post_array = np.array(post_band.ReadAsArray(self.post_x_offset,
                                                        self.post_y_offset,
                                                        self.out_cols,
                                                        self.out_rows))
        if pre_array is not None and post_array is not None:
            dst_array = pre_array + post_array
            dst_array[(pre_array > 0) | (post_array > 0)] = 1
        elif pre_array is not None:
            dst_array = pre_array
            dst_array[pre_array > 0] = 1
        elif post_array is not None:
            dst_array = post_array
            dst_array[post_array > 0] = 1

        if dst_array is not None:
            dst_ds = self.driver.Create(dst_path, cols, rows, 1, gdal.GDT_Byte)
            dst_ds.SetGeoTransform(geo)
            dst_ds.SetProjection(proj)
            dst_band = dst_ds.GetRasterBand(1)
            dst_band.WriteArray(dst_array)

    def scene_prep_process(self):
        ''' Processes Scene Prep '''
        sql = ("SELECT * FROM 'Mappings' WHERE id = " + self.mapping_id)
        fetch = self.utils.run_query(sql)
        prefire_id = str(fetch[6])
        postfire_id = str(fetch[10])

        if not prefire_id == 'None':
            pre_sub_path, pre_date, pre_path_row, pre_name_template = \
                self.utils.get_image_folder_date_path_row_name(prefire_id)
        else:
            QMessageBox.critical(None,
                                 'Error!',
                                 ('prefire_id is set to None\n'
                                  'Scene Prep exiting.'),
                                 QMessageBox.Ok)
            return

        if not postfire_id == 'None':
            post_sub_path, post_date, post_path_row, post_name_template = \
                self.utils.get_image_folder_date_path_row_name(postfire_id)
        else:
            QMessageBox.critical(None,
                                 'Error!',
                                 ('postfire_id is set to None\n'
                                  'Scene Prep exiting.'),
                                 QMessageBox.Ok)
            return
        if post_sub_path == const.LANDSAT:
            output_path = os.path.join(
                self.path, 'img_proc', const.LANDSAT)
            cell_size = 30
        else:
            output_path = os.path.join(
                self.path, 'img_proc', const.SENTINEL2)
            cell_size = self.utils.get_resolution(postfire_id)

        pre_nbr = pre_name_template.replace('<ext>', const.NBR)
        post_nbr = post_name_template.replace('<ext>', const.NBR)

        pre_nbr_path = os.path.join(self.img_src, pre_sub_path, pre_path_row,
                                    prefire_id, pre_nbr)
        post_nbr_path = os.path.join(self.img_src, post_sub_path,
                                     post_path_row, postfire_id, post_nbr)
        if not os.path.exists(pre_nbr_path):
            QMessageBox.critical(None,
                                 'Error!!',
                                 ('Prefire scene does not exist.\n'
                                  + pre_nbr_path +
                                  '\nExiting...'),
                                 QMessageBox.Ok)
            return
        if not os.path.exists(post_nbr_path):
            QMessageBox.critical(None,
                                 'Error!!',
                                 ('Postfire  scene does not exist.\n'
                                  + post_nbr_path +
                                  '\nExiting...'),
                                 QMessageBox.Ok)
            return

        pre_refl_path = pre_nbr_path.replace(const.NBR, const.REFL)
        post_refl_path = post_nbr_path.replace(const.NBR, const.REFL)

        # Gap mask if Landsat 7
        pre_gapmask_path = pre_nbr_path.replace(const.NBR, const.GM)
        pre_gapmask_path_temp =\
            pre_gapmask_path.replace(const.GM, const.GAP_MASK)
        post_gapmask_path = post_nbr_path.replace(const.NBR, const.GM)
        post_gapmask_path_temp =\
            post_gapmask_path.replace(const.GM, const.GAP_MASK)

        temp_string =\
            (pre_path_row + '_' + pre_date +
             '_' + post_path_row + '_' + post_date)

        dnbr_out_dir = os.path.join(output_path, temp_string)
        if not os.path.exists(dnbr_out_dir):
            os.makedirs(dnbr_out_dir)

        dnbr_out_path = os.path.join(dnbr_out_dir,
                                     ('d' + temp_string + const.TIF_EXT))
        dnbr_mask_path = os.path.join(dnbr_out_dir,
                                      ('m' + temp_string + const.TIF_EXT))
        mask_temp = dnbr_mask_path.replace(
            const.TIF_EXT, '_temp' + const.TIF_EXT)

        if os.path.exists(dnbr_mask_path) and self.overwrite is True:
            os.remove(dnbr_mask_path)

        if os.path.exists(dnbr_out_path) and self.overwrite is False:
            QMessageBox.warning(None,
                                'Warning!',
                                'DNBR image already exists.\nExiting...',
                                QMessageBox.Ok)
            return

        overlap = False

        # Get the bounding box coordinates for the two NBR's
        pre_coords = []
        pre_ds = gdal.Open(pre_nbr_path)
        pre_band = pre_ds.GetRasterBand(1)
        pre_rows = pre_band.YSize
        pre_cols = pre_band.XSize
        pre_geo = pre_ds.GetGeoTransform()
        pre_proj = pre_ds.GetProjection()

        pre_coords.append(pre_geo[0])
        pre_coords.append(pre_geo[3] + pre_geo[5] * pre_rows)
        pre_coords.append(pre_geo[0] + pre_geo[1] * pre_cols)
        pre_coords.append(pre_geo[3])

        post_coords = []
        post_ds = gdal.Open(post_nbr_path)
        post_band = post_ds.GetRasterBand(1)
        post_rows = post_band.YSize
        post_cols = post_band.XSize
        post_geo = post_ds.GetGeoTransform()
        post_proj = post_ds.GetProjection()

        post_coords.append(post_geo[0])
        post_coords.append(post_geo[3] + post_geo[5] * post_rows)
        post_coords.append(post_geo[0] + post_geo[1] * post_cols)
        post_coords.append(post_geo[3])

        # Verify same projection
        if pre_proj != post_proj:
            QMessageBox.critical(None,
                                 'Error!',
                                 ('Files are not in the same projection.\n'
                                  'Scene Prep exiting.'),
                                 QMessageBox.Ok)
            return
        # Determine the total extent of where they overlap
        if post_coords[0] <= pre_coords[0] <= post_coords[2]:
            x_min = pre_coords[0]
            if post_coords[0] <= pre_coords[2] <= post_coords[2]:
                x_max = pre_coords[2]
            else:
                x_max = post_coords[2]

            if post_coords[1] <= pre_coords[1] <= post_coords[3]:
                y_min = pre_coords[1]
                if post_coords[1] <= pre_coords[3] <= post_coords[3]:
                    y_max = pre_coords[3]
                else:
                    y_max = post_coords[3]
                overlap = True

            elif pre_coords[1] <= post_coords[1] <= pre_coords[3]:
                y_min = post_coords[1]
                if pre_coords[1] <= post_coords[3] <= pre_coords[3]:
                    y_max = post_coords[3]
                else:
                    y_max = pre_coords[3]
                overlap = True

        elif pre_coords[0] <= post_coords[0] <= pre_coords[2]:
            x_min = post_coords[0]
            if pre_coords[0] <= post_coords[2] <= pre_coords[2]:
                x_max = post_coords[2]
            else:
                x_max = pre_coords[2]

            if post_coords[1] <= pre_coords[1] <= post_coords[3]:
                y_min = pre_coords[1]
                if post_coords[1] <= pre_coords[3] <= post_coords[3]:
                    y_max = pre_coords[3]
                else:
                    y_max = post_coords[3]
                overlap = True

            elif pre_coords[1] <= post_coords[1] <= pre_coords[3]:
                y_min = post_coords[1]
                if pre_coords[1] <= post_coords[3] <= pre_coords[3]:
                    y_max = post_coords[3]
                else:
                    y_max = pre_coords[3]
                overlap = True
        if not overlap:
            QMessageBox.critical(None,
                                 'Error!',
                                 ('Scenes do not intersect.\n'
                                  'Scene Prep exiting.'),
                                 QMessageBox.Ok)
            return

        # Calculate output rows and columns
        self.out_rows = int(abs(y_max - y_min) / cell_size)
        self.out_cols = int(abs(x_max - x_min) / cell_size)

        # Calculate offset values needed for pre and post raster
        self.pre_x_offset = int(abs((pre_coords[0] - x_min) / cell_size))
        self.pre_y_offset = int(abs((pre_coords[3] - y_max) / cell_size))
        self.post_x_offset =\
            int(abs((post_coords[0] - x_min) / cell_size))
        self.post_y_offset =\
            int(abs((post_coords[3] - y_max) / cell_size))

        # Get new geo and proj
        dnbr_geo = (x_min, cell_size, 0, y_max, 0, -cell_size)
        dnbr_proj = pre_proj

        #  Create Gap Mask
        if os.path.exists(pre_refl_path):
            if pre_nbr.startswith(self.LS7):
                if os.path.exists(pre_gapmask_path):
                    os.remove(pre_gapmask_path)
                    self.toa_mask(pre_refl_path, pre_gapmask_path)
                    self.buffer_mask(pre_gapmask_path)
                    rv = self.nbr_mask(pre_gapmask_path,
                                       pre_gapmask_path_temp)
                else:
                    self.toa_mask(pre_refl_path, pre_gapmask_path)
                    self.buffer_mask(pre_gapmask_path)
                    rv = self.nbr_mask(pre_gapmask_path,
                                       pre_gapmask_path_temp)
                if rv != 0:
                    return
        if os.path.exists(post_refl_path):
            if post_nbr.startswith(self.LS7):
                if os.path.exists(post_gapmask_path):
                    os.remove(post_gapmask_path)
                    self.toa_mask(post_refl_path, post_gapmask_path)
                    self.buffer_mask(post_gapmask_path)
                    rv = self.nbr_mask(post_gapmask_path,
                                       post_gapmask_path_temp)
                else:
                    self.toa_mask(post_refl_path, post_gapmask_path)
                    self.buffer_mask(post_gapmask_path)
                    rv = self.nbr_mask(post_gapmask_path,
                                       post_gapmask_path_temp)
                if rv != 0:
                    return

        # Open pre and post arrays
        pre_array = np.array(pre_band.ReadAsArray(self.pre_x_offset,
                                                  self.pre_y_offset,
                                                  self.out_cols,
                                                  self.out_rows))
        post_array = np.array(post_band.ReadAsArray(self.post_x_offset,
                                                    self.post_y_offset,
                                                    self.out_cols,
                                                    self.out_rows))
        # Create dNBR mask
        if not os.path.exists(mask_temp):
            self.dnbr_mask(pre_gapmask_path_temp, post_gapmask_path_temp,
                           mask_temp, self.out_cols, self.out_rows,
                           dnbr_geo, dnbr_proj)
        elif os.path.exists(mask_temp) and self.overwrite is True:
            self.dnbr_mask(pre_gapmask_path_temp, post_gapmask_path_temp,
                           mask_temp, self.out_cols, self.out_rows,
                           dnbr_geo, dnbr_proj)

        # Convert masks to 1 bit images
        if os.path.exists(mask_temp):
            command = (self.CMD_BASE + mask_temp + ' ' + dnbr_mask_path)
            subprocess.call(command, shell=True)
            os.remove(mask_temp)
        if os.path.exists(pre_gapmask_path):
            os.remove(pre_gapmask_path)
            if os.path.exists(pre_gapmask_path_temp):
                command = (self.CMD_BASE +
                           pre_gapmask_path_temp + ' ' +
                           pre_gapmask_path)
                subprocess.call(command, shell=True)
                os.remove(pre_gapmask_path_temp)
        if os.path.exists(post_gapmask_path):
            os.remove(post_gapmask_path)
            if os.path.exists(post_gapmask_path_temp):
                command = (self.CMD_BASE +
                           post_gapmask_path_temp + ' ' +
                           post_gapmask_path)
                subprocess.call(command, shell=True)
                os.remove(post_gapmask_path_temp)
        # Open the dNBR mask
        if os.path.exists(dnbr_mask_path):
            dnbr_mask_ds = gdal.Open(dnbr_mask_path)
            dnbr_mask_band = dnbr_mask_ds.GetRasterBand(1)
            dnbr_mask_array =\
                np.array(dnbr_mask_band.ReadAsArray(), dtype=np.byte)
            dnbr_mask_ds = None
            dnbr_mask_band = None

        # Compute the dNBR and output the finished product
        dnbr_array = pre_array - post_array
        if os.path.exists(dnbr_mask_path):
            dnbr_array[dnbr_mask_array == self.DNBR_MIN] = self.NDV
            dnbr_array[dnbr_array >= self.DNBR_MAX] = self.NDV
            dnbr_array[(pre_array == self.NDV) | (post_array == self.NDV)] =\
                self.NDV
            dnbr_mask_array = None
        else:
            dnbr_array[(pre_array == self.NDV) | (post_array == self.NDV)] =\
                self.NDV

        dnbr_ds = self.driver.Create(dnbr_out_path, self.out_cols,
                                     self.out_rows, 1, gdal.GDT_Int16)
        dnbr_ds.SetGeoTransform(dnbr_geo)
        dnbr_ds.SetProjection(dnbr_proj)
        dnbr_band = dnbr_ds.GetRasterBand(1)
        dnbr_band.WriteArray(dnbr_array)
        dnbr_band.FlushCache()
        dnbr_ds = None

        QMessageBox.information(None,
                                'Run Scene Prep Complete',
                                'Run Scene Prep step is complete',
                                QMessageBox.Ok)
