#!/usr/bin/env python3
"""Extract PISCES data."""
import argparse
import logging
import os
import subprocess
import time

from base.logging import zoox_logger
from data.chum import chumpy
from mined_metric.builder.metrics_impl.pisces.utils.utils import (
    get_pipedream_task_info,
)
from vehicle.planner.metrics.python.utils.metric_hub_utils import (
    get_experiment_id_from_validation,
    get_experiment_data,
)

LOG = zoox_logger.ZooxLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser("extract_pisces_data")
    parser.add_argument(
        "--validation-id",
        "--id",
        help="MetricHub validation ID to use",
        required=True,
        type=str,
    )

    parser.add_argument(
        "--output-dir",
        "-o",
        help="Directory to write results to",
        required=True,
        type=str,
    )
    parser.add_argument(
        "--debug",
        "--verbose",
        "-v",
        help="Whether to print debug statements",
        required=False,
        action="store_true",
    )
    args = parser.parse_args()

    if args.debug:
        LOG.set_level(logging.DEBUG)

    return args


def check_sha_matches(sha1, sha2):
    m = len(sha1)
    n = len(sha2)
    min_length = min(m, n)
    return sha1[:min_length] == sha2[:min_length]


class Experiment:
    def __init__(self, data, queue_length=16):
        self.data = data
        self.queue_length = queue_length

    def read_data(self, results_dir, task_index=0, total_task_count=1):
        r"""Extracts data and stores results in results_dir

        Maps the extractor over self.data to pull out relevant data for PISCES
        comparisons. Results are stored in results_dir. This method is trivially
        parallelizable.

        Parameters
        ----------
        results_dir : str
            Location to write results of extraction
        task_index : int
            Index of current task on pipedream
        total_task_count : int
            Total tasks for this stage on pipedream
        """
        shells = []
        for idx, datum in enumerate(self.data):
            if idx % total_task_count != task_index:
                continue
            if datum.get("sim_failed", False):
                LOG.warn("Skipping %s as simulation failed.", datum["chum_uri"])
                continue

            # use the bazel target to match _or_ use the chum uri
            if "bazel_target" in datum:
                # TODO: the bazel target is not sufficient to match and we
                # should match across sim called with the _same_ args, but when
                # the thing we're comparing is between the variations... that's
                # harder
                filename = datum["scenario_hash"]
                sdl_target = datum["sdl_target"]
            else:
                chum_uri = chumpy.parseChumUriToProto(datum["chum_uri"])
                chum_range = chumpy.Range()
                chumpy.parseChumUri(datum["chum_uri"], range=chum_range)

                # null out chum_uri stuff
                chum_uri.ClearField("input")
                chum_uri.ClearField("variant")
                chum_uri.ClearField("topic")

                chum_uri.start_time_ns = chum_range.start_time
                chum_uri.end_time_ns = chum_range.end_time

                uri = chumpy.renderChumUri(chum_uri)
                filename = uri.split("/")[2]
                sdl_target = ""
            file_path = os.path.join(results_dir, filename)
            cmd = (
                "./mined_metric/builder/metrics_impl/pisces/extract_data "
                f"--chum_uri '{datum['chum_uri']}' "
                f"--filepath {file_path} "
                f"--sdl_target '{sdl_target}'"
            )
            # ensure the queue of outstanding processes does not exceed
            # queue_length in size
            #
            # we poll all processes until one finishes
            while len(shells) >= self.queue_length:
                time.sleep(1)
                for s in shells:
                    retcode = s.poll()
                    if retcode is not None:
                        shells.remove(s)
            p = subprocess.Popen(cmd, shell=True)
            LOG.info("Spawned pid %s to process %s", p.pid, datum["chum_uri"])
            shells.append(p)

        for s in shells:
            s.wait()


def extract_data_from_metrichub(
    validation_id, output_dir, task_index, total_task_count
):
    r"""Extracts all data from MetricHub validation run

    A MetricHub "validation" is a collection of simulation runs; in this case,
    we should only have two (the control and candidate). This function maps
    the extractor over all the experiments and stores the results in the
    provided output directory.

    Parameters
    ----------
    validation_id : str
        MetricHub validation ID; typically a pre-generated UUID
    output_dir : str
        Location to write results of extraction
    task_index : int
        Index of current task on pipedream
    total_task_count : int
        Total tasks for this stage on pipedream
    """
    experiments = get_experiment_id_from_validation(validation_id)

    for i, experiment in enumerate(experiments):
        state = experiment.get("state", "UNKNOWN")
        if state != "DONE":
            LOG.warn(
                "Skipping experiment %s in state %s", experiment["id"], state
            )
            continue

        msg = experiment.get("additionalMsg", f"{i:02}")
        msg = "_".join(msg.split(" "))
        try:
            data = get_experiment_data(experiment["id"])
        except:
            LOG.error("Failed to get experiment on %s", experiment["id"])
            raise

        results_dir = os.path.join(output_dir, validation_id, msg)
        os.makedirs(results_dir, exist_ok=True)

        Experiment(data).read_data(results_dir, task_index, total_task_count)


def main():
    zoox_logger.configureLogging("extract_pisces_data")
    args = parse_args()

    total_task_count, task_index = get_pipedream_task_info()

    extract_data_from_metrichub(
        args.validation_id, args.output_dir, task_index, total_task_count
    )


if __name__ == "__main__":
    main()
