import argparse
import itertools
import logging
import os
import re
import requests
import traceback

from datetime import datetime
from dateutil.tz import tzutc, tzlocal
from dateutil.parser import parse
from multiprocessing.pool import ThreadPool

import infra.data_catalog.client.data_rest_api as data_rest_api
import pipedream.api

from base.logging import zoox_logger
from infra.data_catalog.common import id_util
from zclient.client import make_client_from_url

logger = logging.getLogger('retry_bin')

BAD_PIPEDREAM_STATES = ['cancelled', 'failed', 'timeout']
PIPEDREAM_SCHEDULING_URL = 'https://pipedream-scheduling.zooxlabs.com'
PIPEDREAM_ACCOUNTING_URL = 'https://pipedream-accounting.zooxlabs.com'

MAX_THREAD_COUNT = 8

DEFAULT_VEHICLES = ['Kitt', 'VH6']
INGEST_STAGES = ['registration', 'primary', 'secondary', 'camera']

def get_run_pipelines(meta_id, stage):
    """
    :param meta_id: identifier for the run
    :param stage: stage of ingest, one of INGEST_STAGES
    """
    job_name = "*ingest*-{}".format(meta_id)
    logger.info('Querying pipedream for job {}'.format(job_name))
    pipelines_url = os.path.join(PIPEDREAM_ACCOUNTING_URL, "v0/pipelines")
    pipelines_resp = requests.get(pipelines_url, params={'name': job_name})
    pipelines_resp.raise_for_status()
    pipelines = filter(
        lambda x: stage in x['name'],
        sorted(
            pipelines_resp.json().get('pipelines'),
            lambda x,y: x['launchedTimestamp'] > y['launchedTimestamp']))
    return pipelines

def retry_pipeline(pipe_id, task, meta_id, partitions, skip_coredump, dry_run):
    """
    :param pipe_id: identifier of pipedream job
    :param task: task that is part of an ingest stage
    :param meta_id: identifier for the run
    :param partitions: partition to assign jobs
    :param skip_coredump: bool to skip coredump job in primary ingest
    :param dry_run: bool to run as dry run
    """
    retry_pipedef = None
    if not task: # we will retry entire stage
        logger.info('Attempting to retry failures pipeline {}'.format(pipe_id))
        pipedef_url = os.path.join(
            PIPEDREAM_SCHEDULING_URL, "v1/pipeline/{}/retry_pipedef?only_retry_failed_stages=true"\
                .format(pipe_id))
        response = requests.get(pipedef_url)
        response.raise_for_status()
        retry_pipedef = response.json()
    else:
        logger.info('Attempting to retry task {} in job'.format(task))
        pipedef_url = os.path.join(
            PIPEDREAM_SCHEDULING_URL, "v1/pipeline/{}/retry_pipedef?stage_name={}"\
                .format(pipe_id, task))
        response = requests.get(pipedef_url)
        response.raise_for_status()
        retry_pipedef = response.json()
    if partitions:
        for job in retry_pipedef["job"]:
            opts = job["slurmOption"]
            for i, opt in enumerate(opts):
                if opt.startswith("--partition="):
                    opts[i] = "--partition={}".format(partitions)
    if skip_coredump:
        def bad_prefix(name):
            return name.startswith('ingest_vehicle_coredumps')
        retry_md = retry_pipedef['retryMetadata']
        retry_md['jobMetadata'] = list(filter(lambda md: not bad_prefix(md['jobName']), retry_md['jobMetadata']))
        retry_pipedef['job'] = list(filter(lambda job: not bad_prefix(job['name']), retry_pipedef['job']))
        for job in retry_pipedef['job']:
            job['dependency'] = list(filter(lambda dep: not bad_prefix(dep['jobName']), job['dependency']))

    if dry_run:
        logger.info('(Dry-run): would schedule retry for job {} '
                    '- would retry tasks: {}'.format(
                        retry_pipedef['retryMetadata']['parentPipeId'],
                        map(lambda x: x['jobName'], retry_pipedef['retryMetadata']['jobMetadata'])))
        return
    retry_pipe_id, _ = pipedream.api.schedule(retry_pipedef, as_user='ingest')
    data_id = id_util.human_readable_to_data_id(meta_id)
    data_rest_api.add_pipeline(data_id, retry_pipe_id)

def retry_run(meta_id, stage, task, retry_reruns, partitions, skip_coredump, dry_run):
    """
    :param meta_id: identifier of run
    :param stage: stage of ingest, one of INGEST_STAGES
    :param task: task that is part of an ingest stage
    :param retry_reruns: bool to allow retried jobs to be retried again
    :param partitions: partition to assign jobs
    :param skip_coredump: bool to skip coredump job in primary ingest
    :param dry_run: bool to run as dry run
    """
    try:
        pipelines = get_run_pipelines(meta_id, stage)
        if not pipelines:
            logger.warn('Did not find any pipelines to retry for run {} with stage {}'\
                .format(meta_id, stage))
            return
        if any([pl for pl in pipelines if not pl['terminated']]):
            logger.warn('Pipeline for run {} with stage {} found running/pending. Will not retry.'\
                .format(meta_id, stage))
            return
        pipeline = pipelines[0]
        pipe_id = pipeline['pipeId']
        if not any([state for state in pipeline['state'].split(' ') if state in BAD_PIPEDREAM_STATES]):
            logger.info('Pipeline for run {} did not fail. Will not retry.'.format(meta_id))
            return
        retry_pipeline(pipe_id, task, meta_id, partitions, skip_coredump, dry_run)
    except Exception as e:
        logger.exception(e)


def parallelize(f, args):
    pool = ThreadPool(MAX_THREAD_COUNT)
    errors = pool.map(f, args)
    pool.close()
    pool.join()
    errors = filter(None, errors)
    if errors:
        for error in errors:
            logger.error(error)
        raise Exception('Found {} errors.'.format(len(errors)))

def parallelize_pipe_ids(pipe_ids):
    logger.info('Processing {} pipelines'.format(len(pipe_ids)))
    arg_list = map(lambda x: (x, args.task, args.partitions, args.skip_coredump, args.dry_run), pipe_ids)
    parallelize(lambda args: retry_pipeline(*args), arg_list)

def parallelize_meta_ids(meta_ids):
    logger.info('Processing {} runs'.format(len(meta_ids)))
    arg_list = map(lambda x: (x, args.stage, args.task, args.retry_reruns, args.partitions, args.skip_coredump, args.dry_run), meta_ids)
    parallelize(lambda args: retry_run(*args), arg_list)

def valid_date(s):
    try:
        parsed = parse(s)
        return parsed.replace(tzinfo=tzutc())
    except ValueError:
        msg = "Not a valid date: '{0}'.".format(s)
        raise argparse.ArgumentTypeError(msg)

def valid_ingest_stage(s):
    if s in INGEST_STAGES:
        return s
    else:
        msg = "--stage must be one of {}".format(INGEST_STAGES)
        raise argparse.ArgumentTypeError(msg)

def parse_args():
    parser = argparse.ArgumentParser(
        "Retry ingest tasks based on certain parameters.")
    parser.add_argument("--runs_file", type=str, required=False, help='file path for file with line separated list of run meta ids')
    parser.add_argument("--pipelines_file", type=str, required=False, help='file path for file with line separated list of pipe ids')
    parser.add_argument("--ingest_meta_id", type=str, required=False, help='ingest meta id for single run to process')
    parser.add_argument("--stage", type=valid_ingest_stage, required=False, default='primary', help='stage to retry')
    parser.add_argument("--task", type=str, required=False, help='task to retry, has to be part of stage')
    parser.add_argument("--starts_at", type=valid_date, required=False, help='start date range to search in UTC')
    parser.add_argument("--ends_at", type=valid_date, required=False, default=datetime.now(tz=tzutc()), help='end date range to search in UTC (default: now)')
    parser.add_argument("--include-simulation", action='store_true',
                        help='boolean flag to include simulation runs (only used when querying dates, default: False)')
    parser.add_argument("--only-simulation", action='store_true',
                        help='boolean flag to only run for simulation runs (only used when querying dates, default: False)')
    parser.add_argument("--retry-reruns", action='store_true',
                        help='boolean flag to schedule retries for stages that have failed manual retries (default: False)')
    parser.add_argument("--partitions", required=False, default=None, help='assign jobs to partition')
    parser.add_argument("--skip-coredump", action='store_true', help='Skip coredump in primary ingest')
    parser.add_argument('--dry-run', action='store_true',
                        help='boolean flag to print instead of storing faultevents (default: False)')
    args = parser.parse_args()
    if not (args.runs_file or args.pipelines_file or args.ingest_meta_id or (args.starts_at and args.ends_at)):
        parser.error('Must include --runs_file, --pipelines_file, --ingest_meta_id, or both --starts_at and --ends_at')
    if args.include_simulation and args.only_simulation:
        parser.error('Must use one of --include-simulation and --only-simulation')
    return args

def main(args):
    logger.info('Running with args: {}'.format(dict(filter(lambda x: x[1] is not None, args.__dict__.items()))))

    if args.ingest_meta_id:
        retry_run(args.ingest_meta_id, args.stage, args.task, args.retry_reruns, args.partitions, args.skip_coredump, args.dry_run)
    elif args.runs_file:
        with open(args.runs_file) as f:
            meta_ids = [line.strip() for line in f]
        parallelize_meta_ids(meta_ids)
    elif args.pipelines_file:
        with open(args.pipelines_file) as f:
            pipe_ids = [line.strip() for line in f]
        parallelize_pipe_ids(pipe_ids)
    else:
        starts_at = args.starts_at
        ends_at = args.ends_at
        vehicle_types = DEFAULT_VEHICLES
        if args.include_simulation:
            vehicle_types.append('SimulatedKitt')
        elif args.only_simulation:
            vehicle_types = ['SimulatedKitt']
        run_result = data_rest_api.get_runs(
            window_start=starts_at,
            window_end=ends_at,
            limit=10000,
            vehicle_type_list=vehicle_types,
        )
        if not run_result.get('success'):
            logger.error('Runs request returned error {}'.format(run_result))
        meta_ids = sorted(map(lambda x: x['meta_id'], run_result['runs']))
        parallelize_meta_ids(meta_ids)

if __name__ == '__main__':
    zoox_logger.configureLogging('retry_bin')
    args = parse_args()
    main(args)

