import argparse
import concurrent.futures
import os
import sys
import time

from functools import partial
from getpass import getpass
from importlib.metadata import version
from tqdm import tqdm
from typing import List
from typing import Tuple

from m2m import download
from m2m import filesystem
from m2m import persist
from m2m import prune
from m2m import stdout_logger
from m2m.api import M2M


log = stdout_logger(__name__)


def get_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        prog='ls-m2m',
        description='Order and download landsat imagery through the M2M interface')

    parser.add_argument('h', type=int)
    parser.add_argument('v', type=int)
    parser.add_argument('region', type=str, choices=['CU', 'AK', 'HI'])
    parser.add_argument('-outroot', type=str, default=os.path.abspath('.'))
    parser.add_argument('-ceph', action='store_true', help='Download directly to Ceph (ignores outroot)')
    parser.add_argument('-endpoint_url', type=str, required=False, help="Ceph endpoint URL, only required if passing -ceph.")
    parser.add_argument('-aws_profile', type=str, default="ceph", help="A named profile in .aws/credentials for authenticating with Ceph.  Only required if passing -ceph.")
    parser.add_argument('-cpu', type=int, required=False, default=5)
    parser.add_argument('-batch_size', type=int, required=False, default=100)
    parser.add_argument('-username', type=str, required=False, default=None)
    parser.add_argument('-password', type=str, required=False, default=None)
    parser.add_argument('-resume', type=bool, required=False, default=True)

    parser.add_argument("-v", "--version", action="version", version=version('ls-m2m'))

    return parser


def get_userpass() -> Tuple[str, str]:
    return getpass('Username: '), getpass('Password: ')


def request_links(api: M2M, products: List[dict]) -> List[str]:
    resp = api.download_request(products, 'c2ard')
    return [d['url'] for d in resp['availableDownloads']]


def batch(things, size):
    for i in range(0, len(things), size):
        yield things[i: i + size]


def download_products(username, password, fs, outroot, cpu, batch_size, product_list, download_fn):
    """
    """
    func = partial(download_fn, fs=fs, outroot=outroot)
    with concurrent.futures.ThreadPoolExecutor(max_workers=min(os.cpu_count(), cpu)) as executor:
        t1 = time.time()
        tot = 0.0
        with tqdm(total=len(product_list), file=sys.stdout) as pbar:

            for b in batch(product_list, batch_size):
                with M2M(username, password) as api:
                    urls = request_links(api, b)

                for res in executor.map(func, urls):
                    tot += res / 1024 / 1024
                    pbar.set_postfix({'Speed': tot / (time.time() - t1)})
                    pbar.update()


def download_hv(username, password, h, v, region, ceph, endpoint_url, aws_profile, outroot, resume, cpu, batch_size):
    if ceph:
        if not endpoint_url:
            log.error("Must specify endpoint_url if using ceph.")
            sys.exit(1)
        fs_type = "ceph"
    else:
        fs_type = "local"

    fs = getattr(filesystem, fs_type)(profile=aws_profile, endpoint_url=endpoint_url)
    download_fn = getattr(download, fs_type)
    persist_write_fn = getattr(persist, f"{fs_type}_w")

    if fs_type == "local" and not fs.exists(outroot):
        log.info('Creating output directory')
        fs.makedirs(outroot)

    log.info('Retrieving product list')
    if resume and fs.exists(persist.product_list_path(fs, h, v, region, outroot)):
        log.info('Using existing product list')
        persist_read_fn = getattr(persist, f"{fs_type}_r")
        product_list = persist_read_fn(fs, h=h, v=v, region=region, outroot=outroot)
    else:
        with M2M(username, password) as api:
            product_list = persist.create_product_list(api, h, v, region)
        log.info('Saving product list')
        persist_write_fn(fs, product_list, h=h, v=v, region=region, outroot=outroot)

    if resume:
        log.info(f'Pruning product list: {len(product_list)}')
        prune_fn = getattr(prune, fs_type)
        product_list = prune_fn(fs, product_list, outroot=outroot)

    log.info(f'Number products to download: {len(product_list)}')
    log.info('Beginning downloads')
    download_products(username, password, fs, outroot, cpu, batch_size, product_list, download_fn)


def cli():
    pargs = get_parser().parse_args()

    if (pargs.username is None) or (pargs.password is None):
        pargs.username, pargs.password = get_userpass()

    download_hv(**vars(pargs))


if __name__ == "__main__":
    try:
        cli()
    except Exception as e:
        log.error(e)
