#!/usr/bin/env python3
"""Compare PISCES steering mode results.

This script extracts the 2ws/4ws log test scenario simulation results from
Marvel, computes the errors using PlannerMotionComparator, and writes the
results to a JSON file.
"""
import argparse
from collections import namedtuple
import json
import os

from base.logging import zoox_logger
from mined_metric.builder.metrics_impl.pisces.utils.marvel import get_marvel_run
from mined_metric.builder.metrics_impl.pisces.utils.utils import (
    get_pipedream_task_info,
)
from vis.data_analysis.data_comparator.data_extractor import (
    PlannerDataExtractor,
)
from vis.data_analysis.data_comparator.planner_motion_comparator import (
    InvalidComparisonError,
    PlannerMotionComparator,
)

DummyArgs = namedtuple(
    "DummyArgs", ["output_metrics_path", "plot_trk_errors", "decision_only"]
)
PLANNER_MOTION_COMPARATOR_ARGS = DummyArgs(
    output_metrics_path="", plot_trk_errors=False, decision_only=True
)

LOG = zoox_logger.ZooxLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser("Compare two and four wheel steering")

    parser.add_argument(
        "--pisces-id", "--id", help="ID of PISCES job", required=True, type=str
    )

    parser.add_argument(
        "--marvel-2ws-id", help="ID for Marvel 2ws run", required=True, type=str
    )
    parser.add_argument(
        "--marvel-4ws-id", help="ID for Marvel 4ws run", required=True, type=str
    )

    parser.add_argument(
        "--label",
        "-l",
        help="Branch label for the comparison",
        required=True,
        type=str,
    )

    parser.add_argument(
        "--output-dir",
        "-o",
        help="Directory to write results to",
        required=True,
        type=str,
    )

    return parser.parse_args()


def get_marvel_results(results):
    for datum in results:
        if datum.get("sim_failed", False):
            LOG.warn("Skipping %s as simulation failed.", datum["chum_uri"])
            continue

        yield datum["scenario_hash"], datum["chum_uri"], datum["sdl_target"]


def compute_2ws_4ws_errors(ctrl_uri, cand_uri):
    bl_extractor = PlannerDataExtractor(ctrl_uri, decision_only=True)
    bl_dataset = bl_extractor.get_dataset()
    comp_extractor = PlannerDataExtractor(cand_uri, decision_only=True)
    comp_dataset = comp_extractor.get_dataset()
    # Read vehicle wheelbase for baseline and comparison runs.
    bl_wheel_base = bl_extractor.get_config().dimensions().wheel_base
    comp_wheel_base = comp_extractor.get_config().dimensions().wheel_base

    # Run comparison
    # Compare the baseline and comparison runs.
    comparator = PlannerMotionComparator(
        bl_dataset,
        comp_dataset,
        bl_wheel_base,
        comp_wheel_base,
        PLANNER_MOTION_COMPARATOR_ARGS,
    )

    try:
        comparator.compare()
        ## Computes error metrics between the runs from the comparison.
        results = comparator.compute_metrics()
    except InvalidComparisonError as err:
        # Comparison failed - indicate infinite error but keep going so that
        # this stage completes successfully.
        error_metrics = {"Max absolute": float("inf"), "RMS": float("inf")}
        results = {
            "Projected errors": {
                "decision": {
                    "Decision s error(m)": error_metrics,
                    "Projected Lateral error(m)": error_metrics,
                    "Projected Heading error(deg)": error_metrics,
                    "Projected Velocity vx error(m/s)": error_metrics,
                    "Expected heading error diff(deg)": error_metrics,
                    "Non-expected heading error diff(deg)": error_metrics,
                }
            }
        }
        LOG.error(f"planner_motion_comparator.py: {err.message}")
    return results


def extract_data_from_marvel(
    pisces_id,
    marvel_2ws,
    marvel_4ws,
    output_dir,
    label,
    task_index,
    total_task_count,
):
    results_dir = os.path.join(output_dir, pisces_id, label)
    os.makedirs(results_dir, exist_ok=True)

    control = {
        key: (chum, target)
        for key, chum, target in get_marvel_results(get_marvel_run(marvel_2ws))
    }
    candidate = {
        key: (chum, target)
        for key, chum, target in get_marvel_results(get_marvel_run(marvel_4ws))
    }
    # Find the intersection of the keys
    # Using a sorted set is important here so that the valid keys are enumerated
    # the same in each PipeDream task.
    valid_keys = sorted(set(control.keys()).intersection(set(candidate.keys())))

    for idx, key in enumerate(valid_keys):
        if idx % total_task_count != task_index:
            continue

        # TODO: store the results
        LOG.info("%s 2ws %s", label, control[key][1])
        LOG.info("%s 4ws %s", label, candidate[key][1])
        comparison_results = compute_2ws_4ws_errors(
            control[key][0], candidate[key][0]
        )

        # While we could write out results into a _single_ big file, we choose
        # to write the results into individual json files in keeping with the
        # PISCES convention of having 1 file per result.
        results = dict(
            hash_id=key,
            control_chum_uri=control[key][0],
            candidate_chum_uri=candidate[key][0],
            sdl_target=control[key][1],
            results=comparison_results,
        )
        with open(os.path.join(results_dir, f"{key}.steering.json"), "w") as f:
            json.dump(results, f)


if __name__ == "__main__":
    zoox_logger.configureLogging("compare_2ws_4ws")

    args = parse_args()
    total_task_count, task_index = get_pipedream_task_info()

    extract_data_from_marvel(
        args.pisces_id,
        args.marvel_2ws_id,
        args.marvel_4ws_id,
        args.output_dir,
        args.label,
        task_index,
        total_task_count,
    )
