'''
/***************************************************************************
Name		     : FirePrep
Description          : Fire Prep functions for QGIS FMT3 Plugin
copyright            : (C) 2018 by Cheryl Holen
Created              : Sep 05, 2018 - Adapted from QGIS 2.x version
Updated              :
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
 This code is based off of work done by:

    Kelcy Smith
    Contractor to USGS/EROS
    kelcy.smith.ctr@usgs.gov

    Fire Prep for Open Source MTBS

    References:
    http://landsathandbook.gsfc.nasa.gov/data_prod/prog_sect11_3.html
    http://landsathandbook.gsfc.nasa.gov/pdfs/L5TMLUTIEEE2003.pdf

'''
from PyQt5.QtWidgets import QMessageBox
from PyQt5.QtGui import QColor

from qgis.core import (QgsLayerTreeLayer, QgsProject, QgsRasterLayer,
                       QgsVectorLayer, QgsCoordinateReferenceSystem,
                       QgsContrastEnhancement)

import os
from osgeo import ogr
from osgeo import osr

from .Utilities import Utilities
from FMT3 import constants as const


class FirePrep():
    def __init__(self, mapping_id, parent_path, img_src,
                 vectors_styles_files_list, wms_layer_files_list):
        self.parent_path = parent_path
        self.mapping_id = mapping_id
        self.img_src_path = img_src
        self.img_proc_path =\
            os.path.join(self.parent_path, 'img_proc')
        self.event_prods_path =\
            os.path.join(self.parent_path, 'event_prods', 'fire')
        self.vectors_styles_files_list = vectors_styles_files_list
        self.wms_layer_files_list = wms_layer_files_list
        self.utils = Utilities()

    def create_fire_shapefile(self, shape_name, shape_loc, image_loc):
        '''
        Create shape file
        :param string: file name
               string: file path
               string: tif image( to get projection info)
        '''
        if not os.path.exists(shape_loc):
            driver = ogr.GetDriverByName('ESRI Shapefile')
            src_ds = driver.CreateDataSource(shape_loc)
            srs = osr.SpatialReference()
            srs.ImportFromWkt(self.utils.get_projection(image_loc))

            if not shape_name.endswith('_mask.shp'):
                layer = src_ds.CreateLayer('Fire', srs, ogr.wkbPolygon)
                field_id = ogr.FieldDefn('Id', ogr.OFTInteger)
                field_id.SetWidth(6)
                layer.CreateField(field_id)

                field_area = ogr.FieldDefn('Area', ogr.OFTReal)
                field_area.SetWidth(18)
                layer.CreateField(field_area)

                field_perim = ogr.FieldDefn('Perimeter', ogr.OFTReal)
                field_perim.SetWidth(18)
                layer.CreateField(field_perim)

                field_acres = ogr.FieldDefn('Acres', ogr.OFTReal)
                field_acres.SetWidth(18)
                layer.CreateField(field_acres)

                field_fire = ogr.FieldDefn('Fire_ID', ogr.OFTString)
                field_fire.SetWidth(30)
                layer.CreateField(field_fire)

                field_name = ogr.FieldDefn('Fire_Name', ogr.OFTString)
                field_name.SetWidth(50)
                layer.CreateField(field_name)

                field_year = ogr.FieldDefn('Year', ogr.OFTInteger)
                field_year.SetWidth(4)
                layer.CreateField(field_year)

                field_month = ogr.FieldDefn('StartMonth', ogr.OFTInteger)
                field_month.SetWidth(4)
                layer.CreateField(field_month)

                field_day = ogr.FieldDefn('StartDay', ogr.OFTInteger)
                field_day.SetWidth(4)
                layer.CreateField(field_day)

                field_confidence = ogr.FieldDefn('Confidence', ogr.OFTString)
                field_confidence.SetWidth(6)
                layer.CreateField(field_confidence)

                field_comment = ogr.FieldDefn('Comment', ogr.OFTString)
                field_comment.SetWidth(80)
                layer.CreateField(field_comment)

                feature = ogr.Feature(layer.GetLayerDefn())
                feature.Destroy()
                src_ds.Destroy()
            else:
                layer = src_ds.CreateLayer('Fire', srs, ogr.wkbPolygon)

                field_id = ogr.FieldDefn('Id', ogr.OFTInteger)
                field_id.SetWidth(6)
                layer.CreateField(field_id)

                field_area = ogr.FieldDefn('Area', ogr.OFTReal)
                field_area.SetWidth(18)
                layer.CreateField(field_area)

                field_perim = ogr.FieldDefn('Perimeter', ogr.OFTReal)
                field_perim.SetWidth(18)
                layer.CreateField(field_perim)

                field_description = ogr.FieldDefn('Descript', ogr.OFTString)
                field_description.SetWidth(30)
                layer.CreateField(field_description)
                feature = ogr.Feature(layer.GetLayerDefn())
                feature.Destroy()
                src_ds.Destroy()

    def fire_prep_process(self):
        ''' Fire preparation process '''
        sql = ("SELECT * FROM 'Mappings' WHERE id = " + self.mapping_id)
        mapping_fetch = self.utils.run_query(sql)

        fire_id = str(mapping_fetch[1])
        prefire_id = str(mapping_fetch[6])
        postfire_id = str(mapping_fetch[10])
        perim_id = str(mapping_fetch[12])

        sql = "SELECT * FROM 'Fires' WHERE id = " + str(fire_id)
        fire_fetch = self.utils.run_query(sql)

        mtbs_id = fire_fetch[1].lower()
        fire_date = fire_fetch[4]
        fire_date_list = str(fire_date).split('-')
        fire_year = fire_date_list[0]

        # Make Folder
        mapping_dir = os.path.join(self.event_prods_path,
                                   fire_year,
                                   mtbs_id,
                                   'mtbs_' + self.mapping_id)

        if not os.path.exists(mapping_dir):
            os.makedirs(mapping_dir)

        pre_sub_path, pre_date, pre_path_row, pre_name_template = \
                self.utils.get_image_folder_date_path_row_name(prefire_id)
        post_sub_path, post_date, post_path_row, post_name_template = \
                self.utils.get_image_folder_date_path_row_name(postfire_id)
        perim_sub_path, _, perim_path_row, perim_name_template = \
                self.utils.get_image_folder_date_path_row_name(perim_id)

        if post_sub_path == const.LANDSAT:
            self.img_proc_path = os.path.join(self.img_proc_path,
                                              const.LANDSAT)
        else:
            self.img_proc_path = os.path.join(self.img_proc_path,
                                              const.SENTINEL2)

        # REFL Images
        post_refl = os.path.join(self.img_src_path,
                                 post_sub_path,
                                 post_path_row,
                                 postfire_id,
                                 post_name_template.replace(
                                     '<ext>', const.REFL))
        pre_refl = os.path.join(self.img_src_path,
                                pre_sub_path,
                                pre_path_row,
                                prefire_id,
                                pre_name_template.replace(
                                    '<ext>', const.REFL))
        perim_refl = os.path.join(self.img_src_path,
                                  perim_sub_path,
                                  perim_path_row,
                                  perim_id,
                                  perim_name_template.replace(
                                    '<ext>', const.REFL))

        post_nbr = os.path.join(self.img_src_path,
                                 post_sub_path,
                                 post_path_row,
                                 postfire_id,
                                 post_name_template.replace(
                                     '<ext>', const.NBR))

        pre_nbr = os.path.join(self.img_src_path,
                                pre_sub_path,
                                pre_path_row,
                                prefire_id,
                                pre_name_template.replace(
                                    '<ext>', const.NBR))

        perim_nbr = os.path.join(self.img_src_path,
                                  perim_sub_path,
                                  perim_path_row,
                                  perim_id,
                                  perim_name_template.replace(
                                    '<ext>', const.NBR))

        dnbr_out_diff =\
            os.path.join(self.img_proc_path,
                         (pre_path_row + '_' + pre_date + '_' +
                          post_path_row + '_' + post_date))
        dnbr_name = ('d' + pre_path_row + '_' + pre_date + '_' +
                     post_path_row + '_' + post_date + const.TIF_EXT)

        dnbr_out_path = os.path.join(dnbr_out_diff, dnbr_name)

        _, _, _, _, post_proj, _ =\
            self.utils.get_geo_info(post_refl)

        # Create Fire Shapefile
        if os.path.exists(dnbr_out_path):
            burn_shape = (mtbs_id + '_' +
                          pre_date + '_' +
                          post_date + '_burn_bndy.shp')
        else:
            burn_shape = (mtbs_id + '_' +
                          post_date + '_burn_bndy.shp')
        burn_shape_path = os.path.join(mapping_dir, burn_shape)

        if not os.path.exists(burn_shape_path):
            self.create_fire_shapefile(burn_shape, burn_shape_path, post_nbr)
        else:
            QMessageBox.warning(None,
                                'Shapefile creation',
                                burn_shape + ' already exists!',
                                QMessageBox.Ok)

        # Create Mask Shapefile
        mask_shape = burn_shape.replace('_burn_bndy.shp', '_mask.shp')
        mask_shape_path = os.path.join(mapping_dir, mask_shape)

        if not os.path.exists(mask_shape_path):
            self.create_fire_shapefile(mask_shape, mask_shape_path, post_nbr)
        else:
            QMessageBox.warning(None,
                                'Shapefile creation',
                                mask_shape + ' already exists!',
                                QMessageBox.Ok)

        # Create QGIS project and add data
        project_out_file = burn_shape_path.replace('_burn_bndy.shp', '.qgs')

        if os.path.exists(project_out_file):
            ans = QMessageBox.question(
                None, 'Warning Project Already Exists',
                ('Do you want to overwrite the project?'),
                QMessageBox.Yes, QMessageBox.No)
            overwrite = ans == QMessageBox.Yes
        else:
            overwrite = True

        if overwrite:
            if os.path.exists(dnbr_out_path):
                if os.path.exists(perim_refl):
                    success = self.create_qgis_project(
                        burn_shape_path, mask_shape_path, pre_refl, post_refl,
                        perim_refl, pre_nbr, post_nbr, perim_nbr, dnbr_out_path,
                        project_out_file, mapping_fetch, post_proj)
                else:
                    success = self.create_qgis_project(
                        burn_shape_path, mask_shape_path, pre_refl, post_refl,
                        None, pre_nbr, post_nbr, None, dnbr_out_path,
                        project_out_file, mapping_fetch, post_proj)
            else:
                if os.path.exists(perim_refl):
                    success = self.create_qgis_project(
                        burn_shape_path, mask_shape_path, None, post_refl,
                        perim_refl, None, post_nbr, perim_nbr, None,
                        project_out_file, mapping_fetch, post_proj)
                else:
                    success = self.create_qgis_project(
                        burn_shape_path, mask_shape_path, None, post_refl,
                        None, None, post_nbr, None, None,
                        project_out_file, mapping_fetch, post_proj)
        else:
            success = True
        
        if success:
            QMessageBox.information(None,
                                    'Fire Prep Complete',
                                    'Fire prep is Complete',
                                    QMessageBox.Ok)
        else:
            QMessageBox.information(
                None, 'Fire Prep Error',
                'Issue running Fire Prep, could not create QGIS project',
                QMessageBox.Ok)

    def create_qgis_project(self, burn_bndy, fire_mask, pre_refl, post_refl,
                            perim_refl, pre_nbr, post_nbr, perim_nbr,
                            post_dnbr, project_file_path, mapping_fetch, post_proj):
        ''' Create the qgis project, load layers and style files'''
        try:
            project = QgsProject.instance()
            project.setBackgroundColor(QColor('black'))
            crs = QgsCoordinateReferenceSystem()
            proj4 = post_proj.ExportToProj4()
            crs.createFromProj(proj4)
            if not crs.isValid():
                raise BaseException(f"Invalid CRS: {post_proj}")
            project.setCrs(crs)
            style_dir = os.path.join(
                os.path.dirname(os.path.realpath(__file__)), 'Style')
            vector_dir = os.path.join(
                os.path.dirname(os.path.realpath(__file__)), 'Vector_Files')

            # Load in the burn area boundary
            burn_style = os.path.join(style_dir, "BurnAreaBndy.qml")
            self.load_vector_layer(project, burn_bndy, burn_style, 0,
                                   "Burned Area Bndy")

            # Load in the mask
            mask_style = os.path.join(style_dir, "Mask.qml")
            self.load_vector_layer(project, fire_mask, mask_style, -1, "Mask")

            # Load in default vector layers from config
            for file in self.vectors_styles_files_list:
                vector_style = file.split(', ', 2)
                self.load_vector_layer(
                    project, vector_style[1], vector_style[2], -1,
                    vector_style[0])

            for layer in self.wms_layer_files_list:
                try:
                    wms_layer = layer.split(', ', 1)
                    toc_name = wms_layer[0]
                    uri = wms_layer[1]
                    rst_layer = QgsRasterLayer(uri, toc_name, 'wms')
                    if not rst_layer.isValid():
                        raise IOError(
                            "Raster layer - {} did not load".format(uri))
                    root = project.layerTreeRoot()
                    project.addMapLayer(rst_layer, False)
                    root.insertChildNode(-1, QgsLayerTreeLayer(rst_layer))
                except IOError as ex:
                    QMessageBox.warning(
                        None, 'Error!',
                        'Fire Prep Error: ' + str(ex),
                        QMessageBox.Ok)

            # Load in the prefire refl scene
            if pre_refl:
                pre_refl_style = self.choose_style_file(
                    mapping_fetch[3], style_dir)
                self.load_raster_layer(project, pre_refl, pre_refl_style, -1,
                                       "Pre Scene Refl")

            # Load in the postfire refl scene
            # We will always have a post ref layer
            post_refl_style = self.choose_style_file(
                mapping_fetch[7], style_dir)
            self.load_raster_layer(project, post_refl, post_refl_style, -1,
                                    "Post Scene Refl")

            # Load in the perimeter refl scene
            if perim_refl:
                perim_refl_style = self.choose_style_file(
                    mapping_fetch[7], style_dir)
                self.load_raster_layer(project, perim_refl, perim_refl_style, -1,
                                       "Perim Scene Refl")

            # Load in the prefire nbr scene
            if pre_nbr:
                pre_nbr_style = os.path.join(style_dir, "NBRStyle.qml")
                self.load_raster_layer(project, pre_nbr, pre_nbr_style, -1, "Pre Scene NBR")

            # Load in the postfire nbr scene
            # We will always have a post nbr layer
            post_nbr_style = os.path.join(style_dir, "NBRStyle.qml")
            self.load_raster_layer(project, post_nbr, post_nbr_style, -1, "Post Scene NBR")

            # Load in the perimfire nbr scene
            if perim_nbr:
                perim_nbr_style = os.path.join(style_dir, "NBRStyle.qml")
                self.load_raster_layer(project, perim_nbr, perim_nbr_style, -1, "Perim Scene NBR")

            # Load in the postfire dnbr scene
            if post_dnbr:
                post_dnbr_style = os.path.join(style_dir, "dNBRStyle.qml")
                self.load_raster_layer(project, post_dnbr, post_dnbr_style, -1, "Post Scene dNBR")

            project.write(project_file_path)
            project.clear()
        except IOError as ex:
            QMessageBox.critical(
                None, 'Error!',
                'Fire Prep Error: ' + str(ex),
                QMessageBox.Ok)
            return False
        except BaseException as ex:
            QMessageBox.critical(
                None, 'Error!',
                'Fire Prep Error: ' + str(ex),
                QMessageBox.Ok)
            return False
        return True

    def load_raster_layer(self, project, raster_layer, style_file, index,
                          toc_name):
        ''' Load a raster layer and style file into project'''
        rst_layer = QgsRasterLayer(raster_layer, toc_name)
        if not rst_layer.isValid():
            raise IOError(
                "ERROR: Raster layer - {} did not load".format(raster_layer))
        if not os.path.exists(style_file):
            raise IOError(
                "ERROR: Style file does not exist - {}".format(style_file))
        rst_layer.loadNamedStyle(style_file)
        root = project.layerTreeRoot()
        project.addMapLayer(rst_layer, False)
        root.insertChildNode(index, QgsLayerTreeLayer(rst_layer))

        # Style file fixes for Refl files
        if 'Refl' in toc_name:
            fix_layer = project.mapLayersByName(toc_name)[0]
            contrast_enhancement = QgsContrastEnhancement.StretchToMinimumMaximum
            red_band = fix_layer.renderer().redBand()
            green_band = fix_layer.renderer().greenBand()
            blue_band = fix_layer.renderer().blueBand()
            provider = fix_layer.dataProvider()

            red_type = fix_layer.renderer().dataType(red_band)
            green_type = fix_layer.renderer().dataType(green_band)
            blue_type = fix_layer.renderer().dataType(blue_band)

            stats_red = provider.cumulativeCut(red_band,0.02,0.98)
            stats_green = provider.cumulativeCut(green_band,0.02,0.98)
            stats_blue = provider.cumulativeCut(blue_band,0.02,0.98)

            # Red enhancement
            red_enhancement = QgsContrastEnhancement(red_type)
            red_enhancement.setContrastEnhancementAlgorithm(contrast_enhancement,True)
            red_enhancement.setMinimumValue(stats_red[0])
            red_enhancement.setMaximumValue(stats_red[1])
            # Green enhancement
            green_enhancement = QgsContrastEnhancement(green_type)
            green_enhancement.setContrastEnhancementAlgorithm(contrast_enhancement,True)
            green_enhancement.setMinimumValue(stats_green[0])
            green_enhancement.setMaximumValue(stats_green[1])
            # Blue enhancement
            blue_enhancement = QgsContrastEnhancement(blue_type)
            blue_enhancement.setContrastEnhancementAlgorithm(contrast_enhancement,True)
            blue_enhancement.setMinimumValue(stats_blue[0])
            blue_enhancement.setMaximumValue(stats_blue[1])

            # Set the renderers now
            fix_layer.renderer().setRedContrastEnhancement(red_enhancement)
            fix_layer.renderer().setGreenContrastEnhancement(green_enhancement)
            fix_layer.renderer().setBlueContrastEnhancement(blue_enhancement)

            # Repaint layer
            fix_layer.triggerRepaint()

    def load_vector_layer(self, project, vector_layer, style_file, index,
                          toc_name):
        ''' Load a vector layer and style file into project'''
        vlayer = QgsVectorLayer(vector_layer, toc_name)
        if not vlayer.isValid():
            raise IOError(
                "ERROR: Vector layer - {} did not load".format(vector_layer))
        if not os.path.exists(style_file):
            raise IOError(
                "ERROR: Style file does not exist - {}".format(style_file))
        vlayer.loadNamedStyle(style_file)
        root = project.layerTreeRoot()
        project.addMapLayer(vlayer, False)
        root.insertChildNode(index, QgsLayerTreeLayer(vlayer))

    def choose_style_file(self, scene_sensor, style_dir):
        ''' Choose and return the style file for a refl raster'''
        style_file = None
        for sensor in const.SENSOR_ARRAY:
            if scene_sensor.lower() == sensor[1]:
                style_file = os.path.join(style_dir, sensor[2])
        return style_file
