'''
/***************************************************************************
Name		     : RdNBR
Description          : RdNBR functions for QGIS FMT3 Plugin
copyright            : (C) 2018 by Cheryl Holen
Created              : Sep 06, 2018 - Adapted from QGIS 2.x version
Updated              :
 ******************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
 This code is based off of work done by:

    Kelcy Smith
    Contractor to USGS/EROS
    kelcy.smith.ctr@usgs.gov

    RdNBR creation for Open Source MTBS

    References:
    http://landsathandbook.gsfc.nasa.gov/data_prod/prog_sect11_3.html
    http://landsathandbook.gsfc.nasa.gov/pdfs/L5TMLUTIEEE2003.pdf

'''
from PyQt5.QtWidgets import QMessageBox

from qgis.core import (QgsLayerTreeLayer, QgsProject, QgsRasterLayer,)

import numpy as np
import os
from osgeo import gdal

from .Utilities import Utilities
from FMT3 import constants as const


class RdNBR():
    def __init__(self, mapping_id, parent_path, offset):
        self.mapping_id = mapping_id
        self.offset = offset
        self.event_prods_path = os.path.join(parent_path,
                                             'event_prods', 'fire')
        self.utils = Utilities()

    def rdnbr_process(self):
        ''' RdNBR Process '''
        sql = ("SELECT * FROM 'Mappings' WHERE Id = " + self.mapping_id)
        fetch = self.utils.run_query(sql)

        fire_id = str(fetch[1])
        prefire_id = str(fetch[6])
        postfire_id = str(fetch[10])

        sql = ("SELECT * FROM 'Fires' WHERE Id = " + str(fire_id))
        fetch = self.utils.run_query(sql)

        mtbs_id = str(fetch[1])
        mtbs_id_lower = mtbs_id.lower()
        fire_date = fetch[4]
        fire_date_list = str(fire_date).split('-')
        fire_year = fire_date_list[0]
        # Make Folder
        mapping_dir = os.path.join(self.event_prods_path,
                                   fire_year,
                                   mtbs_id_lower,
                                   'mtbs_' + self.mapping_id)

        if not os.path.exists(mapping_dir):
            os.makedirs(mapping_dir)

        _, pre_date, _, _ = \
                self.utils.get_image_folder_date_path_row_name(prefire_id)
        _, post_date, _, _ = \
                self.utils.get_image_folder_date_path_row_name(postfire_id)

        pre_nbr_clipped = os.path.join(
                mapping_dir,
                (mtbs_id_lower + '_' + pre_date + '_nbr' + const.TIF_EXT))

        # Clip dNBR location
        dnbr_clipped = os.path.join(
                mapping_dir,
                (mtbs_id_lower + '_' +
                 pre_date + '_' + post_date + '_dnbr' + const.TIF_EXT))

        if not os.path.exists(pre_nbr_clipped):
            QMessageBox.critical(None,
                                 'RdNBR Error!',
                                 'There is no NBR image to use.\n' +
                                 pre_nbr_clipped +
                                 '\n Exiting!',
                                 QMessageBox.Ok)
            return
        if not os.path.exists(dnbr_clipped):
            QMessageBox.critical(None,
                                 'RdNBR Error!',
                                 'There is no dNBR image to use.\n' +
                                 dnbr_clipped +
                                 '\n Exiting!',
                                 QMessageBox.Ok)
            return

        # Open Prefire Nbr
        pre_ds = gdal.Open(pre_nbr_clipped)
        pre_band = pre_ds.GetRasterBand(1)
        pre_rows = pre_band.YSize
        pre_cols = pre_band.XSize
        pre_proj = pre_ds.GetProjection()
        pre_geotransform = pre_ds.GetGeoTransform()

        pre_array = np.array(pre_band.ReadAsArray(0, 0,
                                                  pre_cols,
                                                  pre_rows),
                             dtype=np.float16)

        # Open dNBR imagery
        dnbr_ds = gdal.Open(dnbr_clipped)
        dnbr_band = dnbr_ds.GetRasterBand(1)
        dnbr_rows = dnbr_band.YSize
        dnbr_cols = dnbr_band.XSize

        dnbr_array = np.array(dnbr_band.ReadAsArray(0, 0,
                                                    dnbr_cols,
                                                    dnbr_rows),
                              dtype=np.float16)

        # Compute RdNBR array
        # Let numpy handle division by 0, but don't output an error message
        np.seterr(divide='ignore')
        rdnbr_array =\
            (dnbr_array - float(self.offset)) /\
            np.sqrt(np.abs(pre_array / 1000))

        # Output RdNBR raster
        rdnbr_out_path = dnbr_clipped.replace('dnbr' + const.TIF_EXT,
                                              'rdnbr' + const.TIF_EXT)
        rdnbr_geotransform = pre_geotransform
        rdnbr_projection = pre_proj

        driver = gdal.GetDriverByName('GTiff')
        rdnbr_ds = driver.Create(rdnbr_out_path,
                                 pre_cols,
                                 pre_rows,
                                 1, gdal.GDT_Int16)
        rdnbr_ds.SetGeoTransform(rdnbr_geotransform)
        rdnbr_ds.SetProjection(rdnbr_projection)
        rdnbr_band = rdnbr_ds.GetRasterBand(1)
        rdnbr_band.WriteArray(rdnbr_array)

        del rdnbr_band
        dnbr_ds = None

        pre_ds = None
        del pre_ds
        pre_band = None
        del pre_band

        pre_array = None
        del pre_array

        # Create QGIS project and add data
        project_out_file = rdnbr_out_path.replace(
            '_rdnbr' + const.TIF_EXT, '.qgs')
        toc_name = 'RdNBR'
        project = QgsProject.instance()
        project.read(project_out_file)

        # Check if layer already exists in project and remove it
        # to replace with new later
        layers = project.layerTreeRoot().children()
        for layer in layers:
            if layer.name() == toc_name:
                project.removeMapLayer(layer.layerId())

        style_dir = os.path.join(
                os.path.dirname(os.path.realpath(__file__)), 'Style')
        style_file = os.path.join(style_dir, "dNBRStyle.qml")
        rst_layer = QgsRasterLayer(rdnbr_out_path, toc_name)
        if not rst_layer.isValid():
            raise IOError(
                "ERROR: Raster layer - {} did not load".format(rdnbr_out_path))
        if not os.path.exists(style_file):
            raise IOError(
                "ERROR: Style file does not exist - {}".format(style_file))
        rst_layer.loadNamedStyle(style_file)
        root = project.layerTreeRoot()
        project.addMapLayer(rst_layer, False)
        root.insertChildNode(2, QgsLayerTreeLayer(rst_layer))
        project.write(project_out_file)
        project.clear()
