'''
/***************************************************************************
Name		     : ThresholdProcess
Description          : Threshold Process functions for QGIS FMT3 Plugin
copyright            : (C) 2018 by Cheryl Holen
Created              : Sep 06, 2018 - Adapted from QGIS 2.x version
Updated              :
******************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
 This code is based off of work done by:

    Kelcy Smith
    Contractor to USGS/EROS
    kelcy.smith.ctr@usgs.gov

    Threshold for Open Source MTBS

    References:
    http://landsathandbook.gsfc.nasa.gov/data_prod/prog_sect11_3.html
    http://landsathandbook.gsfc.nasa.gov/pdfs/L5TMLUTIEEE2003.pdf

'''
from PyQt5.QtWidgets import QMessageBox

from qgis.core import (QgsLayerTreeLayer, QgsProject, QgsRasterLayer,)

import numpy as np
import os
from osgeo import gdal
from osgeo import ogr
import subprocess

from .Utilities import Utilities
from FMT3 import constants as const


class ThresholdProcess():
    def __init__(self, mapping_id, parent_path, low, mod, high, regrowth):
        self.mapping_id = mapping_id
        self.low = int(low)
        self.mod = int(mod)
        self.high = int(high)
        self.regrowth = regrowth  # this remains a string until recode

        self.event_prods_path = os.path.join(
                parent_path, 'event_prods', 'fire')
        self.utils = Utilities()

    def recode_dnbr(self, val):
        '''
        Recode dNBR

        :param integer: value
        :type integer: recoded value
        '''
        # No data recode
        if val <= -9999:
            ret_val = 6
        # Veg Regrowth Recode
        elif val > -9999 and val <= int(self.regrowth):
            ret_val = 5
        # Unburned recode
        elif val > int(self.regrowth) and val < self.low:
            ret_val = 1
        # Low recode
        elif val >= self.low and val < self.mod:
            ret_val = 2
        # Mod recode
        elif val >= self.mod and val < self.high:
            ret_val = 3
        elif val >= self.high and val < 9999:
            ret_val = 4
        else:
            ret_val = 0

        return ret_val

    def recode_nbr(self, val):
        '''
        Recode NBR

        :param integer: value
        :type integer: recoded value
        '''
        # No data recode
        if val <= -9999:
            ret_val = 6
        # Unburned recode
        elif val > self.low:
            ret_val = 1
        # Low recode
        elif val <= self.low and val > self.mod:
            ret_val = 2
        # Mod recode
        elif val <= self.mod and val > self.high:
            ret_val = 3
        elif val <= self.high and val > -9999:
            ret_val = 4
        else:
            ret_val = 0

        return ret_val

    def calculate_thresholds(self):
        ''' Calculate the thresholds '''

        sql = ("SELECT * FROM 'Mappings' WHERE Id = " + self.mapping_id)
        fetch = self.utils.run_query(sql)

        fire_id = str(fetch[1])
        prefire_id = str(fetch[6])
        postfire_id = str(fetch[10])
        fire_comment = str(fetch[14])
        confidence = str(fetch[24])

        sql = "SELECT * FROM 'Fires' WHERE Id = " + str(fire_id)
        fetch = self.utils.run_query(sql)

        mtbs_id = str(fetch[1]).lower()

        if fire_comment == '':
            fire_comment = 'None'

        fireDate = fetch[4]
        fire_date_string = str(fireDate).split('-')
        fire_year = fire_date_string[0]

        # Make Folder
        mapping_dir = os.path.join(self.event_prods_path,
                                   fire_year,
                                   mtbs_id,
                                   'mtbs_' + self.mapping_id)

        if not os.path.exists(mapping_dir):
            os.makedirs(mapping_dir)

        _, pre_date, _, _ = \
                self.utils.get_image_folder_date_path_row_name(prefire_id)
        _, post_date, _, _ = \
                self.utils.get_image_folder_date_path_row_name(postfire_id)

        # Clip Post Nbr
        post_nbr_clip = os.path.join(
                mapping_dir,
                (mtbs_id + '_' + post_date + '_nbr' + const.TIF_EXT))

        # Clip dNBR location
        dnbr_clip = os.path.join(mapping_dir,
                                 (mtbs_id + '_' + pre_date + '_' +
                                  post_date + '_dnbr' + const.TIF_EXT))

        # Mask Clip Locations
        dnbr_mask_clip = os.path.join(mapping_dir,
                                      (mtbs_id + '_' + pre_date + '_' +
                                       post_date + '_gapmask' +
                                       const.TIF_EXT))

        post_mask_clip = os.path.join(mapping_dir,
                                      (mtbs_id + '_' +
                                       post_date + '_gapmask' +
                                       const.TIF_EXT))

        # Create Fire Shapefile
        if os.path.exists(dnbr_clip):
            burn_shape = (mtbs_id + '_' + pre_date + '_' +
                          post_date + '_burn_bndy.shp')
        else:
            burn_shape = (mtbs_id + '_' + post_date + '_burn_bndy.shp')

        burn_bndy_path = os.path.join(mapping_dir, burn_shape)

        # Add data to Shapefile
        # Open Burn Bndy using OGR
        driver = ogr.GetDriverByName('ESRI Shapefile')
        src_ds = driver.Open(burn_bndy_path, 1)
        lyr = src_ds.GetLayer()

        count = -1
        for feature in lyr:
            count = count + 1
            feature.SetField('Confidence', confidence)
            feature.SetField('Comment', fire_comment)
            lyr.SetFeature(feature)

            feature = None
            del feature

        src_ds.Destroy()

        lyr = None
        del lyr

        # Mask Shapefile
        mask_shp = burn_shape.replace('_burn_bndy.shp', '_mask.shp')
        mask_shape_path = os.path.join(mapping_dir, mask_shp)

        mask_tif_path = mask_shape_path.replace('.shp', const.TIF_EXT)
        input_image = None
        toc_name = None
        # Start Main Process
        if os.path.exists(dnbr_clip):
            input_image = dnbr_clip
            output_image = input_image.replace('_dnbr' + const.TIF_EXT,
                                               '_dnbr6' + const.TIF_EXT)
            output_temp = output_image.replace(const.TIF_EXT,
                                               '_temp' + const.TIF_EXT)
            toc_name = 'dNBR6'
        else:
            if os.path.exists(post_nbr_clip):
                input_image = post_nbr_clip
                output_image = input_image.replace('_nbr' + const.TIF_EXT,
                                                   '_nbr6' + const.TIF_EXT)
                output_temp = output_image.replace(const.TIF_EXT,
                                                   '_temp' + const.TIF_EXT)
                toc_name = 'NBR6'

        if not input_image:
            QMessageBox.critical(None,
                                 'Missing file',
                                 'Input image is missing',
                                 QMessageBox.Ok)
            return
        # Open file
        thresh_ds = gdal.Open(input_image)
        thresh_band = thresh_ds.GetRasterBand(1)
        thresh_geotransform = thresh_ds.GetGeoTransform()
        thresh_cols = thresh_band.XSize
        thresh_rows = thresh_band.YSize
        # Determine whether 8 or 16 bit data
        thresh_data_type = gdal.GetDataTypeName(thresh_band.DataType)
        thresh_projection = thresh_ds.GetProjection()

        # Check the input image bit type
        if thresh_data_type == 'Int8':
            bit_type = np.int8
        elif thresh_data_type == 'Int16':
            bit_type = np.int16
        elif thresh_data_type == 'Int32':
            bit_type = np.int32
        else:
            QMessageBox.critical(None,
                                 'Incorrect data type',
                                 'Incorrect data type' + thresh_data_type,
                                 QMessageBox.Ok)
            return

        # Bring in data as an array
        thresh_array = np.array(
                thresh_band.ReadAsArray(0, 0, thresh_cols, thresh_rows),
                dtype=bit_type)
        #  (rows, cols) = thresh_array.shape
        # Flatten array and change to list
        thresh_array_flat = thresh_array.flatten()
        thresh_array_list = thresh_array_flat.tolist()
        # Recode the list
        if input_image.endswith('_dnbr' + const.TIF_EXT):
            recode_thresh_list =\
                [self.recode_dnbr(x) for x in thresh_array_list]
        else:
            recode_thresh_list =\
                [self.recode_nbr(x) for x in thresh_array_list]
        # List to array specifying the bit type (e.g. 16 bits).
        recode_thresh_array = np.array(recode_thresh_list,
                                       dtype=bit_type).reshape(thresh_rows,
                                                               thresh_cols)

        # Open Mask as array
        mask_array = self.utils.file_to_array(mask_tif_path)

        # Open Gap Mask
        if os.path.exists(dnbr_mask_clip):
            gap_array = self.utils.file_to_array(dnbr_mask_clip)
            recode_thresh_array[(mask_array == 1) | (gap_array == 0)] = 6
        else:
            if os.path.exists(post_mask_clip):
                gap_array = self.utils.file_to_array(post_mask_clip)
                recode_thresh_array[(mask_array == 1) | (gap_array == 0)] = 6
            else:
                # Convert Masked values equal to 1 to 6 (mask MTBS value)
                recode_thresh_array[mask_array == 1] = 6
        mask_array = None
        gap_array = None

        # Setup the recode
        # Delete image if it already exists
        if os.path.exists(output_temp):
            os.remove(output_temp)

        # Create output image and write data
        driver = gdal.GetDriverByName('GTiff')
        dst_ds = driver.Create(output_temp,
                               thresh_cols,
                               thresh_rows,
                               1, gdal.GDT_Byte)
        dst_ds.SetGeoTransform(thresh_geotransform)
        dst_ds.SetProjection(thresh_projection)
        out_band = dst_ds.GetRasterBand(1)
        out_band.WriteArray(recode_thresh_array)

        del out_band
        del recode_thresh_array
        del thresh_ds
        del thresh_band
        del dst_ds

        # Create temp vrt
        out_vrt = output_temp.replace(const.TIF_EXT, '.vrt')
        # Delete image if it already exists
        if os.path.exists(out_vrt):
            os.remove(out_vrt)
        cmd = ('gdal_translate -q -of VRT ' + output_temp + ' ' + out_vrt)
        subprocess.call(cmd, shell=True)

        # Add Color Table
        out_vrt_2 = out_vrt.replace('.vrt', '2.vrt')
        # Delete image if it already exists
        if os.path.exists(out_vrt_2):
            os.remove(out_vrt_2)

        o_vrt = open(out_vrt, 'r')
        r_vrt = o_vrt.readlines()
        o_vrt.close()

        o_vrt_2 = open(out_vrt_2, 'w')

        rText = ('<ColorInterp>Palette</ColorInterp>\n' +
                 '       <ColorTable>\n' +
                 '           <Entry c1="0" c2="0" c3="0" c4="255"/>\n' +
                 '           <Entry c1="0" c2="100" c3="0" c4="255"/>\n' +
                 '           <Entry c1="127" c2="255" c3="212" c4="255"/>\n' +
                 '           <Entry c1="255" c2="255" c3="0" c4="255"/>\n' +
                 '           <Entry c1="255" c2="0" c3="0" c4="255"/>\n' +
                 '           <Entry c1="127" c2="255" c3="0" c4="255"/>\n' +
                 '           <Entry c1="255" c2="255" c3="255" c4="255"/>\n' +
                 '       </ColorTable>')

        for line in r_vrt:
            if '<ColorInterp>Gray</ColorInterp>' not in line:
                o_vrt_2.write(line)
            else:
                o_vrt_2.write(line.replace('<ColorInterp>Gray</ColorInterp>',
                                           rText))

        o_vrt_2.close()

        # run gdal_translate to create new image with color map
        if os.path.exists(out_vrt):
            os.remove(out_vrt)

        if os.path.exists(output_image):
            os.remove(output_image)

        cmd = ('gdal_translate -q -ot Byte ' + out_vrt_2 + ' ' + output_image)
        subprocess.call(cmd, shell=True)

        # Delete Vrt Files
        if os.path.exists(out_vrt_2):
            os.remove(out_vrt_2)

        if os.path.exists(output_temp):
            os.remove(output_temp)

        # Clip Image to Shapefile
        cmd = ('gdalwarp -q -cutline ' + burn_bndy_path + ' ' + output_image +
               ' ' + output_temp)
        subprocess.call(cmd, shell=True)

        if os.path.exists(output_temp):
            if os.path.exists(output_image):
                os.remove(output_image)
            os.rename(output_temp, output_image)
        
        # Create QGIS project and add data
        project_out_file = burn_bndy_path.replace('_burn_bndy.shp', '.qgs')
        project = QgsProject.instance()
        project.read(project_out_file)

        # Check if layer already exists in project and remove it
        # to replace with new later
        layers = project.layerTreeRoot().children()
        for layer in layers:
            if layer.name() == toc_name:
                project.removeMapLayer(layer.layerId())

        rst_layer = QgsRasterLayer(output_image, toc_name)
        if not rst_layer.isValid():
            raise IOError(
                "ERROR: Raster layer - {} did not load".format(output_image))
        root = project.layerTreeRoot()
        project.addMapLayer(rst_layer, False)
        root.insertChildNode(2, QgsLayerTreeLayer(rst_layer))
        project.write(project_out_file)
        project.clear()
