import os
import glob
import tarfile
import calendar
import numpy as np
import rasterio
from rasterio.vrt import WarpedVRT
import rasterio.shutil as rio_shutil
from rasterio.enums import Resampling
from datetime import datetime, timedelta
import itertools
import sys
import uuid

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ process ESPA related ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def cloudmask(qa, date, pr, landsat, testing=False, temp=None):

    # open file for geoproperties to use for outputting raster
    with rasterio.open(qa) as src:
        #print(src.shape)
        # outmeta = src.profile
        outmeta = src.meta
        # outmeta.update(dtype=rasterio.uint8,
        #               count=1,
        # outmeta.update(crs=f'epsg:4326',
        #               compress='lzw')

        qa = src.read()  # read in as array

        #print(qa.shape)
        qa = qa.astype(int)  # convert to int array
        #print(type(qa), qa.shape)
        dilated = np.bitwise_and(qa, 2)  # change value to bit
        #print(dilated.shape)
        cloud = np.bitwise_and(qa, 8)
        cloudshadow = np.bitwise_and(qa, 16)
        snow = np.bitwise_and(qa, 32)
        if landsat == '8' or landsat == '9':
            cirrus = np.bitwise_and(qa, 4)
            group = [dilated, cirrus, cloud, cloudshadow, snow]
        else:
            group = [dilated, cloud, cloudshadow, snow]

        max_mask = np.maximum.reduce(group)
        #print(max_mask.shape)
        mask = np.where(max_mask == 0, 0, 1)  # clear pixels = 0, cloud pixels = 1
        #print(mask)
        #print(mask[np.isnan(mask)])  # = 1
        #print(mask.shape)

    # save cloud mask
    outfile = os.path.join(os.getcwd(), f'cmask_{pr}_{date}.tif')

    with rasterio.open(outfile, 'w', **outmeta) as wrast:
        wrast.write(mask)

    # testing
    if testing:
        with rasterio.open(os.path.join(temp, f'dilated{date}.tif'), 'w', **outmeta) as wrast:
            wrast.write(dilated)
        with rasterio.open(os.path.join(temp, f'cloud{date}.tif'), 'w', **outmeta) as wrast:
            wrast.write(cloud)
        with rasterio.open(os.path.join(temp, f'cloudshadow{date}.tif'), 'w', **outmeta) as wrast:
            wrast.write(cloudshadow)
        with rasterio.open(os.path.join(temp, f'snow{date}.tif'), 'w', **outmeta) as wrast:
            wrast.write(snow)

    return mask

# unzip tar.gz files
def unziptar(gzfiles, tempdir):
    for gz in gzfiles:
        name = os.path.basename(gz)[4:18]   #LC08 02602920230108 02T1-SC20240507180436.tar.gz
        #landsat = os.path.basename(gz)[3:4]
        scratch = os.path.join(tempdir, 'espa_' + name)
        if not os.path.exists(scratch):
            os.makedirs(scratch)
        with tarfile.open(gz, 'r:gz') as tar:
            os.chdir(scratch)
            for member in tar.getmembers():
                tar.extract(member, scratch)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ END process ESPA related ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

#=================================================================================================

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Gapfill_ETf related ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def gapfilling(rfile, int_file_raster_list, limit, outdir, tempdir, testing=False):
    targetraster = rfile
    print('targetraster: ', targetraster)
    name = os.path.basename(rfile)[-12:-4]
    print(name)
    year = name[:4]
    month = name[4:6]
    day = name[6:8]
    print(f'Year: {year}, Month: {month}, Day:{day}')
    scenedate = year + month + day
    DATE_FORMAT = '%Y%m%d'
    targetdate = datetime.strptime(scenedate, DATE_FORMAT)
    targetdateInt = int(targetdate.strftime(DATE_FORMAT))
    start_date = (targetdate - timedelta(days=limit)).strftime(DATE_FORMAT)
    end_date = (targetdate + timedelta(days=limit)).strftime(DATE_FORMAT)
    print(f'Interpolation period between {str(start_date)} and {str(end_date)}')

    temploc = os.path.join(tempdir, 'gap_temp')
    if not os.path.exists(temploc):
        os.makedirs(temploc)

    prevrasters = []
    afterrasters = []
    for r in int_file_raster_list:
        if int(os.path.basename(r)[-12:-4]) > int(start_date) and int(os.path.basename(r)[-12:-4]) < targetdateInt:
            prevrasters.append(r)
        if int(os.path.basename(r)[-12:-4]) < int(end_date) and int(os.path.basename(r)[-12:-4]) > targetdateInt:
            afterrasters.append(r)

    prevMax = len(prevrasters)
    print(f'There are {str(len(prevrasters))} potential Reference Arrays prior to the Target Date')
    afterMax = len(afterrasters)
    print(f'There are {str(afterMax)} potential Reference Arrays after the Target Date')
    prevrasters.sort(reverse=True)
    afterrasters.sort()

    if prevMax == 0 or afterMax == 0:
        print(f'Not enough Reference Rasters within the Max Limit')
        print(f'No interpolation for: {scenedate}')
        with rasterio.open(targetraster) as tras:
            outmeta = tras.meta.copy()
            outmeta.update(driver='GTiff', dtype='float32')
            tras1 = tras.read(1)

            outgfetf = os.path.join(outdir, 'gapfilled_ETf' + scenedate + '.tif')
            with rasterio.open(outgfetf, 'w', **outmeta) as wrast:
                wrast.write(tras1, 1)
        print("Finished with No Interpolation")

    else:
        combos = list(zip(prevrasters, afterrasters))
        # print('length of combos', len(combos))
        # print('the combos \n', combos)
        # count = 0
        # lintCount = 1

        linted_rasters = []
        # if count <= len(combos):
        # print('len(combos)', len(combos))
        for j, combo in enumerate(combos):
            raster1 = combo[0]
            print(raster1)
            raster2 = combo[1]
            print(raster2)

            prevDate = str(os.path.basename(combo[0])[-12:-4])
            refDate1 = datetime.strptime(prevDate, DATE_FORMAT)
            afterDate = str(os.path.basename(combo[1])[-12:-4])
            refDate2 = datetime.strptime(afterDate, DATE_FORMAT)
            timeDiff = int((refDate2 - refDate1).days)
            prevGap = int((targetdate - refDate1).days)
            afterGap = int((refDate2 - targetdate).days)
            print(f'Time Difference: {str(timeDiff)}')
            print(f'Time prior to Target Date: {str(prevGap)}')
            print(f'Time after Target Date:{str(afterGap)}')

            print(f'Starting Linear Interpolation Round: {j}')

            # print(f'j - {j}, combos -1 - {(len(combos) - 1)}')
            # the first one we have to write out to the temp one
            if j == 0:
                # print('doing first')
                with rasterio.open(targetraster) as tras, rasterio.open(raster1) as r1, rasterio.open(raster2) as r2:
                    outmeta = tras.meta.copy()
                    outmeta.update(driver='GTiff', dtype='float32')

                    tras = tras.read(1)
                    r1 = r1.read(1)
                    ra1 = np.ma.masked_less(r1, 0)
                    r1shape = np.shape(ra1)
                    r2 = r2.read(1)
                    ra2 = np.ma.masked_less(r2, 0)
                    r2shape = np.shape(ra2)
                    # if testing:
                    #     with rasterio.open(os.path.join(outdir, f'tras{scenedate}.tif'), 'w', **outmeta) as wrast:
                    #         wrast.write(tras, 1)

                    if r1shape == np.shape(tras) and r2shape == np.shape(tras):
                        slope = (ra2 - ra1) / timeDiff
                        # if testing:
                        #     with rasterio.open(os.path.join(outdir, f'slope{scenedate}.tif'), 'w', **outmeta) as wrast:
                        #         wrast.write(slope, 1)
                        linInterp = ra1 + (slope * prevGap)
                        # if testing:
                        #     with rasterio.open(os.path.join(outdir, f'linInterp{scenedate}.tif'), 'w', **outmeta) as wrast:
                        #         wrast.write(linInterp, 1)
                        # np.copyto(tras, linInterp, 'safe', tras == -9999)
                        # tras[tras == -9999] = linInterp

                        tras[tras == -9999] = np.nan
                        ras = np.where(np.isnan(tras), linInterp, tras)
                        # ras = np.ma.filled(tras, -9999)
                        # ras_scaled = ras * 0.0001
                        ras_final = np.where(ras < 0, 0, ras)
                    else:
                        raise Exception(f'rasters are not in the same shape.')

                if len(combos) == 1:
                    # if there is only one combo you need to write out the final file here
                    # print('only one combo writing final')
                    out_gf_etf = os.path.join(outdir, 'gapfilled_ETf' + scenedate + '.tif')
                    with rasterio.open(out_gf_etf, 'w', **outmeta) as wrast:
                        wrast.write(ras_final, 1)
                else:
                    # print('more than one combo, writing first temp')
                    temp_gap = os.path.join(temploc, 'gap_temp'+scenedate+'.tif')
                    with rasterio.open(temp_gap, 'w', **outmeta) as wrast:
                        wrast.write(ras_final, 1)


                # print(f'Linear Interpolation Round {str(lintCount)} - Complete!')
                # print('--------------------------------------------------------')
                # count += 1
                # lintCount += 1

            # the second+ one we need to overwrite the temp file with increasing gapfills
            elif (j > 0) and (j != (len(combos)-1)):
                print('doing middle')
                with rasterio.open(temp_gap) as tras, rasterio.open(raster1) as r1, rasterio.open(
                    raster2) as r2:
                    outmeta = tras.meta.copy()
                    outmeta.update(driver='GTiff', dtype='float32')

                    tras = tras.read(1)
                    r1 = r1.read(1)
                    ra1 = np.ma.masked_less(r1, 0)
                    r1shape = np.shape(ra1)
                    r2 = r2.read(1)
                    ra2 = np.ma.masked_less(r2, 0)
                    r2shape = np.shape(ra2)
                    # if testing:
                    #     with rasterio.open(os.path.join(outdir, f'tras{scenedate}.tif'), 'w', **outmeta) as wrast:
                    #         wrast.write(tras, 1)

                    if r1shape == np.shape(tras) and r2shape == np.shape(tras):
                        slope = (ra2 - ra1) / timeDiff
                        # if testing:
                        #     with rasterio.open(os.path.join(outdir, f'slope{scenedate}.tif'), 'w', **outmeta) as wrast:
                        #         wrast.write(slope, 1)
                        linInterp = ra1 + (slope * prevGap)
                        # if testing:
                        #     with rasterio.open(os.path.join(outdir, f'linInterp{scenedate}.tif'), 'w', **outmeta) as wrast:
                        #         wrast.write(linInterp, 1)
                        # np.copyto(tras, linInterp, 'safe', tras == -9999)
                        # tras[tras == -9999] = linInterp

                        tras[tras == -9999] = np.nan
                        ras = np.where(np.isnan(tras), linInterp, tras)
                        # ras = np.ma.filled(tras, -9999)
                        # ras_scaled = ras * 0.0001
                        ras_final = np.where(ras < 0, 0, ras)
                    else:
                        raise Exception(f'rasters are not in the same shape.')

                    # temp_file = os.path.join(outdir, 'gapfilled_ETf' + scenedate + '.tif')
                with rasterio.open(temp_gap, 'w', **outmeta) as wrast:
                    wrast.write(ras_final, 1)
                print('middle lint Complete')


            # last one we want to write the final one out.
            elif j == (len(combos)-1):
                print('doing end')
                with rasterio.open(temp_gap) as tras, rasterio.open(raster1) as r1, rasterio.open(
                    raster2) as r2:
                    outmeta = tras.meta.copy()
                    outmeta.update(driver='GTiff', dtype='float32')

                    tras = tras.read(1)
                    r1 = r1.read(1)
                    ra1 = np.ma.masked_less(r1, 0)
                    r1shape = np.shape(ra1)
                    r2 = r2.read(1)
                    ra2 = np.ma.masked_less(r2, 0)
                    r2shape = np.shape(ra2)
                    # if testing:
                    #     with rasterio.open(os.path.join(outdir, f'tras{scenedate}.tif'), 'w', **outmeta) as wrast:
                    #         wrast.write(tras, 1)

                    if r1shape == np.shape(tras) and r2shape == np.shape(tras):
                        slope = (ra2 - ra1) / timeDiff
                        # if testing:
                        #     with rasterio.open(os.path.join(outdir, f'slope{scenedate}.tif'), 'w', **outmeta) as wrast:
                        #         wrast.write(slope, 1)
                        linInterp = ra1 + (slope * prevGap)
                        # if testing:
                        #     with rasterio.open(os.path.join(outdir, f'linInterp{scenedate}.tif'), 'w', **outmeta) as wrast:
                        #         wrast.write(linInterp, 1)
                        # np.copyto(tras, linInterp, 'safe', tras == -9999)
                        # tras[tras == -9999] = linInterp

                        tras[tras == -9999] = np.nan
                        ras = np.where(np.isnan(tras), linInterp, tras)
                        # ras = np.ma.filled(tras, -9999)
                        # ras_scaled = ras * 0.0001
                        ras_final = np.where(ras < 0, 0, ras)
                    else:
                        print(f'rasters are not in the same shape.')

                    out_gf_etf = os.path.join(outdir, 'gapfilled_ETf' + scenedate + '.tif')
                    with rasterio.open(out_gf_etf, 'w', **outmeta) as wrast:
                        wrast.write(ras_final, 1)
                    print(f'wrote the final combo, Year: {year}, Month: {month}, Day:{day} is complete')
            # print(f'Linear Interpolation Round {str(lintCount)}  (final LINT) - Complete!')
            # print('ended combos wrote out ->', out_gf_etf)
            print('j', j, 'len combos', combos)
            print('--------------------------------------------------------')
            # count += 1
            # lintCount += 1

def mutual_extent(etf_dir, mosaic_dir):

    rasters = sorted(glob.glob(os.path.join(etf_dir, '*.tif')))
    print(rasters)

    sample_raster = rasters[0]
    rasters_to_warp = rasters[1:]
    # Set up a list to hold the aligned rasters
    aligned_rasters = [sample_raster]  # starts of with the sample raster because it matters too!
    for ras in rasters_to_warp:
        warp_ras = warp_based_on_sample(ras, temp_folder=mosaic_dir, nodata=-9999, sample_file=sample_raster,
                                        resamplemethod='nearest', outdtype='float32')
        aligned_rasters.append(warp_ras)

    return aligned_rasters

def warp_based_on_sample(input_raster: str, temp_folder: str, nodata=None, sample_file=None,
                         resamplemethod='average', outdtype='float32', generic_temp=False):

    unique_id = uuid.uuid4().hex[:8]
    if generic_temp:
        outname = f'temp_{unique_id}.tif'
    else:
        outname = os.path.basename(input_raster)
    # print('the sample file:', sample_file)
    outwarp = os.path.join(temp_folder, outname)

    with rasterio.open(sample_file) as src:
        out_meta = src.meta
        crs = out_meta['crs']
        transform = out_meta['transform']
        cols = out_meta['width']
        rows = out_meta['height']
        # return out_meta

    if resamplemethod == 'nearest':
        rs = Resampling.nearest
    elif resamplemethod == 'average':
        rs = Resampling.average
    elif resamplemethod == 'bilinear':
        rs = Resampling.bilinear
    else:
        raise Exception('only nearest-neighbor, average and bi-linear resampling are supported at this time')
        sys.exit(0)

    with rasterio.open(input_raster, 'r') as src:
        # create the virtual raster based on the standard rasterio
        # attributes from the sample tiff and shapefile feature.
        # update with suitable nodata values.
        vrt_kwargs = dict(
            resampling=rs,
            crs=crs,
            transform=transform,
            height=rows,
            width=cols,
            dtype=outdtype
        )

        if nodata is not None:
            vrt_kwargs["nodata"] = nodata
        with WarpedVRT(src, **vrt_kwargs) as vrt:
            rio_shutil.copy(vrt, outwarp, driver='GTiff')

        # if nodata is not None:
        #     # print(f'nodata set to {nodata}')
        #     with WarpedVRT(src, resampling=rs,
        #                    crs=crs,
        #                    transform=transform,
        #                    height=rows,
        #                    width=cols,
        #                    nodata=nodata,
        #                    dtype=outdtype) as vrt:
        #         # save the file as an enumerated tiff. reopen outside this loop with the outputs list
        #
        #         rio_shutil.copy(vrt, outwarp, driver='GTiff')
        # else:
        #     # print('Nodata is not set')
        #     with WarpedVRT(src, resampling=rs,
        #                    crs=crs,
        #                    transform=transform,
        #                    height=rows,
        #                    width=cols,
        #                    dtype=outdtype) as vrt:
        #         # save the file as an enumerated tiff. reopen outside this loop with the outputs list
        #         outwarp = os.path.join(temp_folder, outname)
        #         rio_shutil.copy(vrt, outwarp, driver='GTiff')

    return outwarp
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ End Gapfill_ETf related ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#=================================================================================================
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ interpolate_espa_ETf related ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# N/A
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ End interpolate_espa_ETf related ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#=================================================================================================
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ aggregation related ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def check_complete_year(yearly_eta_list, yr):
    first_day = parse_eta_date(yearly_eta_list[0])
    last_day = parse_eta_date(yearly_eta_list[-1])
    # print(f'first day {first_day} and last day {last_day}')

    return first_day == datetime(year=yr, day=1, month=1) and last_day == datetime(year=yr, day=31, month=12)

def get_year_month_tuples(start_year, start_month, end_year, end_month):

    start_date = datetime(int(start_year), int(start_month), 1)
    end_date = datetime(int(end_year), int(end_month), 1)

    year_month_tuples = []
    current_date = start_date

    while current_date <= end_date:
        year_month_tuples.append((current_date.year, current_date.month))
        if current_date.month == 12:
            current_date = datetime(current_date.year + 1, 1, 1)
        else:
            current_date = datetime(current_date.year, current_date.month + 1, 1)

    return year_month_tuples

def parse_eta_date(eta_filepath):

    fname = os.path.basename(eta_filepath)
    dt_str = fname[8:16]
    # print('dt_str:', dt_str)
    dt = datetime.strptime(dt_str, '%Y%m%d')
    return dt

def check_complete_month(monthly_eta_list):

    first_day = parse_eta_date(monthly_eta_list[0])
    last_day = parse_eta_date(monthly_eta_list[-1])
    # print(f'first day: {first_day}, last day: {last_day}')

    try:
        # first day of month
        is_first = first_day.day == 1

        # last day of month
        ldom = calendar.monthrange(last_day.year, last_day.month)[1]
        is_last = last_day.day == ldom

        # returns True if it is the first and last day
        return is_first and is_last

    except ValueError:
        return False

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ End aggregation related ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#=================================================================================================

#================================ Main Tab Functions =============================================

def process_espa_files(input_dir, update_progress, update_status, root):
    """
    Processes ESPA files.

    """

    # Unzipping tar.gz files
    # find all the tar.gz files in the folder 'filedir'
    filedir = input_dir
    gzfiles = sorted(glob.glob(filedir + os.sep + '*.gz'))
    print(f'There are {len(gzfiles)} tar.gz files in the folder.')

    # change directory to 'filedir' and unzip the tar.gz files to folder 'unzipped'
    os.chdir(filedir)
    # print(os.getcwd())
    dir_one_up = os.path.normpath(os.getcwd() + os.sep + os.pardir)
    # unzip function
    unzipdir = os.path.join(dir_one_up, 'unzipped')
    print(f'Unzipping tar.gz files to directory: {unzipdir}')
    unziptar(gzfiles, tempdir=unzipdir)

    # Scaling and applying cloud masking to ET raster
    # Uses the ETf and ETa files from the unzipped folder are scaled and cloud-mask improved
    # and saved in the designated folders organized by path/row.
    # The cloud-masking adjusts for dilated clouds,cirrus clouds,cloud shadows, and snow.
    etfiledir = os.path.join(dir_one_up, 'data')
    # print(f'Creating cloud-masked ETf and ETa data in directory: {etfiledir}')
    dir_contents2 = sorted(os.listdir(unzipdir))
    for i in dir_contents2:
        actdir = unzipdir + os.sep + i
        path_row = f'p{i[6:8]}r{i[9:11]}'
        for folder in ['ETf']: #, 'ETa']:
            if not os.path.exists(os.path.join(etfiledir, path_row, folder)):
                os.makedirs(os.path.join(etfiledir, path_row, folder))
        startTime = datetime.now()
        print(f'start time: {startTime}')
        rasters = sorted(glob.glob(actdir + os.sep + '*.tif'))
        print(rasters)
        first = rasters[0]
        name2 = os.path.basename(first)[:41]
        landsat = name2[3:4]

        date = str(os.path.basename(name2)[17:25])
        print('working on:', date)

        # scale factor for ET rasters
        etf_scale_factor = 0.0001
        #eta_scale_factor = 0.001

        # defining the different ET rasters and the Quality band (qa)
        etfband = os.path.join(actdir, name2 + 'ETF.tif')
        #etaband = os.path.join(actdir, name2 + 'ETa.tif')
        qaband = os.path.join(actdir, name2 + 'QA_PIXEL.tif')

        # Create cloud mask with function
        cmask = cloudmask(qaband, date, path_row, landsat)
        # print(cmask)

        # ETf raster
        # open file for geoproperties to use for outputting raster
        with rasterio.open(etfband) as src:
            outmeta = src.meta
            outmeta.update(dtype=rasterio.float32, compress='lzw')
            #              crs=f'epsg:4326',
            #              compress='lzw')

            # apply scalar and cloud mask
            scaled_etf = src.read()
            scaled_etf = scaled_etf * etf_scale_factor
            # print(scaled_etf)
            masked_etf = np.where(cmask == 0, scaled_etf, np.nan)
            masked_etf = np.where(masked_etf == -0.999900, np.nan, masked_etf)
            # print(masked_etf)

            # save the masked raster
            outetf = os.path.join(etfiledir, path_row, 'ETf', 'etf_' + date + '.tif')
            with rasterio.open(outetf, 'w', **outmeta) as wrast:
                wrast.write(masked_etf)

        print(f'Created ETf raster for {date}')


        # # ETa raster
        # # open file for geoproperties to use for outputting raster
        # with rasterio.open(etaband) as src:
        #     outmeta = src.meta
        #     outmeta.update(dtype=rasterio.float32, compress='lzw')
        #     # apply scalar and cloud mask
        #     scaled_eta = src.read()
        #     scaled_eta = scaled_eta * eta_scale_factor
        #     # print(scaled_eta)
        #     masked_eta = np.where(cmask == 0, scaled_eta, np.nan)
        #     masked_eta = np.where(masked_eta == -9.999, np.nan, masked_eta)
        #     # print(masked_eta)
        #
        # # save the masked raster
        # outeta = os.path.join(etfiledir, path_row, 'ETa', 'eta_' + date + '.tif')
        # with rasterio.open(outeta, 'w', **outmeta) as wrast:
        #     wrast.write(masked_eta)

    update_status('Preprocessing of ESPA files completed.')
    root.update_idletasks()
    update_progress(100)

    return etfiledir

def gapfill_ETf(input_dir, update_progress, update_status, root, limit=48):
    """
    Uses linear interpolation to gapfill bad quality pixels.

    """

    dirpath = input_dir
    os.chdir(dirpath)
    parentdir = os.path.normpath(os.getcwd() + os.sep + os.pardir)
    print(parentdir)
    pr = os.path.basename(dirpath)  # 'p43r33'
    print(pr)
    # window of days to interpolate before and after the target image
    cat = 'ETf'

    rasters = sorted(glob.glob(dirpath + os.sep + cat + os.sep + '*.tif'))

    startTime = datetime.now()
    # Set up output and temp directories
    outdir = os.path.join(dirpath, 'gapfilled' + cat)
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    temp = os.path.join(dirpath, 'temp')
    if not os.path.exists(temp):
        os.makedirs(temp)

    outdir2 = os.path.join(temp, 'Int' + cat)
    if not os.path.exists(outdir2):
        os.makedirs(outdir2)
    temp2 = os.path.join(outdir2, 'temp')
    if not os.path.exists(temp2):
        os.makedirs(temp2)
    outdir3 = os.path.join(temp, 'MutualExtent')
    if not os.path.exists(outdir3):
        os.makedirs(outdir3)

    # Create the mutual extent raster for all scenes
    etf_dir = os.path.join(dirpath, cat)
    etf_rasters_aligned = mutual_extent(etf_dir=etf_dir, mosaic_dir=temp)

    # clip ETf raster to mutual extent, scale and integerize raster to minimize file size
    startTime = datetime.now()
    mut_ras = os.path.join(temp, 'MutualExtent.tif')
    # scale_raster(etf_dir=etf_dir, scale_dir=outdir2, mutual_raster=mut_ras)
    InttimeEnd = str(datetime.now() - startTime)
    print(f'Total Time to Integerize: {str(InttimeEnd)}')

    # gapfilling the scaled ETf rasters using linear interpolation
    newstartTime = datetime.now()

    for i in etf_rasters_aligned:
        print('working on file:', i)
        gapfilling(rfile=i, int_file_raster_list=etf_rasters_aligned, limit=limit, outdir=outdir, tempdir=temp, testing=False)

    print(f'ET Fraction has completed Linear Interpolation')

    LinttimeEnd = str(datetime.now() - newstartTime)
    print(f'Total Time to Interpolate {pr}: {str(LinttimeEnd)}')

def interpolate_espa_ETf(input_dir, pet_dir, pet_name_fmt, climatology, update_progress,
                         update_status, root, etf_out=True):
    """
    Interpolates ESPA ETf
    """
    prstartTime = datetime.now()
    dirpath = input_dir
    petdir = pet_dir
    # in line 44 defined the file naming convention for the reference ET data you are using.

    epath = os.path.join(dirpath, 'gapfilledETf')
    rasters = glob.glob(os.path.join(epath, '*.tif'))
    # rasters = sorted(glob.glob(dirpath + os.sep + 'gapfilledETF' + os.sep + '*.tif'))
    print(f'Rasters: \n{rasters}')

    etf_root = os.path.join(dirpath, 'dailyETf')
    if not os.path.exists(etf_root):
        os.makedirs(etf_root)

    eta_root = os.path.join(dirpath, 'dailyETa')
    if not os.path.exists(eta_root):
        os.makedirs(eta_root)

    temp_root = os.path.join(dirpath, 'temp')
    # should already exist.
    if not os.path.exists(temp_root):
        os.makedirs(temp_root)

    # GELP change to make sure the start and end of gapfilling is the same for os and arc.
    for i in range(0, len(rasters[:-1])):
        current = rasters[i]
        nextone = rasters[i + 1]
        print(f'current raster: {os.path.basename(current)[-12:-4]}')
        print(f'next raster: {os.path.basename(nextone)[-12:-4]}')
        year = os.path.basename(current)[-12:-8]
        month = os.path.basename(current)[-8:-6]
        day = os.path.basename(current)[-6:-4]
        cur_date = datetime(year=int(year), month=int(month), day=int(day))
        next_date = datetime(year=int(os.path.basename(nextone)[-12:-8]),
                             month=int(os.path.basename(nextone)[-8:-6]),
                             day=int(os.path.basename(nextone)[-6:-4]))
        # print(f'current date: {cur_date}')
        # print(f'next date: {next_date}')


        # change file name string to match your reference et naming convention.
        # Here it is: pet_JJJ.tif, with JJJ = 3 digit day of the year
        if climatology:
            if pet_name_fmt == 'pet_ddd':
                pet_name = 'pet_' + datetime.strftime(cur_date, format='%j') + '.tif'
                pet = os.path.join(petdir, pet_name)
                print(f'pet raster: {pet}')
            else:
                raise KeyError(f"No supported pet name format {pet_name_fmt}, only 'pet_ddd' is supported at this time")
        else:
            raise Exception('Only Daily Climatology Rasters are Accepted at this time.')


        # generic temp set to true because we dont' want to write out the file over and over in the temp folder.
        warped_pet = warp_based_on_sample(input_raster=pet, temp_folder=temp_root, sample_file=current,
                                          generic_temp=True)

        with rasterio.open(current) as curras_src, rasterio.open(nextone) as nextras_src, rasterio.open(
                warped_pet) as petras_src:

            # read the bands
            curras = curras_src.read(1)
            nextras = nextras_src.read(1)
            petras = petras_src.read(1)
            # copy metadata
            outmeta = curras_src.meta
            outmeta.update({'driver': 'GTiff', 'dtype': 'float32'})

            DATE_FORMAT = '%Y%m%d'
            currentdate = datetime.strptime(current[-12:-4], DATE_FORMAT)
            nextdate = datetime.strptime(nextone[-12:-4], DATE_FORMAT)
            currentdate = cur_date
            nextdate = next_date
            timeDiff = int((nextdate - currentdate).days)
            # print('time Diff', timeDiff)

            targdate = str(currentdate.year) + str(currentdate.month).zfill(2) + str(currentdate.day).zfill(2)
            # print(f'{targdate}')

            # Create daily ETf raster
            etfout = os.path.join(etf_root, f'dailyETf{targdate}.tif')
            with rasterio.open(etfout, 'w', **outmeta) as wras:
                wras.write(curras, 1)
            print(f'Created daily ETf {targdate}')

            # Create daily ETa raster
            etr = petras * 0.01
            eta = curras * etr
            etaout = os.path.join(dirpath, 'dailyETa', f'dailyETa{targdate}.tif')
            with rasterio.open(etaout, 'w', **outmeta) as wras:
                wras.write(eta, 1)
            print(f'Created daily ETa {targdate}')

            # full Interpolation step
            if timeDiff > 0:
                slope = (nextras - curras) / timeDiff
                for n in range(1, timeDiff):
                    interpDate = currentdate + timedelta(days=n)
                    # print(f'interpolation date: {interpDate}')
                    interpolateddate = str(interpDate.year) + str(interpDate.month).zfill(2) + str(
                        interpDate.day).zfill(2)
                    lint = curras + (slope * n)
                    pet2 = os.path.join(petdir, 'pet_' + datetime.strftime(interpDate, format='%j') + '.tif')

                    # generic temp set to true because we dont' want to write out the file over and over in the temp folder.
                    warped_pet_2 = warp_based_on_sample(input_raster=pet2, temp_folder=temp_root,
                                                        sample_file=current, generic_temp=True)

                    # Create daily interpolated ETf raster
                    with rasterio.open(warped_pet_2) as petras2:
                        petras2_arr = petras2.read(1)
                        etr2 = petras2_arr * 0.01

                        outfile1 = os.path.join(etf_root, f'dailyETf{interpolateddate}.tif')
                        if os.path.isfile(outfile1) == False:
                            with rasterio.open(outfile1, 'w', **outmeta) as wras:
                                wras.write(lint, 1)
                        print(f'Created daily ETf {interpolateddate}')

                        # Create daily interpolated ETa raster
                        interETa = lint * etr2
                        outfile2 = os.path.join(eta_root, f'dailyETa{interpolateddate}.tif')
                        with rasterio.open(outfile2, 'w', **outmeta) as wras:
                            wras.write(interETa, 1)
                        print(f'Created daily ETa {interpolateddate}')

    TimeEnd = str(datetime.now() - prstartTime)
    print(f'Total Time to Process: {str(TimeEnd)}')

def monthly_ETa(input_dir, year, progress, status_msg, root):
    """
    Aggregates Daily ETa to Monthly
    """
    parent_dir = os.path.dirname(input_dir)
    outdir = os.path.join(parent_dir, 'monthlyETA')
    if not os.path.exists(outdir):
        os.makedirs(outdir)

    eta_list = sorted(glob.glob(os.path.join(input_dir, '*.tif')))
    print(eta_list)

    start_month = 1
    end_month = 12
    for i in range(start_month, (end_month + 1)):  # loop through month 1,2,..12,
        print('Month to be accumulated is: ' + str(i))
        Listras = []
        for et_path in eta_list:
            et_in = os.path.basename(et_path)
            if int(et_in[8:12]) == year:
                mon = int(et_in[12:14])  # 2-digit month in dailt ETa file name
                if mon == i:  # if month = i then append grid to list for summing up
                    Listras.append(et_path)
        print('rasters for month ' + str(i) + ': ', Listras)
        if Listras == []:
            print('No data for month ' + str(i) + ' available..continue to next month')
            continue
        elif not check_complete_month(Listras):
            print(f'data for month {i} is incomplete. Continue to next month')
            continue
        else:

            for j, ras in enumerate(Listras):

                if j == 0:
                    # open file for geoproperties to use for outputting raster
                    with rasterio.open(Listras[j], 'r') as src:
                        outmeta = src.meta
                        # instantiate the cumulative raster.
                        cum_arr = src.read(1)
                else:
                    with rasterio.open(Listras[j], 'r') as src:
                        daily_arr = src.read(1)
                        cum_arr += daily_arr

            # save the summed up ET raster
            prodname = f'eta_{year}{i:02d}.tif'
            output_Tif = os.path.join(outdir, prodname)
            with rasterio.open(output_Tif, 'w', **outmeta) as wrast:
                wrast.write(cum_arr, 1)

def annual_ETa(input_dir, year, progress, status_msg, root):
    """
    Aggregates Daily ETa to annual cumulative ETa
    """

    # the yearly rasters
    eta_list = sorted(glob.glob(os.path.join(input_dir, f'dailyETa{year}*.tif')))
    print(eta_list)

    parent_dir = os.path.dirname(input_dir)
    outdir = os.path.join(parent_dir, 'yearlyETA')

    if not check_complete_year(eta_list, int(year)):
        raise Exception(f'Incomplete set of daily rasters for year {year}, you must process '
                        f'more rasters prior to aggregating.')
        # open file for geoproperties to use for outputting raster
    else:
        for j, ras in enumerate(eta_list):

            if j == 0:
                # open file for geoproperties to use for outputting raster
                with rasterio.open(eta_list[j], 'r') as src:
                    outmeta = src.meta
                    # instantiate the cumulative raster.
                    cum_arr = src.read(1)
            else:
                with rasterio.open(eta_list[j], 'r') as src:
                    daily_arr = src.read(1)
                    cum_arr += daily_arr

            # save the summed up ET raster
            prodname = f'eta_{year}.tif'
            output_Tif = os.path.join(outdir, prodname)
            with rasterio.open(output_Tif, 'w', **outmeta) as wrast:
                wrast.write(cum_arr, 1)
        print('Created annual raster!')

def seasonal_ETa(input_dir, start_month, end_month, year, progress, status_msg, root):
    """Aggregates Seasonal ETa. If Start month is greater than end month, this function
     assumes the season stretches into the next calendar year."""

    parent_dir = os.path.dirname(input_dir)
    outdir = os.path.join(parent_dir, 'seasonalETA')
    if not os.path.exists(outdir):
        os.makedirs(outdir)

    number_of_month = (end_month + 1) - start_month
    print(number_of_month)
    if start_month < end_month:
        # month_delta = end_month - start_month
        end_year = year
    elif start_month <= end_month:
        # we assume the season crosses calendar year.
        # first_tranche = 12 - start_month
        # month_delta = first_tranche + end_month
        end_year = year + 1

    yr_mo_lst = get_year_month_tuples(start_year=year, start_month=start_month, end_year=end_year, end_month=end_month)

    print('year month list: \n', yr_mo_lst)

    eta_list_of_lists = []
    for ym in yr_mo_lst:
        # unpack tuple
        y, m = ym
        etas_ym = sorted(glob.glob(os.path.join(parent_dir, 'dailyETa', f'dailyETa{y:04}{m:02}*.tif')))
        if not check_complete_month(etas_ym):
            raise Exception(f'Month {m} of year {y} does not have a complete set of daily ETa rasters, '
                            f'therefore the season cannot be aggregated. Try a different date range or '
                            f'adjust your interpolation.')
        eta_list_of_lists.append(etas_ym)

    eta_list = list(itertools.chain(*eta_list_of_lists))  # flattens the list of lists

    print('eta list of files for season: \n', eta_list)

    for j, ras in enumerate(eta_list):

        if j == 0:
            # open file for geoproperties to use for outputting raster
            with rasterio.open(eta_list[j], 'r') as src:
                outmeta = src.meta
                # instantiate the cumulative raster.
                cum_arr = src.read(1)
        else:
            with rasterio.open(eta_list[j], 'r') as src:
                daily_arr = src.read(1)
                cum_arr += daily_arr

    # save the summed up ET raster
    prodname = f'eta_{year}{start_month:02d}_{end_year}{end_month:02d}.tif'
    output_Tif = os.path.join(outdir, prodname)
    with rasterio.open(output_Tif, 'w', **outmeta) as wrast:
        wrast.write(cum_arr, 1)

