import functools
import json
import random
import time
import weakref

from cytoolz import merge
import requests

from typing import Any
from typing import Callable
from typing import List
from typing import Tuple
from typing import Union


def weak_lru(maxsize=128, typed=False):
    """
    LRU Cache decorator that keeps a weak reference to "self"
    https://stackoverflow.com/a/68052994
    """

    def wrapper(func):
        @functools.lru_cache(maxsize, typed)
        def _func(_self, *args, **kwargs):
            return func(_self(), *args, **kwargs)

        @functools.wraps(func)
        def inner(self, *args, **kwargs):
            return _func(weakref.ref(self), *args, **kwargs)

        return inner

    return wrapper


def retry(retries: int, jitter: Tuple[int, int] = (1, 10)) -> Callable:
    """
    Simple retry decorator, for retrying any function that may
    throw an exception such as when trying to retrieve
    network resources
    """

    def retry_dec(func: Callable) -> Callable:
        def wrapper(*args, **kwargs):
            count = 1
            while True:
                try:
                    return func(*args, **kwargs)
                except Exception:
                    count += 1
                    if count > retries:
                        raise
                    time.sleep(random.randint(*jitter))

        return wrapper

    return retry_dec


class M2M:
    def __init__(self, username: str, password: str, timeout: int = 20):
        self.base_url = 'https://m2m.cr.usgs.gov/api/api/json/stable'
        self.timeout = timeout

        self.session = requests.Session()
        self.session.headers.update({'X-Auth-Token': self._login(username=username,
                                                                 password=password)})

    @retry(5)
    def _api_request(self, verb: str, resource: str, **kwargs) -> dict:
        """
        Make a request to the API, and check for errors
        """
        resp = self.session.request(verb, resource, timeout=self.timeout, **kwargs)
        return self._check_response(resp)

    @staticmethod
    def _check_response(resp: requests.Response) -> dict:
        """
        Check the response object for errors, raise appropriately
        """
        d = resp.json()

        if not resp.ok:
            #             logger.error(f"{d['errorMessage']}")
            resp.raise_for_status()

        return d['data']

    def _post(self, endpoint: str, data: Union[None, dict] = None) -> Union[dict, List[dict]]:
        """
        Make a post request to the API
        """
        return self._api_request('post', '/'.join([self.base_url, endpoint]), data=json.dumps(data))

    def _login(self, username: str, password: str) -> str:
        """
        Login and set the 2-hour token for all future API requests
        """
        data = {'username': username, 'password': password}
        return self._post('login', data=data)

    def logout(self):
        """
        Invalidates the API token
        """
        return self._post('logout')

    @weak_lru(maxsize=1)
    def available_datasets(self) -> List[dict]:
        """
        Retrieve the list of publicly available EarthExplorer datasets, along with metadata
        """
        return self._post('dataset-search',
                          data={'catalog': 'EE',
                                'publicOnly': True})

    @weak_lru(maxsize=1)
    def available_aliases(self) -> List[str]:
        """
        Just retrieve the alias list of publicly available EarthExplorere datasets
        """
        return [d['datasetAlias'] for d in self.available_datasets()]

    @weak_lru(maxsize=4)
    def available_filters(self, alias: str) -> List[dict]:
        """
        Retrieve the list of filters available for the given dataset alias
        """
        return self._post('dataset-filters',
                          data={'datasetName': alias})

    def available_products(self, alias: str, entityids: List[str]) -> List[dict]:
        """
        Available products for the given data set
        """
        return self._post('download-options',
                          data={'datasetName': alias,
                                'entityIds': entityids})

    def download_request(self, products: List[dict], label: str) -> List[dict]:
        """
        Request products to be staged for download

        [{'entityId': str,
          'productId': str}]
        """
        return self._post('download-request',
                          data={'downloads': products,
                                'label': label})

    def ard_query(self, start: str, end: str, region: str, horiz: int, vert: int, maxresults: int) -> dict:
        """
        Query for ARD products
        """
        return self._scene_search('landsat_ard_tile_c2',
                                  maxresults=maxresults,
                                  scenefilters=self._ard_metadata_filters('landsat_ard_tile_c2',
                                                                          region,
                                                                          horiz,
                                                                          vert))

    def _scene_search(self, dataset: str, scenefilters: dict = None, maxresults: int = 100) -> dict:
        """
        Post request against the scene-search endpoint
        """
        base = {'ingestFilter': None,
                'spatialFilter': None,
                'metadataFilter': None,
                'cloudCoverFilter': {'max': 100, 'min': 0},
                'acquisitionFilter': None}

        if scenefilters:
            base = merge(base, scenefilters)

        return self._post('scene-search', data={'datasetName': dataset,
                                                'maxResults': maxresults,
                                                'sceneFilter': base
                                                })

    def _value(self, alias: str, value: Any, fieldlabel: str, operand: str = '=') -> dict:
        """
        Filter for a single value
        operand can be either '=' or 'like'
        """
        return {'filterId': self._find_filterid(alias, fieldlabel),
                'filterType': 'value',
                'value': value,
                'operand': operand}

    def _between(self, alias: str, value1: Any, value2: Any, fieldlabel: str) -> dict:
        """
        Filter for looking at a range of values
        """
        return {'filterId': self._find_filterid(alias, fieldlabel),
                'filterType': 'between',
                'firstValue': value1,
                'secondValue': value2}

    def _filter_metadata(self, filters: List[dict]) -> dict:
        """
        Put the different metadata filters together in the proper structure
        """
        return {'metadataFilter': {'filterType': 'and',
                                   'childFilters': filters}}

    def _ard_metadata_filters(self, alias: str, grid: str, horiz: int, vert: int) -> dict:
        return self._filter_metadata([self._ard_grid(alias, grid),
                                      self._ard_horiz(alias, horiz),
                                      self._ard_vert(alias, vert)])

    def _ard_horiz(self, alias: str, horiz: int) -> dict:
        """
        Put together the h/v search parameters
        Typically used for ARD searches
        """
        return self._between(alias, horiz, horiz, 'Tile Grid Horizontal')

    def _ard_vert(self, alias: str, vert: int) -> dict:
        """
        Put together the h/v search parameters
        Typically used for ARD searches
        """
        return self._between(alias, vert, vert, 'Tile Grid Vertical')

    def _ard_grid(self, alias: str, grid: str) -> dict:
        """
        Put together the grid search parameters
        Typically used for ARD searches
        """
        return self._value(alias, grid, 'Tile Grid Region')

    def _filter_temporal(self, start: str, end: str):
        return {'temporalFilter': {'start': start, 'end': end}}

    def _find_filterid(self, alias: str, fieldlabel: str) -> str:
        """
        Find the filterid for the given dataset based on the fieldLabel
        """
        return [f for f in self.available_filters(alias)
                if f['fieldLabel'] == fieldlabel][0]['id']

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.logout()
        self.session.close()


def ls_sr_bandnumbers(sensor: str) -> List[int]:
    """
    SR band numbers for the given sensor
    LC09 LC08 LE07 LT05 LT04
    """
    if sensor == 'LC09':
        return [2, 3, 4, 5, 6, 7]
    elif sensor == 'LC08':
        return [2, 3, 4, 5, 6, 7]
    elif sensor == 'LE07':
        return [1, 2, 3, 4, 5, 7]
    elif sensor == 'LT05':
        return [1, 2, 3, 4, 5, 7]
    elif sensor == 'LT04':
        return [1, 2, 3, 4, 5, 7]
    else:
        raise ValueError


def ls_bt_bandnumbers(sensor: str) -> List[int]:
    """
    BT band numbers for the given sensor
    LC09 LC08 LE07 LT05 LT04
    """
    if sensor == 'LC09':
        return [10, 11]
    elif sensor == 'LC08':
        return [10, 11]
    elif sensor == 'LE07':
        return [6]
    elif sensor == 'LT05':
        return [6]
    elif sensor == 'LT04':
        return [6]
    else:
        raise ValueError


def ls_bt_std_layers(entityid: str) -> List[str]:
    """
    Standard list of BT bands associated with the given entity id
    """
    return [f'{entityid}_BT_B{b}.TIF'
            for b in ls_bt_bandnumbers(entityid[:4])]


def ls_sr_std_layers(entityid: str) -> List[str]:
    """
    Standard list of needed bands associated with the given entity id
    """
    bands = [f'{entityid}_QA_PIXEL.TIF']
    bands.extend([f'{entityid}_SR_B{b}.TIF'
                  for b in ls_sr_bandnumbers(entityid[:4])])

    return bands


def ard_entityids(api: M2M, region: str, h: int, v: int) -> List[str]:
    """
    Query the M2M and get a list of base entity id's
    """
    return [q['entityId']
            for q in api.ard_query(start=0, end=0, region=region, horiz=h, vert=v, maxresults=5000)['results']]


def ard_order_ids(api: M2M, entity_ids: List[str]) -> List[dict]:
    """
    Identify standard list of needed products for Change Detection needs
    SR/BT/PixelQA
    """
    products = api.available_products('landsat_ard_tile_c2', entity_ids)

    ret = []
    for p in products:
        names = ls_sr_std_layers(p['entityId']) + ls_bt_std_layers(p['entityId'])
        for s in p['secondaryDownloads']:
            # if s['entityId'] in names:
            #     #LC08_CU_002009_20130322_20210501_02_SR_B2
            #     ret.append({'productId': s['id'], 'entityId': s['entityId']})
            
            # if "ST_B" in s['entityId'] and ".TIF" in s['entityId']:
            #     #I modded this to download surface temperature
            #     #'LC08_CU_002009_20130322_20210501_02_ST_TRAD.TIF'
            #     #'LC08_CU_002009_20130322_20210501_02_ST_ATRAN.TIF'
            #     #'LC08_CU_002009_20130322_20210501_02_ST_QA.TIF'
            #     #LC08_CU_002009_20130322_20210501_02_ST_B10 
            #     #need to remove all that dont have ST_B
            #     ret.append({'productId': s['id'], 'entityId': s['entityId']})
            #     ABC = None

            # if "QA" in s['entityId']:
            #     ABC = None
            
            if "QA_PIXEL" in s['entityId']:
                ret.append({'productId': s['id'], 'entityId': s['entityId']})

            if "ST" in s['entityId'] and ".TIF" in s['entityId']:
                #I modded this to download surface temperature
                #'LC08_CU_002009_20130322_20210501_02_ST_TRAD.TIF'
                #'LC08_CU_002009_20130322_20210501_02_ST_ATRAN.TIF'
                #'LC08_CU_002009_20130322_20210501_02_ST_QA.TIF'
                #LC08_CU_002009_20130322_20210501_02_ST_B10 
                #need to remove all that dont have ST_B
                ret.append({'productId': s['id'], 'entityId': s['entityId']})
                ABC = None
            

    print(len(ret))
    ABC = None#
    return ret
