'''
/***************************************************************************
Name		     : GenerateMetadata
Description          : Generate Metadata functions for QGIS FMT3 Plugin
copyright            : (C) 2018 by Cheryl Holen
Created              : Sep 10, 2018 - Adapted from QGIS 2.x version
Updated              :
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
 This code is based off of work done by:

    Kelcy Smith
    Contractor to USGS/EROS
    kelcy.smith.ctr@usgs.gov

    Generate metadata for Open Source MTBS

    References:
    http://landsathandbook.gsfc.nasa.gov/data_prod/prog_sect11_3.html
    http://landsathandbook.gsfc.nasa.gov/pdfs/L5TMLUTIEEE2003.pdf

'''
from PyQt5.QtWidgets import QMessageBox

import calendar
import math
import os
from osgeo import ogr
from osgeo import osr

from .Utilities import Utilities
from FMT3 import constants as const

MINPERDEG = 60.0
SECPERDEG = 3600.0


class GenerateMetadata():
    def __init__(self, mapping_id, parent_path):
        self.parent_path = parent_path
        self.mapping_id = mapping_id
        self.driver = ogr.GetDriverByName('ESRI Shapefile')
        self.img_proc_path =\
            os.path.join(self.parent_path, 'img_proc')
        self.event_prods_path =\
            os.path.join(self.parent_path, 'event_prods', 'fire')
        self.utils = Utilities()

    def get_bounding_box(self, x_alb, y_alb, buffer_dist):
        '''
        Get Bounding Box Coords

        :params integer: X Albers Value
        :params integer: Y Albers Value
        :params integer: Buffer distance
        :type integer: Min X value
        :type integer: Min y value
        :type integer: Max X value
        :type integer: Max y value
        '''
        xmin = str(self.utils.odd_even(x_alb) - buffer_dist)
        ymin = str(self.utils.odd_even(y_alb) - buffer_dist)
        xmax = str(self.utils.odd_even(x_alb) + buffer_dist)
        ymax = str(self.utils.odd_even(y_alb) + buffer_dist)

        return xmin, ymin, xmax, ymax

    def get_spatial_ref(self, epsg_code):
        '''
        Spatial Coordinate Transformation
        :params string: Input EPSG code
        :type integer: Spatial reference
        '''
        spatial_ref = osr.SpatialReference()
        spatial_ref.ImportFromEPSG(epsg_code)
        return spatial_ref

    def generate_metadata_process(self):
        ''' Generate the metadata file '''
        # EPSG code for USGS Albers CONUS
        conus_spatial_ref = self.get_spatial_ref(5070)
        # EPSG code for WGS84
        wgs84_spatial_ref = self.get_spatial_ref(4326)

        wgs_coord_transform = osr.CoordinateTransformation(conus_spatial_ref,
                                                           wgs84_spatial_ref)

        # get mapping info
        sql = ("SELECT * FROM 'Mappings' WHERE Id = " + self.mapping_id)
        fetch = self.utils.run_query(sql)

        fire_id = str(fetch[1])
        pre_sensor = str(fetch[3])
        prefire_id = str(fetch[6])
        post_sensor = str(fetch[7])
        postfire_id = str(fetch[10])
        perim_sensor = str(fetch[11])
        perim_id = str(fetch[12])
        dnbr_offset = fetch[16]
        if not dnbr_offset:
            dnbr_offset = 'No RdNBR produced; Single scene assessment'
        low_threshold = fetch[17]
        if not low_threshold:
            low_threshold = 'None'
        moderate_threshold = fetch[18]
        if not moderate_threshold:
            moderate_threshold = 'None'
        high_threshold = fetch[19]
        if not high_threshold:
            high_threshold = 'None'
        perimeter_comment = fetch[25]
        if not perimeter_comment:
            perimeter_comment = 'None'
        no_data_threshold = fetch[35]
        if not no_data_threshold:
            no_data_threshold = 'None'
        increased_green_threshold = fetch[36]
        if not increased_green_threshold:
            increased_green_threshold = 'None'
        comments = fetch[37]
        if not comments:
            comments = 'None'
        fire_x = float(fetch[42])
        fire_y = float(fetch[43])
        pred_strat = str(fetch[44])
        sd_offset = fetch[45]

        # get fire info
        sql = "SELECT * FROM 'Fires' WHERE Id = " + str(fire_id)
        fetch = self.utils.run_query(sql)

        mtbs_id = str(fetch[1])
        mtbs_id_lower = mtbs_id.lower()
        fire_name = str(fetch[2])
        fire_comment = str(fetch[14])

        if fire_comment == '':
            fire_comment = 'None'

        fire_date = fetch[4]
        fire_date_string = str(fire_date).split('-')
        fire_year = fire_date_string[0]
        fire_month = fire_date_string[1]
        fire_day = fire_date_string[2]
        fire_acres = float(fetch[7])
        fire_path = fetch[16]
        fire_row = fetch[17]

        # Get Potential Fire Area
        buffer_dist =\
            round(math.sqrt(float(fire_acres) * 0.404686 * 10000), 0) + 5000
        # Get bounding Box coordinates
        xmin, ymin, xmax, ymax = self.get_bounding_box(fire_x,
                                                       fire_y, buffer_dist)

        # Make Folder
        mapping_dir = os.path.join(self.event_prods_path,
                                   fire_year,
                                   mtbs_id_lower,
                                   'mtbs_' + self.mapping_id)

        if not os.path.exists(mapping_dir):
            os.makedirs(mapping_dir)

        # REFL Images
        _, pre_date, pre_path_row, _ = \
                self.utils.get_image_folder_date_path_row_name(prefire_id)
        post_sub_path, post_date, post_path_row, _ = \
                self.utils.get_image_folder_date_path_row_name(postfire_id)
        _, perim_date, _, _ = \
                self.utils.get_image_folder_date_path_row_name(perim_id)
        
        if post_sub_path == const.LANDSAT:
            self.img_proc_path = os.path.join(self.img_proc_path,
                                              const.LANDSAT)
            sensor = const.LANDSAT.capitalize()
        else:
            self.img_proc_path = os.path.join(self.img_proc_path,
                                              const.SENTINEL2)
            sensor = const.SENTINEL2.capitalize()

        # Get Image clip output names
        pre_refl_clipped = os.path.join(mapping_dir,
                                        (mtbs_id_lower + '_' + pre_date + '_' +
                                         self.utils.get_sensor(prefire_id) +
                                         '_refl' + const.TIF_EXT))

        if os.path.exists(pre_refl_clipped):
            pre_sensor_type = pre_sensor
            pre_sensor_info = (pre_date[0:4] + '-' + pre_date[4:6] + '-' +
                               pre_date[6:8] + ' / ' + prefire_id)
        else:
            pre_sensor_type = 'Not applicable'
            pre_sensor_info = 'Single scene assessment'

        post_refl_clipped = os.path.join(mapping_dir,
                                         (mtbs_id_lower + '_' +
                                          post_date + '_' +
                                          self.utils.get_sensor(postfire_id) +
                                          '_refl' + const.TIF_EXT))
        if os.path.exists(post_refl_clipped):
            post_sensor_type = post_sensor
            post_sensor_info = (post_date[0:4] + '-' + post_date[4:6] + '-' +
                                post_date[6:8] + ' / ' + postfire_id)
        else:
            post_sensor_type = 'Not applicable'
            post_sensor_info = 'Subset Post-fire refl was not found'

        perim_refl_clip = os.path.join(mapping_dir,
                                       (mtbs_id_lower + '_' +
                                        perim_date + '_' +
                                        self.utils.get_sensor(perim_id) +
                                        '_refl.tif'))
        if os.path.exists(perim_refl_clip):
            perim_sensor_type = perim_sensor
            perim_sensor_info = (perim_date[0:4] + '-' +
                                 perim_date[4:6] + '-' +
                                 perim_date[6:8] + ' / ' + perim_id)
        else:
            perim_sensor_type = 'Not applicable'
            perim_sensor_info = 'No Perimeter scene used'

        # Clip Post Nbr
        post_nbr_clipped = os.path.join(mapping_dir,
                                        (mtbs_id_lower + '_' +
                                         post_date + '_nbr' + const.TIF_EXT))

        nbr_threshold = post_nbr_clipped.replace('_nbr' + const.TIF_EXT,
                                                 '_nbr6' + const.TIF_EXT)
        # dNBR Images
        dnbr_output_dir = os.path.join(self.img_proc_path,
                                       (pre_path_row + '_' + pre_date + '_' +
                                        post_path_row + '_' + post_date))
        dnbr_image = ('d' + pre_path_row + '_' + pre_date + '_' +
                      post_path_row + '_' + post_date + '.tif')

        dnbr_output_path = os.path.join(dnbr_output_dir, dnbr_image)

        # Clip dNBR location
        dnbr_clipped = os.path.join(mapping_dir,
                                    (mtbs_id_lower + '_' +
                                     pre_date + '_' + post_date + '_dnbr' +
                                     const.TIF_EXT))

        dnbr_threshold = dnbr_clipped.replace('_dnbr' + const.TIF_EXT,
                                              '_dnbr6' + const.TIF_EXT)

        if os.path.exists(dnbr_clipped):
            threshold_image = dnbr_threshold
        else:
            threshold_image = nbr_threshold

        # Get Rows/Cols of OutputImage
        _, xsize, ysize, _, _, _ =\
            self.utils.get_geo_info(threshold_image)
        cols = str(xsize)
        rows = str(ysize)

        # Thematic NBR or dNBR
        if os.path.exists(dnbr_threshold):
            thematic = os.path.basename(dnbr_threshold)
            themStmt = ('Thematic dNBR; Derived by thresholding dNBR subset '
                        '(8-bit GeoTIFF)')
        else:
            thematic = os.path.basename(nbr_threshold)
            themStmt = ('Thematic NBR; Derived by thresholding post-fire NBR '
                        'subset (8-bit GeoTIFF)')

        # Create Fire Shapefile
        if os.path.exists(dnbr_output_path):
            burn_shape = (mtbs_id_lower + '_' + pre_date + '_' +
                          post_date + '_burn_bndy.shp')
        else:
            burn_shape = (mtbs_id_lower + '_' + post_date + '_burn_bndy.shp')
        burn_boundary = os.path.join(mapping_dir, burn_shape)

        # Get WGS 84 coords
        # Make sure pixel center
        wgs_lower_right = wgs_coord_transform.TransformPoint(
            float(xmax) - self.utils.HALFCELL,
            float(ymin) + self.utils.HALFCELL)
        wgs_upper_left = wgs_coord_transform.TransformPoint(
            float(xmin) + self.utils.HALFCELL,
            float(ymax) - self.utils.HALFCELL)

        north_latitude = str(round(wgs_upper_left[1], 7))
        west_longitude = str(round(wgs_upper_left[0], 7))
        south_latitude = str(round(wgs_lower_right[1], 7))
        east_longitude = str(round(wgs_lower_right[0], 7))

        # Calculate Degrees, Minutes, Seconds
        north_latitude_split = north_latitude.split('.')
        north_latitude_degrees = north_latitude_split[0]
        north_latitude_minutes =\
            str(int(round(float(
                    '0.' + (north_latitude_split[1])[0:2]) * MINPERDEG, 5)))
        north_latitude_seconds = str(round(
                float('0.00' + (north_latitude_split[1])[2:]) * SECPERDEG, 5))
        north_latitude_dms = (north_latitude_degrees + ' ' +
                              north_latitude_minutes + ' ' +
                              north_latitude_seconds)

        west_longitude_split = west_longitude.split('.')
        west_longitude_degrees = west_longitude_split[0]
        west_longitude_minutes =\
            str(int(round(float(
                    '0.' + (west_longitude_split[1])[0:2]) * MINPERDEG, 5)))
        west_longitude_seconds = str(round(float(
                '0.00' + (west_longitude_split[1])[2:]) * SECPERDEG, 5))
        west_longitude_dms = (west_longitude_degrees + ' ' +
                              west_longitude_minutes + ' ' +
                              west_longitude_seconds)

        south_latitude_split = south_latitude.split('.')
        south_latitude_degrees = south_latitude_split[0]
        south_latitude_minutes =\
            str(int(round(float(
                    '0.' + (south_latitude_split[1])[0:2]) * MINPERDEG, 5)))
        south_latitude_seconds = str(round(float(
                '0.00' + (south_latitude_split[1])[2:]) * SECPERDEG, 5))
        south_latitude_dms = (south_latitude_degrees + ' ' +
                              south_latitude_minutes + ' ' +
                              south_latitude_seconds)

        east_longitude_split = east_longitude.split('.')
        east_longitude_degrees = east_longitude_split[0]
        east_longitude_minutes =\
            str(int(round(float(
                    '0.' + (east_longitude_split[1])[0:2]) * MINPERDEG, 5)))
        east_longitude_seconds = str(round(float(
                '0.00' + (east_longitude_split[1])[2:]) * SECPERDEG, 5))
        east_longitude_dms = (east_longitude_degrees + ' ' +
                              east_longitude_minutes + ' ' +
                              east_longitude_seconds)

        center_latitude = str(round((float(north_latitude) +
                                     float(south_latitude)) / 2.0, 5))
        center_latitude_split = center_latitude.split('.')
        center_latitude_degrees = center_latitude_split[0]
        center_latitude_minutes =\
            str(int(round(
                    float('0.' + (center_latitude_split[1])[0:2]) * MINPERDEG,
                    5)))
        center_latitude_seconds =\
            str(round(
                    float('0.00' + (center_latitude_split[1])[2:]) * SECPERDEG,
                    5))
        center_latitude_dms = (center_latitude_degrees + ' ' +
                               center_latitude_minutes +
                               ' ' + center_latitude_seconds)

        center_longitude = str(round((
                float(east_longitude) + float(west_longitude)) / 2.0, 5))
        center_longitude_split = center_longitude.split('.')
        center_longitude_degrees = center_longitude_split[0]
        center_longitude_minutes =\
            str(int(round(float(
                    '0.' + (center_longitude_split[1])[0:2]) * MINPERDEG, 5)))
        center_longitude_seconds =\
            str(round(float(
                    '0.00' + (center_longitude_split[1])[2:]) * SECPERDEG, 5))
        center_longitude_dms = (center_longitude_degrees + ' ' +
                                center_longitude_minutes + ' ' +
                                center_longitude_seconds)

        template_meta_file = os.path.join(self.parent_path,
                                          'templates',
                                          'template_metadata.txt')
        meta_file = burn_boundary.replace('_burn_bndy.shp', '_metadata.txt')

        try:
            with open(template_meta_file, 'r') as temp_meta,\
                    open(meta_file, 'w') as metadata_out:
                for line in temp_meta:
                    if 'Fire_Id' in line:
                        line = line.replace('Fire_Id', mtbs_id)
                    if 'Fire_Name' in line:
                        line = line.replace('Fire_Name', fire_name)
                    if 'F_Date' in line:
                        line = line.replace(
                            'F_Date',
                            (calendar.month_name[int(fire_month)] + ' ' +
                             fire_day + ', ' + fire_year))
                    if 'Assess_Type' in line:
                        line = line.replace('Assess_Type', pred_strat)
                    if 'Fire_Acres' in line:
                        line = line.replace('Fire_Acres', str(fire_acres))
                    if 'Fire_Pr' in line:
                        line = line.replace(
                            'Fire_Pr', str(fire_path) + '/' + str(fire_row))
                    if 'Pre_Sensor_Type' in line:
                        line = line.replace(
                            'Pre_Sensor_Type', pre_sensor_type)
                    if 'Pre_Sensor_Info' in line:
                        line = line.replace(
                            'Pre_Sensor_Info', pre_sensor_info)
                    if 'Post_Sensor_Type' in line:
                        line = line.replace(
                            'Post_Sensor_Type', post_sensor_type)
                    if 'Post_Sensor_Info' in line:
                        line = line.replace(
                                'Post_Sensor_Info', post_sensor_info)
                    if 'Perimeter_Sensor_Type' in line:
                        line = line.replace(
                            'Perimeter_Sensor_Type', perim_sensor_type)
                    if 'Perimeter_Sensor_Info' in line:
                        line = line.replace(
                                'Perimeter_Sensor_Info', perim_sensor_info)
                    # Get center pixel
                    if 'Fix_Ulx' in line:
                        line = line.replace(
                            'Fix_Ulx', str(float(xmin) + self.utils.HALFCELL))
                    if 'Fix_Uly' in line:
                        line = line.replace(
                            'Fix_Uly', str(float(ymax) - self.utils.HALFCELL))
                    if 'Fix_Lrx' in line:
                        line = line.replace(
                            'Fix_Lrx', str(float(xmax) - self.utils.HALFCELL))
                    if 'Fix_Lry' in line:
                        line = line.replace(
                            'Fix_Lry', str(float(ymin) + self.utils.HALFCELL))
                    if 'Fix_Rows' in line:
                        line = line.replace('Fix_Rows', rows)
                    if 'Fix_Cols' in line:
                        line = line.replace('Fix_Cols', cols)
                    # WGS conversion
                    if 'Fix_Nlat' in line:
                        line = line.replace('Fix_Nlat', north_latitude)
                    if 'Nlat_Degrees_Minutes_Seconds' in line:
                        line = line.replace(
                            'Nlat_Degrees_Minutes_Seconds',
                            north_latitude_dms)
                    if 'Fix_Slat' in line:
                        line = line.replace('Fix_Slat', south_latitude)
                    if 'Slat_Degrees_Minutes_Seconds' in line:
                        line = line.replace(
                                'Slat_Degrees_Minutes_Seconds',
                                south_latitude_dms)
                    if 'Fix_Elon' in line:
                        line = line.replace('Fix_Elon', east_longitude)
                    if 'Elon_Degrees_Minutes_Seconds' in line:
                        line = line.replace(
                                'Elon_Degrees_Minutes_Seconds',
                                east_longitude_dms)
                    if 'Fix_Wlon' in line:
                        line = line.replace('Fix_Wlon', west_longitude)
                    if 'Wlon_Degrees_Minutes_Seconds' in line:
                        line = line.replace(
                                'Wlon_Degrees_Minutes_Seconds',
                                west_longitude_dms)
                    if 'Fix_Centlat' in line:
                        line = line.replace(
                                'Fix_Centlat', center_latitude)
                    if 'Centlat_Degrees_Minutes_Seconds' in line:
                        line = line.replace(
                                'Centlat_Degrees_Minutes_Seconds',
                                center_latitude_dms)
                    if 'Fix_Centlon' in line:
                        line = line.replace(
                                'Fix_Centlon', center_longitude)
                    if 'Centlon_Degrees_Minutes_Seconds' in line:
                        line = line.replace(
                                'Centlon_Degrees_Minutes_Seconds',
                                center_longitude_dms)
                    if 'dnbr_offset' in line:
                        line = line.replace('dnbr_offset', str(dnbr_offset))
                    if 'No_Thresh' in line:
                        line = line.replace(
                            'No_Thresh', str(no_data_threshold))
                    if 'Increased_Thresh' in line:
                        line = line.replace(
                            'Increased_Thresh', str(increased_green_threshold))
                    if 'Low_Thresh' in line:
                        line = line.replace('Low_Thresh', str(low_threshold))
                    if 'Mod_Thresh' in line:
                        line = line.replace(
                            'Mod_Thresh', str(moderate_threshold))
                    if 'High_Thresh' in line:
                        line = line.replace(
                            'High_Thresh', str(high_threshold))
                    if 'Thematic' in line:
                        line = line.replace('Thematic', thematic)
                    if 'Them_Statement' in line:
                        line = line.replace('Them_Statement', themStmt)
                    if 'mapping_comments' in line:
                        line = line.replace('mapping_comments', str(comments))
                    if 'Perim_comments' in line:
                        line = line.replace(
                            'Perim_comments', str(perimeter_comment))
                    if 'sd_offset' in line:
                        line = line.replace('sd_offset', str(sd_offset))
                    if 'Sensor_Name' in line:
                        line = line.replace('Sensor_Name', str(sensor))

                    metadata_out.write(line)
        except:
            QMessageBox.warning(None,
                                'Metadata creation failed',
                                'Metadata creation failed',
                                QMessageBox.Ok)
            raise IOError
