'''
/***************************************************************************
Name	   : SubsetProcess
Description: Subset Process functions for QGIS FMT3 Plugin
copyright  : (C) 2018 by Cheryl Holen
Created    : 09/06/2018 - Adapted from QGIS 2.x version
Updated    : 08/22/2019 - cholen - Added handling for empty histogram
 ******************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
 This code is based off of work done by:

    Kelcy Smith
    Contractor to USGS/EROS
    kelcy.smith.ctr@usgs.gov

    Clip and subset for Open Source MTBS

    References:
    http://landsathandbook.gsfc.nasa.gov/data_prod/prog_sect11_3.html
    http://landsathandbook.gsfc.nasa.gov/pdfs/L5TMLUTIEEE2003.pdf

'''
from PyQt5.QtWidgets import QMessageBox
from qgis.core import QgsRasterLayer

import numpy as np
import os
from osgeo import ogr
from osgeo import gdal
from osgeo.gdalconst import GA_ReadOnly
import subprocess

from .Utilities import Utilities
from FMT3 import constants as const


class Subset():
    ENV_VAL = 5000.0
    ACRE_CONV = 0.000247105

    def __init__(self, mapping_id, parent_path, img_src):
        self.mapping_id = mapping_id
        self.driver = ogr.GetDriverByName('ESRI Shapefile')
        self.img_src_path = img_src
        self.img_proc_path =\
            os.path.join(parent_path, 'img_proc')
        self.event_prods_path =\
            os.path.join(parent_path, 'event_prods', 'fire')
        self.shp_xmin = 0.0
        self.shp_ymin = 0.0
        self.shp_xmax = 0.0
        self.shp_ymax = 0.0
        self.utils = Utilities()

    def calc_rdnbr_std_offset(
            self, in_image, perim_image, mask_image, low, high):
        '''
        Calculate rDNBR offset

        :param string: Input image
        :param string: Perimeter input image
        :param string: Mask Input image
        :param integer: Low value break
        :param integer: High value break

        :type integer: - rDnbr offset
        :type integer: - Std dev of rDnbr offset
        '''
        read_in = gdal.Open(in_image, GA_ReadOnly)
        band_in = read_in.GetRasterBand(1)

        perim_in = gdal.Open(perim_image, GA_ReadOnly)
        p_band = perim_in.GetRasterBand(1)

        mask_in = gdal.Open(mask_image, GA_ReadOnly)
        m_band = mask_in.GetRasterBand(1)

        data_in_array = np.array(band_in.ReadAsArray(), dtype=np.float64)
        perim_array = np.array(p_band.ReadAsArray(), dtype=np.byte)
        mask_array = np.array(m_band.ReadAsArray(), dtype=np.byte)

        data_in_array[(data_in_array <= low) |
                      (data_in_array >= high) |
                      (perim_array == 1) |
                      (mask_array == 1)] = np.nan

        off_dnbr = int(round(np.nanmedian(data_in_array)))
        off_std = int(round(np.nanstd(data_in_array)))
        mask_array = None
        mask_in = None
        m_band = None
        perim_array = None
        perim_in = None
        p_band = None
        data_in_array = None
        read_in = None
        band_in = None
        del mask_array
        del mask_in
        del m_band
        del perim_array
        del perim_in
        del p_band
        del data_in_array
        del read_in
        del band_in

        return off_dnbr, off_std

    def calc_histogram(self, image, nbins=256):
        '''
        Histogram using Otsu Threshold method

        :param string: Input image
        :param integer: Bin count

        :type array: Histogram
        :type array: Bin centers
        '''
        # handle issue when the array is empty
        if len(image) == 0:
            return image, 0
        # For integer types, histogramming with bincount is more efficient.
        if np.issubdtype(image.dtype, np.integer):
            offset = 0
            image_min = np.min(image)
            if image_min < 0:
                offset = image_min
                image_range = np.max(image).astype(np.int64) - image_min
                # get smallest dtype that can hold both min and offset max
                offset_data_type =\
                    np.promote_types(np.min_scalar_type(image_range),
                                     np.min_scalar_type(image_min))
                if image.dtype != offset_data_type:
                    # prevent overflow errors when offsetting
                    image = image.astype(offset_data_type)
                image = image - offset
            hist = np.bincount(image.ravel())
            bin_centers = np.arange(len(hist)) + offset

            # clip histogram to start with a non-zero bin
            if not len(hist):
                return hist, bin_centers
            idx = np.nonzero(hist)[0][0]
            return hist[idx:], bin_centers[idx:]
        else:
            hist, bin_edges = np.histogram(image.flat, nbins)
            bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.
            return hist, bin_centers

    def threshold_otsu(self, image, nbins=256):
        '''
        Threshold using Otsu method

        :param string: Input image
        :param integer: Bin count

        :type integer: Threshold value

        #####################################################################
        # Utilize the Otsu method for determining a threshold to use with a
        # raster the modified histogram function boosts the output
        # http://en.wikipedia.org/wiki/Otsu%27s_method
        #####################################################################
        '''
        hist, bin_centers = self.calc_histogram(image, nbins)
        hist = hist.astype(float)

        # class probabilities for all possible thresholds
        weight1 = np.cumsum(hist)
        weight2 = np.cumsum(hist[::-1])[::-1]
        if not np.all(weight1):
            return 0
        # class means for all possible thresholds
        mean1 = np.cumsum(hist * bin_centers) / weight1
        mean2 = (np.cumsum((hist * bin_centers)[::-1]) / weight2[::-1])[::-1]

        # Clip ends to align class 1 and class 2 variables:
        # The last value of `weight1`/`mean1` should pair with zero values in
        # `weight2`/`mean2`, which do not exist.
        variance12 = weight1[:-1] * weight2[1:] * (mean1[:-1] - mean2[1:]) ** 2
        if not len(variance12):
            return 0

        idx = np.argmax(variance12)
        threshold = bin_centers[:-1][idx]
        return threshold

    def get_breakpoints(self, src_image):
        '''
        Get approximate breakpoints for image
        :param string: Input image

        :type integer: Low threshold value
        :type integer: Moderate threshold value
        :type integer: High threshold value
        :type integer: Regrowth value
        :type integer: Standard deviation value
        '''
        read_image = gdal.Open(src_image, GA_ReadOnly)
        image_band = read_image.GetRasterBand(1)
        image_array = image_band.ReadAsArray()
        min_val = np.min(image_array)
        max_val = np.max(image_array)
        if src_image.endswith('_nbr.tif'):
            # Set NBR bounds (-1000 >= NBR <= 1000)
            mask_array =\
                image_array[(image_array >= -1000) & (image_array <= 1000)]

            # Set values <= x to x
            # This weights the distribution of image values
            # to the left- or right-hand data distribution
            mask_array_low = mask_array[(mask_array > 300)]
            mask_array_mod =\
                mask_array[(mask_array <= 300) & (mask_array >= -65)]
            mask_array_high = mask_array[(mask_array < -65)]
            regrowth = None

        # Examine dNBR imagery
        elif src_image.endswith('_dnbr.tif'):
            # Set dNBR bounds (-2000 >= dNBR <= 2000)
            mask_array =\
                image_array[(image_array > -2000) & (image_array < 2000)]

            # Set values <= x to x
            # This weights the distribution of image values
            # to the left- or right-hand data distribution
            mask_array_low =\
                mask_array[(mask_array > -100) & (mask_array <= 269)]
            mask_array_mod =\
                mask_array[(mask_array > 270) & (mask_array <= 439)]
            mask_array_high = mask_array[(mask_array >= 440)]
            regrowth = -150

        # Standard Deviation of masked image
        std = np.std(mask_array)

        # Determine Threshold using Otsu
        threshold_low = self.threshold_otsu(mask_array_low)
        threshold_mod = self.threshold_otsu(mask_array_mod)
        threshold_high = self.threshold_otsu(mask_array_high)

        # Deal with cases where threshold val from otsu is 0
        if src_image.endswith('_nbr.tif'):
            if threshold_low == 0:
                threshold_low = min_val - 100
            if threshold_mod == 0:
                threshold_mod = min_val - 200
            if threshold_high == 0:
                threshold_high = -9999
        elif src_image.endswith('_dnbr.tif'):
            if threshold_low == 0:
                threshold_low = max_val + 100
            if threshold_mod == 0:
                threshold_mod = max_val + 200
            if threshold_high == 0:
                threshold_high = 9999
        # Round the breaks
        low = int(round(threshold_low))
        mod = int(round(threshold_mod))
        high = int(round(threshold_high))

        # Clear out in memory arrays
        read_image = None
        image_array = None
        mask_array = None
        mask_array_low = None
        mask_array_mod = None
        mask_array_high = None
        del read_image
        del image_array
        del mask_array
        del mask_array_low
        del mask_array_mod
        del mask_array_high

        return low, mod, high, regrowth, std

    def process_subset(self):
        ''' Run the subset process '''
        offset_calc = None
        offset_std = None
        low = None
        mod = None
        high = None
        regrowth = None
        bb_center_x = None
        bb_center_y = None

        fire_table = 'Fires'
        mapping_table = 'Mappings'

        sql = ('SELECT * FROM ' + mapping_table +
               ' WHERE Id = ' + self.mapping_id)
        fetch = self.utils.run_query(sql)

        fire_id = str(fetch[1])
        prefire_id = str(fetch[6])
        postfire_id = str(fetch[10])
        perim_id = str(fetch[12])
        confidence = str(fetch[24])

        sql = 'SELECT * FROM ' + fire_table + ' WHERE Id = ' + str(fire_id)
        fetch = self.utils.run_query(sql)

        mtbs_id = str(fetch[1])
        mtbs_id_lower = mtbs_id.lower()
        fire_name = str(fetch[2])
        fire_comment = str(fetch[14])

        if fire_comment == '':
            fire_comment = 'None'

        fire_date = fetch[4]
        fire_date_string = str(fire_date).split('-')
        fire_year = fire_date_string[0]
        fire_month = fire_date_string[1]
        fire_day = fire_date_string[2]
        # Make Folder
        mapping_dir = os.path.join(self.event_prods_path,
                                   fire_year,
                                   mtbs_id_lower,
                                   'mtbs_' + self.mapping_id)

        if not os.path.exists(mapping_dir):
            os.makedirs(mapping_dir)
        
        pre_sub_path, pre_date, pre_path_row, pre_name_template = \
                self.utils.get_image_folder_date_path_row_name(prefire_id)
        post_sub_path, post_date, post_path_row, post_name_template = \
                self.utils.get_image_folder_date_path_row_name(postfire_id)
        perim_sub_path, perim_date, perim_path_row, perim_name_template = \
                self.utils.get_image_folder_date_path_row_name(perim_id)

        if post_sub_path == const.LANDSAT:
            self.img_proc_path = os.path.join(self.img_proc_path,
                                              const.LANDSAT)
            res = 30
        else:
            self.img_proc_path = os.path.join(self.img_proc_path,
                                              const.SENTINEL2)
            res = self.utils.get_resolution(postfire_id)

        # REFL Images
        post_refl = os.path.join(self.img_src_path,
                                 post_sub_path,
                                 post_path_row,
                                 postfire_id,
                                 post_name_template.replace(
                                     '<ext>', const.REFL))
        pre_refl = os.path.join(self.img_src_path,
                                pre_sub_path,
                                pre_path_row,
                                prefire_id,
                                pre_name_template.replace(
                                    '<ext>', const.REFL))
        perim_refl = os.path.join(self.img_src_path,
                                  perim_sub_path,
                                  perim_path_row,
                                  perim_id,
                                  perim_name_template.replace(
                                    '<ext>', const.REFL))

        # Get Image clip output names
        pre_refl_clipped = os.path.join(
                mapping_dir,
                (mtbs_id_lower + '_' + pre_date + '_' +
                 self.utils.get_sensor(prefire_id) + '_refl' +
                 const.TIF_EXT))

        post_refl_clipped = os.path.join(
                mapping_dir,
                (mtbs_id_lower + '_' + post_date + '_' +
                 self.utils.get_sensor(postfire_id) + '_refl' +
                 const.TIF_EXT))

        perim_refl_clipped = os.path.join(
                mapping_dir,
                (mtbs_id_lower + '_' + perim_date + '_' +
                 self.utils.get_sensor(perim_id) + '_refl' +
                 const.TIF_EXT))

        post_nbr = os.path.join(self.img_src_path,
                                 post_sub_path,
                                 post_path_row,
                                 postfire_id,
                                 post_name_template.replace(
                                     '<ext>', const.NBR))

        pre_nbr = os.path.join(self.img_src_path,
                                pre_sub_path,
                                pre_path_row,
                                prefire_id,
                                pre_name_template.replace(
                                    '<ext>', const.NBR))

        perim_nbr = os.path.join(self.img_src_path,
                                  perim_sub_path,
                                  perim_path_row,
                                  perim_id,
                                  perim_name_template.replace(
                                    '<ext>', const.NBR))

        post_nbr_mask = post_nbr.replace(const.NBR, const.GM)

        # Clip Post Nbr
        post_nbr_clip = os.path.join(mapping_dir,
                                     (mtbs_id_lower + '_' +
                                      post_date + '_nbr' +
                                      const.TIF_EXT))

        pre_nbrl_clipped = os.path.join(mapping_dir,
                                        (mtbs_id_lower + '_' +
                                         pre_date + '_nbr' +
                                         const.TIF_EXT))

        perim_nbr_clipped = os.path.join(mapping_dir,
                                         (mtbs_id_lower + '_' +
                                          perim_date + '_nbr' +
                                          const.TIF_EXT))

        post_mask_clip = os.path.join(mapping_dir,
                                      (mtbs_id_lower + '_' +
                                       post_date + '_gapmask' +
                                       const.TIF_EXT))

        # dNBR Images
        dnbr_out_dir = os.path.join(self.img_proc_path,
                                    (pre_path_row + '_' + pre_date + '_' +
                                     post_path_row + '_' + post_date))
        dnbr_image = ('d' + pre_path_row + '_' + pre_date + '_' +
                      post_path_row + '_' + post_date + const.TIF_EXT)

        dnbr_mask_image = ('m' + pre_path_row + '_' + pre_date + '_' +
                           post_path_row + '_' + post_date + const.TIF_EXT)

        dnbr_out_path = os.path.join(dnbr_out_dir, dnbr_image)
        dnbr_mask_path = os.path.join(dnbr_out_dir, dnbr_mask_image)

        # Clip dNBR location
        dnbr_clip = os.path.join(mapping_dir,
                                 (mtbs_id_lower + '_' +
                                  pre_date + '_' + post_date + '_dnbr' +
                                  const.TIF_EXT))

        dnbr_mask_clip =\
            os.path.join(mapping_dir,
                         (mtbs_id_lower + '_' +
                          pre_date + '_' + post_date + '_gapmask' +
                          const.TIF_EXT))

        # Create Fire Shapefile
        if os.path.exists(dnbr_out_path):
            burn_shape = (mtbs_id_lower + '_' + pre_date + '_' +
                          post_date + '_burn_bndy.shp')
        else:
            burn_shape = (mtbs_id_lower + '_' + post_date + '_burn_bndy.shp')
        burn_bndy_path = os.path.join(mapping_dir, burn_shape)

        burn_tif = burn_bndy_path.replace('.shp', const.TIF_EXT)

        # Mask Shapefile
        mask_shape = burn_shape.replace('_burn_bndy.shp', '_mask.shp')
        mask_shape_path = os.path.join(mapping_dir, mask_shape)

        mask_tif_path = mask_shape_path.replace('.shp', const.TIF_EXT)

        # Add data to Shapefile
        src_ds = self.driver.Open(burn_bndy_path, 1)
        lyr = src_ds.GetLayer()

        count = -1
        for feature in lyr:
            count = count + 1
            geom = feature.GetGeometryRef()
            geo_area = round(float(geom.GetArea()), 5)
            perimeter = geom.Boundary().Length()
            acres = round(float(geo_area * self.ACRE_CONV), 5)

            feature.SetField('Id', count)
            feature.SetField('Area', geo_area)
            feature.SetField('Perimeter', perimeter)
            feature.SetField('Acres', acres)
            feature.SetField('Fire_Id', mtbs_id)
            feature.SetField('Fire_Name', fire_name)
            feature.SetField('Year', fire_year)
            feature.SetField('StartMonth', fire_month)
            feature.SetField('StartDay', fire_day)
            feature.SetField('Confidence', confidence)
            feature.SetField('Comment', fire_comment)
            lyr.SetFeature(feature)

            feature = None
            del feature

        src_ds.Destroy()

        lyr = None
        del lyr

        # Open Mask Bndy using OGR
        src_ds = self.driver.Open(mask_shape_path, 1)
        lyr = src_ds.GetLayer()

        count = -1
        for feature in lyr:
            count = count + 1
            geom = feature.GetGeometryRef()
            geo_area = round(float(geom.GetArea()), 5)
            perimeter = geom.Boundary().Length()

            feature.SetField('Id', count)
            feature.SetField('Area', geo_area)
            feature.SetField('Perimeter', perimeter)
            lyr.SetFeature(feature)

            feature = None
            del feature

        src_ds.Destroy()

        lyr = None
        del lyr

        # Open Shapefile using OGR
        src_ds = self.driver.Open(burn_bndy_path, 0)
        lyr = src_ds.GetLayer()

        # Clip images
        org_extent_dict = {}
        for feature in lyr:
            # Get the Value for ORGCODE
            temp_id = feature.GetField('Fire_Id')
            # Get the envelope geometry
            try:
                geom = feature.GetGeometryRef().GetEnvelope()
                # Set up Dictionary to identify envelopes associated
                #   with each Id
                key = temp_id

                if key not in org_extent_dict:
                    org_extent_dict[key] = (geom,)
                else:
                    value_list = list(org_extent_dict[key])
                    value_list.append(geom,)
                    value_tuple = tuple(value_list)
                    org_extent_dict[key] = value_tuple
            except Exception:
                pass

        # Need to state if looping through multiple shapes
        lyr.ResetReading()

        src_ds = None
        del src_ds
        lyr = None
        del lyr

        # Loop through each Id in the dictionary to determine
        #   min and max x and y envelope coordinates
        for i in org_extent_dict.items():
            xmin = min(tuple(x[0] for x in i[1]))
            ymin = min(tuple(y[2] for y in i[1]))
            xmax = max(tuple(x[1] for x in i[1]))
            ymax = max(tuple(y[3] for y in i[1]))

            # Calculate the new envelope for the each individual output raster
            self.shp_xmin = self.utils.odd_even(float(xmin) - self.ENV_VAL)
            self.shp_ymin = self.utils.odd_even(float(ymin) - self.ENV_VAL)
            self.shp_xmax = self.utils.odd_even(float(xmax) + self.ENV_VAL)
            self.shp_ymax = self.utils.odd_even(float(ymax) + self.ENV_VAL)

            # Get Centroid of shape(s)
            bb_center_x = (self.shp_xmin + self.shp_xmax) / 2
            bb_center_y = (self.shp_ymin + self.shp_ymax) / 2

            # Clip images
            self.clip_raster(post_refl, post_refl_clipped)
            self.clip_raster(pre_refl, pre_refl_clipped)
            self.clip_raster(perim_refl, perim_refl_clipped)
            self.clip_raster(dnbr_out_path, dnbr_clip)
            self.clip_raster(post_nbr, post_nbr_clip)
            self.clip_raster(pre_nbr, pre_nbrl_clipped)
            self.clip_raster(perim_nbr, perim_nbr_clipped)
            self.clip_raster(dnbr_mask_path, dnbr_mask_clip)
            self.clip_raster(post_nbr_mask, post_mask_clip)

            # Create Binary Mask from Burn Boundary Shapefile
            if os.path.exists(burn_bndy_path):
                cmd = ('gdal_rasterize -burn 1 ' + '-te ' +
                       str(self.shp_xmin) + ' ' + str(self.shp_ymin) + ' ' +
                       str(self.shp_xmax) + ' ' + str(self.shp_ymax) + ' ' +
                       f'-ot Byte -co NBITS=1 -tr {res} {res} -q ' +
                       burn_bndy_path + ' ' + burn_tif)
                subprocess.call(cmd, shell=True)
                # os.system(cmd)

            # Create Binary Mask from Mask Shapefile
            if os.path.exists(mask_shape_path):
                cmd = ('gdal_rasterize -burn 1 ' + '-te ' +
                       str(self.shp_xmin) + ' ' + str(self.shp_ymin) + ' ' +
                       str(self.shp_xmax) + ' ' + str(self.shp_ymax) + ' ' +
                       f'-ot Byte -co NBITS=1 -tr {res} {res} -q ' +
                       mask_shape_path + ' ' + mask_tif_path)
                # os.system(cmd)
                subprocess.call(cmd, shell=True)

        # Check to see if dnbr or nbr imagery used...
        if os.path.exists(dnbr_clip):
            if not os.path.exists(dnbr_mask_clip):
                self.clip_raster(pre_nbr, pre_nbrl_clipped)
            # Get dnbr offset value
            offset_calc, offset_std = self.calc_rdnbr_std_offset(
                    dnbr_clip, burn_tif, mask_tif_path, -100, 100)

        # Start Main Process
        if os.path.exists(dnbr_clip):
            input_image = dnbr_clip
        else:
            if os.path.exists(post_nbr):
                input_image = post_nbr_clip

        # Approximate Breakpoints
        low, mod, high, regrowth, std = self.get_breakpoints(input_image)
        # typecast to strings for return
        if offset_calc:
            offset_calc = str(offset_calc)
        low = str(low)
        mod = str(mod)
        high = str(high)
        regrowth = str(regrowth)
        if offset_std:
            offset_std = str(offset_std)
        bb_center_x = str(bb_center_x)
        bb_center_y = str(bb_center_y)

        QMessageBox.information(None,
                                'Subset Complete',
                                'Subset step is Complete',
                                QMessageBox.Ok)

        return offset_calc, low, mod, high, regrowth, offset_std, bb_center_x, bb_center_y

    def clip_raster(self, src_file, dst_file):
        '''
        Clip an input raster to specified extents.

        :param string: Input filename.
        :param string: Output filename.
        '''
        if os.path.exists(src_file) and not os.path.exists(dst_file):
            src_rstr_lyr = QgsRasterLayer(src_file, os.path.basename(src_file))
            if not src_rstr_lyr.isValid():
                QMessageBox.warning(None,
                                    'Raster layer problem',
                                    src_file + ' is invalid!',
                                    QMessageBox.Ok)
                raise IOError(src_file + ' is invalid!')

            cmd = ('gdalwarp -q -te ' + str(self.shp_xmin) + ' ' +
                   str(self.shp_ymin) + ' ' + str(self.shp_xmax) + ' ' +
                   str(self.shp_ymax) + ' ' + src_file + ' ' + dst_file)
            subprocess.call(cmd, shell=True)
            dst_rstr_lyr = QgsRasterLayer(dst_file, os.path.basename(dst_file))
            if not dst_rstr_lyr.isValid():
                QMessageBox.warning(None,
                                    'Clip - subset error',
                                    dst_file + ' is invalid!',
                                    QMessageBox.Ok)

            src_rstr_lyr = None
            dst_rstr_lyr = None
            del src_rstr_lyr
            del dst_rstr_lyr
