
import numpy as np
import psutil
from scipy import stats
import utility_functions as utility_functions
from joblib import Parallel, delayed
import os
from osgeo import gdal

import glob
import datetime as dt

import gc

def get_stats_for_3d_raster_stack_from_file_list(file_path_list):
    print('file list: ', file_path_list[0])

    temp_frame = utility_functions.get_tifflayer(file_path_list[0])
    data_annual = np.zeros((len(file_path_list), temp_frame.shape[0], temp_frame.shape[1]))
    trans, prj = utility_functions.get_geo(file_path_list[0])

    for i in range(len(file_path_list)):
        data_layer = utility_functions.get_tifflayer(file_path_list[i])
        print(file_path_list[i])
        print(data_layer.shape)
        data_annual[i, :, :] = data_layer



    # x = [i.toordinal() for i in dates_sub]

    reg_x = np.array(range(len(file_path_list)))

    # --------------------------------parallel version start

    # its just how do you get the outputs
    # have to return a tuple or dict that has irow,icol,value at the function level
    # then have joblib return a list
    # then sequentially parse the list
    # step 1 make tuples
    # step 2 pass tuples to joblib
    # step 3 define joblib func
    # step 4 save the outputs to the grid array

    rowcol_input_tuples = []

    for irow in range(0, data_annual.shape[1]):
        for icol in range(0, data_annual.shape[2]):
            rc_tuple = (irow, icol)
            rowcol_input_tuples.append(rc_tuple)

    def getLinregressTuples(rc_tuple):
        (irow, icol) = rc_tuple

        clear_idx = np.isfinite(data_annual[:, irow, icol])
        clear_and_nodata_idx = []
        num_observations = 0
        for iclear in range(len(clear_idx)):

            val = data_annual[iclear, irow, icol]
            is_clear_idx_clear = clear_idx[iclear]
            if val != -9999 and is_clear_idx_clear:
            #if val != -9999 and clear_idx[iclear] is True:
                clear_and_nodata_idx.append(True)
                num_observations = num_observations + 1
            else:
                clear_and_nodata_idx.append(False)
                num_observations = num_observations + 0

        if np.sum(clear_and_nodata_idx) > 0:
            # data_slope
            # data_rvalue
            # data_pvalue
            ds, _, dr, dp, _ = stats.linregress(
                reg_x[clear_and_nodata_idx], data_annual[clear_and_nodata_idx, irow, icol])  # slope, intercept, r_value, p_value, std_err
            return (ds, dr, dp, irow, icol,num_observations)

    num_cpu = psutil.cpu_count(logical=True)
    print(num_cpu)

    #bc its 80 otherwise, and causes OOM on LGA
    #10 causes OOM too?
    num_cpu = 5 
    lr_tuples = Parallel(n_jobs=num_cpu)(delayed(getLinregressTuples)(rc_tuple) for rc_tuple in rowcol_input_tuples)

    print("passed_parallel")

    data_slope = np.zeros((data_annual.shape[1], data_annual.shape[2]))
    data_pvalue = np.zeros((data_annual.shape[1], data_annual.shape[2]))
    data_rvalue = np.zeros((data_annual.shape[1], data_annual.shape[2]))

    data_numobservations = np.zeros((data_annual.shape[1], data_annual.shape[2]))

    for value_tuples in lr_tuples:
        if value_tuples is not None:
            (ds, dr, dp, irow, icol,num_observations) = value_tuples
            data_slope[irow, icol] = ds
            data_rvalue[irow, icol] = dr
            data_pvalue[irow, icol] = dp
            data_numobservations[irow, icol] = num_observations


    # --------------------------------parallel version end

    # --------------------------------single thread version start

    '''
    for irow in range(0, data_annual.shape[1]):
        for icol in range(0, data_annual.shape[2]):
            clear_idx = np.isfinite(data_annual[:, irow, icol])
            clear_and_nodata_idx = []
            for iclear in range(len(clear_idx)):
                val = data_annual[iclear, irow, icol]
                
                is_clear_idx_clear = clear_idx[iclear]
                if val != -9999 and is_clear_idx_clear:
                #if val != -9999 and clear_idx[iclear] is True:
                    clear_and_nodata_idx.append(True)
                    data_numobservations[irow, icol] = data_numobservations[irow, icol] + 1
                else:
                    clear_and_nodata_idx.append(False)
                    data_numobservations[irow, icol] = data_numobservations[irow, icol] + 0
            
            #this part is just for debugging
            if np.sum(clear_and_nodata_idx) > 0:
                data_slope[irow, icol], _, data_rvalue[irow, icol], data_pvalue[irow, icol], _ = stats.linregress(reg_x[clear_and_nodata_idx], data_annual[clear_and_nodata_idx, irow, icol])  #slope, intercept, r_value, p_value, std_err
                if data_slope[irow, icol] > 400:
                    view1 = data_annual[clear_and_nodata_idx, irow, icol]
                    ABC = None
    '''
    # --------------------------------single thread version end

    print('Exporting!')
    t_res, y_res, x_res = data_annual.shape
    dim = [x_res, y_res, 1]
    return (data_annual, data_slope, data_rvalue, data_pvalue, t_res, dim, trans, prj,data_numobservations)


def trend_stats(workpath, HV,IATA_CODE, data_product, start_year_inclusive, end_year_inclusive,processing_year_label):
    start_year = start_year_inclusive
    end_year = end_year_inclusive

    inputdir = os.path.join(workpath, HV, 'ANNUAL')
    outDir = os.path.join(workpath, HV, 'TOTAL')

    if not os.path.exists(outDir):
        os.makedirs(outDir,exist_ok=True)

    print(inputdir)
    some_file_paths = np.asarray(glob.glob(inputdir + os.sep + "*"+data_product+"*"))

    '''
    #get the params to name the files the way we want
    extract_dir = os.path.join(workpath, HV, 'extract')
    st_files = np.asarray(glob.glob(extract_dir + os.sep + "*_ST.tif"))
    file_info_tuples = []
    for st_file in st_files:
        file_path = st_file
        file_name = os.path.basename(file_path)
        pieces = file_name.split("_")
        regional_grid_name = pieces[1]
        file_date = dt.datetime.strptime(pieces[3], "%Y%m%d")
        representative_year = str(file_date.year)
        #processing_year = str(dt.datetime.strptime(pieces[4], "%Y%m%d").year)
        processing_year = processing_year_label
        collection_number = pieces[5]
        file_info_tuples.append((file_path,file_date,regional_grid_name,representative_year,processing_year,collection_number))
    (file_path, file_date, regional_grid_name, representative_year, processing_year,
     collection_number) = file_info_tuples[0]

    #file_path_list = [os.path.join(inputdir, '{}_{}_{}_{}_{}_{}_{}.tif'.format("UHI",regional_grid_name,IATA_CODE,iyear,processing_year,collection_number, data_product)) for iyear in
    #                      range(start_year, end_year + 1)]
    '''


    file_path_list = sorted(some_file_paths)

    #-------------

    file_path = file_path_list[0]
    file_name = os.path.basename(file_path)
    pieces = file_name.split("_")

    processing_year = processing_year_label
    regional_grid_name = pieces[1]
    collection_number = pieces[5]

    #-------


    print('file list: ', file_path_list[0])

    (data_annual, data_slope, data_rvalue, data_pvalue, t_res, dim, trans,
     prj,data_numobservations) = get_stats_for_3d_raster_stack_from_file_list(file_path_list)
    print('Exporting!')

    t_res, y_res, x_res = data_annual.shape
    dim = [x_res, y_res, 1]
    save_data_product = data_product+"_"+"SLOPE"
    data_slope_filename = os.path.join(outDir, '{}_{}_{}_{}_{}_{}_{}_{}.tif'.format("UHI",regional_grid_name,IATA_CODE,start_year_inclusive,end_year_inclusive,processing_year,collection_number,save_data_product))
    utility_functions.saveGeoTiff(data_slope_filename,data_slope, dim, trans, prj, dtype=gdal.GDT_Float32)

    save_data_product = data_product + "_" + "PVAL"
    data_pvalue_filename = os.path.join(outDir, '{}_{}_{}_{}_{}_{}_{}_{}.tif'.format("UHI",regional_grid_name,IATA_CODE,start_year_inclusive,end_year_inclusive,processing_year,collection_number,save_data_product))
    utility_functions.saveGeoTiff(data_pvalue_filename,data_pvalue, dim, trans, prj, dtype=gdal.GDT_Float32)

    save_data_product = data_product + "_" + "R2"
    data_rvalue_filename = os.path.join(outDir, '{}_{}_{}_{}_{}_{}_{}_{}.tif'.format("UHI",regional_grid_name,IATA_CODE,start_year_inclusive,end_year_inclusive,processing_year,collection_number,save_data_product))
    utility_functions.saveGeoTiff(data_rvalue_filename,data_rvalue ** 2, dim, trans, prj, dtype=gdal.GDT_Float32)

    save_data_product = data_product + "_" + "NUMCLEAR"
    data_numobservations_filename = os.path.join(outDir, '{}_{}_{}_{}_{}_{}_{}_{}.tif'.format("UHI",regional_grid_name,IATA_CODE,start_year_inclusive,end_year_inclusive,processing_year,collection_number,save_data_product))
    utility_functions.saveGeoTiff(data_numobservations_filename,data_numobservations, dim, trans, prj, dtype=gdal.GDT_Float32)
