'''
/***************************************************************************
Name		     : Utility functions for FMT3
Description          : Utility functions for QGIS FMT3 Plugin
copyright            : (C) 2018 by Cheryl Holen
Created              : Sep 06, 2018
Updated              :
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
'''
import math
import numpy as np
import os
from osgeo import osr
from osgeo import gdal
from osgeo.gdalconst import GA_ReadOnly
import sqlite3
from FMT3 import constants as const


DB_NAME = 'FireInfo.sqlite'


class Utilities():

    HALFCELL = 15.0

    def get_geo_info(self, fileName):
        '''
        Extract geo info from image

        :param string: Input image filename
        :type integer: No data value
              integer: Column count
              integer: Row count
              object: GeoTransform
              string: Projection
              integer: Data type
        '''
        src_ds = gdal.Open(fileName, GA_ReadOnly)
        ndv = src_ds.GetRasterBand(1).GetNoDataValue()
        xsize = src_ds.RasterXSize
        ysize = src_ds.RasterYSize
        geotransform = src_ds.GetGeoTransform()
        proj = osr.SpatialReference()
        proj.ImportFromWkt(src_ds.GetProjectionRef())
        src_data_type = src_ds.GetRasterBand(1).DataType
        src_data_type = gdal.GetDataTypeName(src_data_type)
        return ndv, xsize, ysize, geotransform, proj, src_data_type

    # Get Projection from image
    def get_projection(self, img):
        '''
        Extract projection info from image

        :param string: Input image filename
        :type string: Projection
        '''
        ds = gdal.Open(img)
        ds_proj = ds.GetProjectionRef()
        ds = None
        del ds
        return ds_proj

    def file_to_array(self, fName, data_type=None, cols=None, rows=None,
                      cOffset=None, rOffset=None, band=None):
        '''
        Read the file to an array

        :param string: Input image
        :param string: Data type
        :param integer: Column count
        :param integer: Row count
        :param integer: Column offset
        :param integer: Row offset
        :param integer: Band number
        :type array: Data array
        '''
        src_ds = gdal.Open(fName, GA_ReadOnly)
        src_array = None
        if data_type is None:
            data_type_name = src_ds.GetRasterBand(1).DataType
            data_type_name = gdal.GetDataTypeName(data_type_name)
        else:
            data_type_name = data_type

        if data_type_name == 'UInt16':
            in_type = np.uint16
        elif data_type_name == 'Int16':
            in_type = np.int16
        elif data_type_name == 'UInt8':
            in_type = np.uint8
        elif data_type_name == 'Byte':
            in_type = np.uint8
        elif data_type_name == 'UInt32':
            in_type = np.uint32
        else:
            in_type = None
        if in_type:
            if cols and rows and cOffset and rOffset:
                src_array = src_ds.GetRasterBand(band).ReadAsArray(
                    cOffset, rOffset, cols, rows).astype(in_type)
            else:
                src_array =\
                    src_ds.GetRasterBand(1).ReadAsArray().astype(in_type)

        srcDs = None
        del srcDs
        return src_array
    
    def get_image_folder_date_path_row_name(self, image_id):
        '''
        Extract folder path, date, path/row, and name info from image id

        :param string: Input image id
        :type string: folder sub path
              string: Image date
              string: Image path row
              string: Image name template (landsat/sentinel are different)
        '''
        image_sub_path = 'None'
        image_date = 'None'
        image_path_row = 'None'
        image_name_template = 'None'

        # Landsat
        landsat_prefixes = [i[0] for i in const.SENSOR_ARRAY[0:4]]
        for l_prefix in landsat_prefixes:
            # Check to make sure the indexing will not cause error
            if image_id.startswith(l_prefix):
                image_path_row = image_id[1:7]
                image_sub_path = const.LANDSAT
                image_date = image_id[7:15]
                image_split = image_id.split('_')
                image_name_template = image_split[0] + "_<ext>" + const.TIF_EXT

        # Sentinel2
        sentinel_prefixes = [i[0] for i in const.SENSOR_ARRAY[4:6]]
        for s_prefix in sentinel_prefixes:
            # Check to make sure the indexing will not cause error
            if image_id.startswith(s_prefix):
                path = image_id[1:4]
                image_path_row = image_id[1:6]
                image_sub_path = os.path.join(const.SENTINEL2, path)
                image_date = image_id[6:14]
                image_split = image_id.split('_', 1)
                meters = image_split[1].split('m')[0]
                image_name_template = \
                    image_split[0] + "_<ext>_" + meters + const.TIF_EXT

        return image_sub_path, image_date, image_path_row, image_name_template

    def get_path_row(self, img_folder):
        ''' Get the path and row from a images folder'''
        path = None
        row = None

        if img_folder:
            # Landsat
            landsat_prefixes = [i[0] for i in const.SENSOR_ARRAY[0:4]]
            for l_prefix in landsat_prefixes:
                # Check to make sure the indexing will not cause error
                if img_folder.startswith(l_prefix) and len(img_folder) >= 7:
                    path = img_folder[2:4]
                    row = img_folder[5:7]

            # Sentinel2
            sentinel_prefixes = [i[0] for i in const.SENSOR_ARRAY[4:6]]
            for s_prefix in sentinel_prefixes:
                # Check to make sure the indexing will not cause error
                if img_folder.startswith(s_prefix) and len(img_folder) >= 6:
                    path = img_folder[1:4]
                    row = img_folder[4:6]
        return path, row
    
    def get_resolution(self, img_id):
        ''' Get image resolution for sentinel file'''
        image_split = img_id.split('_', 1)
        meters = image_split[1].split('m')[0]
        return int(meters)

    def odd_even(self, oord):
        '''
        Calculate the correct corners for gdal

        :param float: Input value
        :type integer: Output value
        '''
        num15s = oord // self.HALFCELL
        dec, intpart15 = math.modf(num15s)

        if (oord > 0 and intpart15 % 2 == 0):
            new_val = self.HALFCELL * (num15s + 1)
        elif (oord < 0 and intpart15 % 2 == 0):
            new_val = self.HALFCELL * (num15s + 1)
        else:
            new_val = self.HALFCELL * num15s
        return int(new_val)

    def get_sensor(self, image_name):
        '''
        Get sensor

        :param string: Image name
        :type string: Sensor value
        '''        
        sensor = 'None'
        for sen in const.SENSOR_ARRAY:
            if image_name.startswith(sen[0]):
                sensor = sen[1]
        return sensor

    def run_query(self, sql, multi=0):
        '''
        Run a sql query ( for SELECT )

        :param string: Input query
        :type list: Query result
        '''
        db_dir = os.path.dirname(os.path.realpath(__file__))
        db_name = 'FireInfo.sqlite'
        db_file = os.path.join(db_dir, db_name)
        db_conn = sqlite3.connect(db_file)
        cursor_map = db_conn.cursor()
        cursor_map.execute(sql)
        if multi == 0:
            fetch = cursor_map.fetchone()
        else:
            fetch = cursor_map.fetchall()
        db_conn.close()
        return fetch

    def write_db(self, sql, params=None):
        '''
        Write to db using sql stmt

        :param string: Input stmt
        :param string: Parameters
        '''
        db_dir = os.path.dirname(os.path.realpath(__file__))
        db_name = 'FireInfo.sqlite'
        db_file = os.path.join(db_dir, db_name)
        db_conn = sqlite3.connect(db_file)
        cursor_map = db_conn.cursor()
        if not params:
            cursor_map.execute(sql)
        else:
            cursor_map.execute(sql, params)
        db_conn.commit()
        db_conn.close()

    def fix_os_sep_in_path(self, src_path):
        '''
        Gets the os.sep straightened out
        
        params(string) - src_path - Source filename
        '''
        ret_path = ''
        if '/' in src_path and os.sep != '/':
            ret_path = src_path.replace('/', os.sep)
        elif '\\' in src_path and os.sep != '\\':
            ret_path = src_path.replace('\\', os.sep)
        else:
            ret_path = src_path
        return ret_path

    def fix_os_sep_in_path(self, src_path):
        '''
        Gets the os.sep straightened out
        
        params(string) - src_path - Source filename
        '''
        ret_path = ''
        if '/' in src_path and os.sep != '/':
            ret_path = src_path.replace('/', os.sep)
        elif '\\' in src_path and os.sep != '\\':
            ret_path = src_path.replace('\\', os.sep)
        else:
            ret_path = src_path
        return ret_path