"""
Search and download M2M data, skip those already downloaded
"""

import os
import sys
import requests
from requests.exceptions import HTTPError
import json
import getpass
import urllib3
import time
import datetime
import logging
import traceback
import multiprocessing
from argparse import ArgumentParser

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
stream = logging.StreamHandler()
stream.setLevel(logging.DEBUG)
LOG_FORMAT = ("%(asctime)s [%(levelname)s]: %(message)s")
formatter = logging.Formatter(LOG_FORMAT)
stream.setFormatter(formatter)
logger.addHandler(stream)


class M2MError(Exception):
    pass


class M2M(object):
    """
    Web-Service interface for EarthExplorer JSON Machine-to-Machine API

    https://m2m.cr.usgs.gov/api/docs/json/

    """

    def __init__(self, instance='ops'):
        url_lookup = dict(
            ops     = 'https://m2m.cr.usgs.gov/api/api/json/stable/',
            devmast = 'https://m2mdevmast.cr.usgs.gov/api/api/json/stable/',
            devsys  = 'https://m2mdev.cr.usgs.gov/devsys/api/api/json/stable/'
        )
        self.baseurl = url_lookup.get(instance)
        self.product_name_dict = dict(LANDSAT='Level-1 GeoTIFF Data Product',
                                      landsat_tm_c2_l1='Landsat Collection 2 Level-1 Product Bundle',
                                      landsat_etm_c2_l1='Landsat Collection 2 Level-1 Product Bundle',
                                      landsat_ot_c2_l1='Landsat Collection 2 Level-1 Product Bundle',
                                      landsat_tm_c2_l2='Landsat Collection 2 Level-2 Product Bundle',
                                      landsat_etm_c2_l2='Landsat Collection 2 Level-2 Product Bundle',
                                      landsat_ot_c2_l2='Landsat Collection 2 Level-2 Product Bundle',
                                      ARD_TILE='Surface Reflectance',
                                      SENTINEL='L1C Tile in JPEG2000 format'
                                      )
        self.ard_product_lookup = dict(
            SR='Surface Reflectance',
            QA='Quality Assessment',
            TOA='Top of Atmosphere',
            ST='Provisional Surface Temperature',
            BT='Brightness Temperature'
        )

    @staticmethod
    def _parse(response):
        """
        Attempt to parse the JSON response, which always contains additional
        information that we might not always want to look at (except on error)

        :param response: requests.models.Response
        :return: dict
        """
        data = None

        try:
            response.raise_for_status()
        except HTTPError as e:
            err_msg = 'Unable to reach M2M API: {}'.format(e)
            logger.error(err_msg)
            raise M2MError(err_msg)

        try:
            data = response.json()
        except ValueError as e:
            err_msg = ('unable to parse JSON response. {}\n'
                       'traceback:\n{}'.format(e, traceback.format_exc()))
            logger.error(err_msg)
            raise M2MError(err_msg)

        if data.get('errorCode'):
            err_msg = '{errorCode}: {errorMessage}'.format(**data)
            logger.error(err_msg)
            raise M2MError(err_msg)
    
        if 'data' not in data:
            err_msg = 'no data found in response:\n{}'.format(data)
            logger.error(err_msg)
            raise M2MError(err_msg)

        return data

    def _api_request(self, verb, resource, data, headers=None):
        url = self.baseurl + resource
        data_ = {k: v if k != 'password' else 'xxxxx' for k, v in data.items()}
        logger.debug('POST {} {}'.format(url, data_))
        response = getattr(requests, verb)(url, headers=headers, data=json.dumps(data))
        return self._parse(response)

    def login(self, username, password=None, **kwargs):
        if password is None:
            password = getpass.getpass('Password (%s): ' % username)
        payload = {'username': username, 'password': password}
        return self._api_request('post', 'login', payload).get('data')

    def scene_search(self, headers, data):
        return self._api_request('post', 'scene-search', data, headers).get('data')

    def get_product_names(self, dataset, products):
        lookup = self.product_name_dict
        if dataset == 'ARD_TILE':
            if products:
                products = products.split(',')
                products = [self.ard_product_lookup.get(p) for p in products]
            else:
                products = ['Surface Reflectance']
        else:
            products = [lookup.get(list(filter(lambda x: dataset.startswith(x), lookup.keys()))[0])]

        return products

    def download_options(self, headers, entity_ids, dataset, products):
        params = {
            'entityIds': entity_ids,
            'datasetName': dataset
        }
        resp = self._api_request('post', 'download-options', params, headers)
        product_names = self.get_product_names(dataset, products)
        product_info = filter(lambda i: i['productName'].strip() in product_names, resp.get('data'))
        return [(p.get('entityId'), p.get('id')) for p in product_info]

    def download_request(self, headers, entity_ids, download_ids):
        params = {
            'downloads': [
                {'entityId': e,
                 'productId': i}
                for e, i in zip(entity_ids, download_ids)
            ],
            'downloadApplication': 'EE'
        }
        resp = self._api_request('post', 'download-request', params, headers)

        data = resp.get('data')
        avail = data.get('availableDownloads')
        prep  = data.get('preparingDownloads')

        urls = list()
        urls.extend(avail)
        urls.extend(prep)
        urls = filter(lambda x: x, urls)

        return [x.get('url') for x in urls]

    @staticmethod
    def additionalCriteriaValues(h=None, v=None, s=None, gd=None, p=None, r=None, sc=None, tile_number=None):
        k = 'metadataFilter'
        additional = {k: {"filterType": "and", "childFilters": []}}
        if h:
            additional[k]['childFilters'].append({"filterType": "between", "filterId": "5e83a38b7e5e872c", "firstValue": h, "secondValue": h})
        if v:
            additional[k]['childFilters'].append({"filterType": "between", "filterId": "5e83a38b5e3a58ab", "firstValue": v, "secondValue": v})
        if s:
            additional[k]['childFilters'].append({"filterType": "value", "filterId": "5e83a38b9bdb525f", "value": s})
        if sc:
            additional[k]['childFilters'].append({"filterType": "value", "filterId": "5e83a38b7b117c3d", "value": sc})
        if gd:
            additional[k]['childFilters'].append({"filterType": "between", "filterId": "5e83a38b4357a01b", "firstValue": gd, "secondValue": gd})
        if p:
            additional[k]['childFilters'].append({"filterType": "between", "filterId": "5e83d0b81d20cee8", "firstValue": p, "secondValue": p})
        if r:
            additional[k]['childFilters'].append({"filterType": "between", "filterId": "5e83d0b849ed5ee7", "firstValue": r, "secondValue": r})
        if tile_number:  # For Sentinel-2 A,B
            additional[k]['childFilters'].append({"filterType": "value", "filterId": "5e83a42cc36e732d", "value": tile_number,
                                                  "operand": "like"})
        return additional

    @staticmethod
    def temporalCriteria(ad):
        dates = ad.split(',')
        sd, ed = dates if len(dates) == 2 else dates * 2
        return {"acquisitionFilter": {"start": sd, "end": ed}}


def download_url(x):
    # We need to get the redirect URL first
    fileurl, directory = x[0], x[1]
    head = requests.head(fileurl, timeout=60)
    location = head.headers.get('Location')

    base_url = fileurl.split('/download-staging')[0]
    fileurl = base_url + location
    head = requests.head(fileurl, timeout=60)

    filename = head.headers['Content-Disposition'].split('filename=')[-1].strip('"')

    local_fname = os.path.join(directory, filename)
    if os.path.exists(local_fname):
        logger.warning('Already exists - skipping: %s \n' % local_fname)
        return

    file_size = None
    if 'Content-Length' in head.headers:
        file_size = int(head.headers['Content-Length'])
    bytes_recv = 0
    if os.path.exists(local_fname + '.part'):
        bytes_recv = os.path.getsize(local_fname + '.part')

    logger.info("Downloading %s ... \n" % local_fname)
    resume_header = {'Range': 'bytes=%d-' % bytes_recv}
    sock = requests.get(fileurl, headers=resume_header, timeout=10,
                        stream=True, verify=False, allow_redirects=True)

    start = time.time()
    f = open(local_fname + '.part', 'ab')
    bytes_in_mb = 1024*1024
    for block in sock.iter_content(chunk_size=bytes_in_mb):
        if block:
            f.write(block)
            bytes_recv += len(block)
    f.close()
    ns = time.time() - start
    mb = bytes_recv/float(bytes_in_mb)
    logger.info("%s (%3.2f (MB) in %3.2f (s), or  %3.2f (MB/s)) \n" % (filename, mb, ns, mb/ns))

    if bytes_recv >= file_size:
        os.rename(local_fname + '.part', local_fname)


def download_url_wrapper(x):
    try:
        download_url(x)
    except Exception as e:
        logger.warning('\n\n *** Failed download %s: %s \n' % (x, str(e)))


def chunkify(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]


def get_product_ids(results):
    return [x['displayId'] for x in results], [y['entityId'] for y in results]


def download_files(directory, h=None, v=None, p=None, r=None, username=None, dataset=None, tile_number=None,
                   sensor=None, spacecraft=None, N=50000, products=None, acq_date=None, gen_date=None,
                   threads=40, search_only=False, instance='ops'):
    """
    Search for and download files to local directory

        Args:
            directory: Relative path to local directory (will be created)
            h: Tile Grid Horizontal [Optional]
            v: Tile Grid Vertical [Optional]
            p: The WRS2 path [Optional]
            r: The WRS2 row [Optional]
            username: ERS Username (with full M2M download access) [Optional]
            dataset: EarthExplorer Catalog datasetName [e.g. ARD_TILE or SENTINAL_2A]
            tile_number: For SENTINEL-2
            sensor: Satellite instrument [All, OLI_TIRS, ETM, TM]
            spacecraft: Satellite platform [Landsat 4, 5, 7, 8]
            N: Maximum number of search results to return
            products: Comma-delmited list of download products as a single string [e.g 'TOA,BT,SR,QA']
            acq_date: Search Date image acquired [Format: %Y-%m-%d]
            gen_date: Tile production date [Format: %Y-%m-%d]
            batch: How many URLs to request before working on downloads
            threads: Number of download threads to launch in parallel
            search_only: Boolean, if true only show results, don't download
            instance: What instance of M2M to use (devsys, devmast, ops)

    """
    if not directory and not search_only:
        logger.error('Must specify download directory')
        sys.exit(0)

    if directory and not os.path.exists(directory):
        os.makedirs(directory)

    m2m = M2M(instance)

    token = m2m.login(username=username or input('Enter ERS username: '))
    header = {'X-Auth-Token': token}
    
    scene_filter = dict(
        ingestFilter      = None,
        spatialFilter     = None,
        metadataFilter    = None,
        cloudCoverFilter  = {'max': 100, 'min': 0},
        acquisitionFilter = None
    )
    
    if any([h, v, p, r, sensor, gen_date, tile_number]):
        scene_filter.update(m2m.additionalCriteriaValues(h=h, v=v, p=p, r=r, s=sensor, sc=spacecraft, gd=gen_date, tile_number=tile_number))
    
    if acq_date:
        scene_filter.update(m2m.temporalCriteria(ad=acq_date))

    search = dict(
        datasetName = dataset, 
        maxResults  = N,
        sceneFilter = scene_filter)
    
    logger.info('Full M2M Search Parameters: {}\n'.format(search))
    results = m2m.scene_search(header, search)

    n_results = results.get('totalHits')
    display_ids, entity_ids = get_product_ids(results.get('results'))

    logger.info('Total search results: %d \n' % n_results)

    if len(display_ids) < 1:
        logger.warning('No results found!')
        sys.exit(0)

    if not search_only:
        download_info = m2m.download_options(header, entity_ids, dataset, products)
        e_ids = [x[0] for x in download_info]
        d_ids = [x[1] for x in download_info]

        urls = [(url, directory) for url in m2m.download_request(header, e_ids, d_ids)]

        pool = multiprocessing.Pool(threads)

        pool.map_async(download_url_wrapper, urls).get(600)
        pool.close()
        pool.join()

    else:
        now = datetime.datetime.now()
        name = datetime.datetime.strftime(now, '%Y-%m-%d_%H:%M:%S')
        name = 'results_{}.txt'.format(name)
        text_output = os.path.join(directory, name)
        with open(text_output, 'w') as f:
            for i in display_ids:
                f.write('{}\n'.format(i))
        logger.info('SEARCH RESULTS: ')
        logger.info('{}'.format(display_ids))
        logger.info('Writing results to {}'.format(text_output))

    return None


def build_command_line_arguments():
    description = __doc__
    parser = ArgumentParser(description=description, add_help=False)
    req_parser = parser.add_argument_group('required arguments')
    parser.add_argument('--help', action='help', help='show this help message and exit')
    req_parser.add_argument('-d', '--directory', type=str, dest='directory',
                        help='Relative path to download all data')
    req_parser.add_argument('-u', '--username', type=str, dest='username', default=None,
                        help='ERS Username (with full M2M download access)')
    parser.add_argument('-t', '--threads', type=int, dest='threads', default=40,
                        help='Number of parallel download threads [Default: 40]')
    parser.add_argument('-m', '--max', type=int, dest='N', default=50000,
                        help='Maximum number of Tile results to return [Default: 50000]')
    parser.add_argument('-p', '--path', type=int, dest='p', default=None,
                        help='WRS-2 Path [Default: None]')
    parser.add_argument('-r', '--row', type=int, dest='r', default=None,
                        help='WRS-2 Row [Default: None]')
    parser.add_argument('-h', '--horizontal', type=int, dest='h', default=None,
                        help='ARD Tile Grid Horizontal [Default: All]')
    parser.add_argument('-v', '--vertical', type=int, dest='v', default=None,
                        help='ARD Tile Grid Vertical [Default: All]')
    parser.add_argument('--tile-number', type=str, dest='tile_number', default=None,
                        help='S2 Tile Number e.g. T19TDK')
    parser.add_argument('-s', '--sensor', type=str, dest='sensor', default=None,
                        choices=['All', 'OLI_TIRS', 'ETM', 'TM'],
                        help='Landsat sensor Identifier [Default: All]')
    parser.add_argument('--spacecraft', dest='spacecraft', default=None,
                        choices=[f'LANDSAT_{x}' for x in (4, 5, 7, 8)],
                        help='Landsat spacecraft identifier')
    parser.add_argument('--products', dest='products', default=None,
                        help='M2M ARD product names (e.g. "SR,TOA,BT"')
    req_parser.add_argument('--dataset', type=str, dest='dataset', required=True,
                        help='EE Catalog dataset [e.g. SENTINEL_2A]')
    parser.add_argument('--acq-date', type=str, dest='acq_date', default=None,
                        help='Search Date Acquired (YYYY-MM-DD)')
    parser.add_argument('--gen-date', type=str, dest='gen_date', default=None,
                        help='ARD Tile Production Date (YYYY/MM/DD)')
    parser.add_argument('--search-only', action='store_true', dest='search_only',
                       help='Only return search results, do not download')
    parser.add_argument('--instance', type=str, dest='instance', choices=['devsys', 'devmast', 'ops'], default='ops',
                        help='Which instance of M2M to use (devsys, devmast, ops)')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    download_files(**vars(build_command_line_arguments()))
