"""PISCES utilities."""
import collections
import json
import multiprocessing as mp
import os

import pandas as pd

from base.logging import zoox_logger
import mined_metric.builder.metrics_impl.pisces.errors as errs
from mined_metric.builder.metrics_impl.pisces.utils import comparator
from mined_metric.builder.metrics_impl.pisces.utils import plotter

LOG = zoox_logger.ZooxLogger(__name__)


def get_pipedream_task_info():
    """Retrieves pipedream total tasks and task index information.

    Pipedream provides the current task index and the total task count via
    environment variables. We use these variables to determine which pieces of
    work to do. By default, we would do all the work as there is just one
    total task.

    Returns
    -------
    task_count : int
        Total tasks for this stage
    task_index  : int
        Current index for this task
    """
    task_count = int(os.environ.get("PIPEDREAM_TASK_COUNT", 1))
    task_index = int(os.environ.get("PIPEDREAM_TASK_INDEX", 0))
    return task_count, task_index


def is_subgoal_label(label):
    is_action = label.startswith("ACTION")
    is_action_cost = label.endswith("_COST")
    is_active_indicator = label.endswith("_is_active")
    return (
        not (is_action or is_action_cost or is_active_indicator)
        and label.isupper()
    )


def is_action_label(label):
    is_action = label.startswith("ACTION")
    is_action_cost = label.endswith("_COST")
    return is_action and not is_action_cost


def add_subgoal_activation(df, subgoals):
    for v in subgoals:
        df[v + "_is_active"] = df[v].notnull().astype(int)
    return df


def add_corridor_pt_count(df):
    def count_pts(corr_json):
        try:
            return len(json.loads(corr_json))
        except:
            return None

    df["corridor_left_cnt"] = df["corridor_left"].map(count_pts)
    df["corridor_right_cnt"] = df["corridor_right"].map(count_pts)
    return df


def __read_data(filename):
    try:
        _, ext = os.path.splitext(filename)
        if ext == ".csv":
            df = pd.read_csv(filename)
        elif ext == ".json":
            df = pd.read_json(filename)
        else:
            msg = f"File {filename} is not json or csv"
            LOG.error(msg)
            raise errs.PiscesInvalidFiletypeError(msg)
    except (pd.errors.ParserError, pd.errors.EmptyDataError) as e:
        LOG.error("Error reading %s (%s)", filename, e)
        raise

    if "timestamp" in df:
        df = df.dropna(subset=["timestamp"])
        try:
            df["timestamp"] = df["timestamp"].apply(
                lambda x: pd.Timestamp(float(x), unit="s")
            )
        except ValueError as e:
            LOG.error("Error converting timestamp in %s (%s)", filename, e)
            LOG.error("Timestamp: %s", df["timestamp"].iloc[0])
            raise
    return df


def read_planner_data(filename):
    df = __read_data(filename)

    df = df.drop("ars_svec", axis=1)
    df = df.drop("ars_eyvec", axis=1)

    df = df.drop("lane_ref_svec", axis=1)
    df = df.drop("lane_ref_eyvec", axis=1)

    df = df.set_index("timestamp")
    df.sort_index(inplace=True)
    return df


def read_steering_data(filename):
    df = __read_data(filename)
    # Do data cleaning here for the data read from the file.
    return df


def read_pred_data(filename):
    df = __read_data(filename)
    # Do data cleaning here for the data read from the file.
    return df


def read_pcp_data(filename):
    df = __read_data(filename)
    # Do data cleaning here for the data read from the file.
    return df


class Scenario:
    def __init__(self, name):
        self.name = name
        self.steering_mode_only = False
        self.df = None
        self.subgoals = set()
        self.actions = set()
        self.df_pred = None
        self.df_pcp = None

    def has_data(self):
        return self.df is not None and len(self.df) > 0

    def load_planner_data(self, filename):
        if filename is None:
            return
        # TODO: validate file contains proper contents
        df = read_planner_data(filename)

        self.subgoals = set(filter(is_subgoal_label, df.columns.values))
        self.actions = set(filter(is_action_label, df.columns.values))
        self.df = add_subgoal_activation(df, self.subgoals)
        self.df = add_corridor_pt_count(df)

        self.steering_mode_only = False

    def load_steering_data(self, filename):
        if filename is None:
            return
        self.df = read_steering_data(filename)
        self.steering_mode_only = True

    def load_prediction_data(self, filename):
        if filename is None:
            return
        self.df_pred = read_pred_data(filename)

    def load_pcp_data(self, filename):
        if filename is None:
            return
        self.df_pcp = read_pcp_data(filename)

    def describe_planner_stats(self):
        # Excludes stats for steering mode only
        if self.steering_mode_only:
            return None

        stats = {"ticks": len(self.df)}
        subgoal_stats = {
            s: self.df[s + "_is_active"].sum() for s in self.subgoals
        }
        action_stats = {a: self.df[a].sum() for a in self.actions}
        return {**stats, **subgoal_stats, **action_stats}

    def describe_prediction_stats(self):
        # Excludes stats for steering mode only
        if self.steering_mode_only:
            return None

        total_ticks = len(self.df_pred["timestamp"].unique())
        total_entities = len(self.df_pred["track_id"].unique())
        if "traj_type" in self.df_pred:
            data = self.df_pred.groupby(["traj_type"]).agg(
                {"timestamp": "nunique", "track_id": "nunique"}
            )
        else:
            data = pd.DataFrame(
                {0: [0], 1: [0]}, index=["STATIONARY"]
            )  # dummy dataframe
        data.columns = ["tick_cnts", "entity_cnts"]
        # Insert a final row of total
        data.loc["Total"] = [total_ticks, total_entities]
        return data


def read_scenario_data(args):
    pred_fn, plan_fn, pcp_fn, _, steering_fn = args

    if steering_fn is None:
        assert (
            (pred_fn is not None)
            and (plan_fn is not None)
            and (pcp_fn is not None)
        )
        data_loc = os.path.dirname(plan_fn)
        # assuming the path name is of the pattern of "/path/to/file_basename.plan.csv"
        scenario_name = ".".join(os.path.basename(plan_fn).split(".")[:-2])
    else:
        data_loc = os.path.dirname(steering_fn)
        scenario_name = os.path.basename(steering_fn).split(".")[0]

    s = Scenario(scenario_name)

    LOG.info("Loading data for scenario '%s' from %s", scenario_name, data_loc)

    try:
        s.load_planner_data(plan_fn)
        s.load_steering_data(steering_fn)
        s.load_prediction_data(pred_fn)
        s.load_pcp_data(pcp_fn)
    except pd.errors.EmptyDataError:
        # Returning None skips this scenario from processing
        return None
    return (s.name, s)


class Release:
    # release is a collection of scenarios
    def __init__(self, name):
        self.name = name

    def load_data(self, root_dir):
        LOG.info("Loading scenarios for %s at %s", self.name, root_dir)

        self.scenarios = {}
        data_paths = {}
        self.release_dir = os.path.join(root_dir, self.name)
        for _, _, files in os.walk(self.release_dir):
            for f in files:
                root, extension = os.path.splitext(f)
                scenario_name, prefix = os.path.splitext(os.path.basename(root))
                # Initialize the scenario with 3 paths, which may not always be populated.
                if scenario_name not in data_paths:
                    data_paths[scenario_name] = [None, None, None, None, None]
                if extension == ".csv":
                    if prefix == ".pred":
                        data_paths[scenario_name][0] = f
                    elif prefix == ".plan":
                        data_paths[scenario_name][1] = f
                    elif prefix == ".pcp":
                        data_paths[scenario_name][2] = f
                elif extension == ".loc":
                    data_paths[scenario_name][3] = f
                elif extension == ".json":
                    data_paths[scenario_name][4] = f

        def paths_generator():
            for _, file_paths in data_paths.items():
                pred_fn, plan_fn, pcp_fn, vid_fn, steering_fn = file_paths
                if steering_fn is None:
                    if pred_fn is None or plan_fn is None or pcp_fn is None:
                        continue
                    pred_path = os.path.join(self.release_dir, pred_fn)
                    plan_path = os.path.join(self.release_dir, plan_fn)
                    pcp_path = os.path.join(self.release_dir, pcp_fn)
                    vid_path = (
                        os.path.join(self.release_dir, vid_fn)
                        if vid_fn
                        else None
                    )
                    yield (pred_path, plan_path, pcp_path, vid_path, None)
                else:
                    steering_path = os.path.join(self.release_dir, steering_fn)
                    yield (None, None, None, None, steering_path)

        pool = mp.Pool(mp.cpu_count())
        self.scenarios = dict(
            filter(lambda x: x, pool.map(read_scenario_data, paths_generator()))
        )
        pool.close()
        pool.join()

        self.__prune_scenarios()

    def __prune_scenarios(self):
        pruned_scenarios = {
            s: self.scenarios[s]
            for s in self.scenarios
            if self.scenarios[s].has_data()
        }
        self.scenarios = pruned_scenarios


RenderPayload = collections.namedtuple(
    "RenderPayload",
    [
        "argus_host",
        "scenario_name",
        "control_release",
        "control_sha",
        "control_scenario",
        "candidate_release",
        "candidate_sha",
        "candidate_scenario",
    ],
)


def render_pisces_html(args):
    p = plotter.Plotter()
    argus_host = args.argus_host
    scenario_name = args.scenario_name

    p.add_data(args.control_release, args.control_sha, args.control_scenario)
    p.add_data(
        args.candidate_release, args.candidate_sha, args.candidate_scenario
    )

    return (scenario_name, p.render_html(argus_host))


class Experiment:
    # experiment contains multiple releases to compare
    def __init__(self, base_path, experiment_id):
        self.base_path = base_path
        self.validation_id = experiment_id
        self.releases = {}

    @property
    def id(self):
        return self.validation_id

    def load_data(self):
        """
        Loads all the PISCES experiment data from the given path.

        Assumes path has the following structure:
            path/RELEASE/data.csv

        TODO: ensure data stored at that path has fixed schema, e.g., protobufs
        """
        root_dir = os.path.join(self.base_path, self.validation_id)
        LOG.info("Loading data from %s", root_dir)

        self.releases = {}
        for root, dirs, files in os.walk(root_dir):
            if len(dirs) == 0:
                release_name = root[len(root_dir) + 1 :]
                self.releases[release_name] = Release(release_name)

        # check that releases are nonempty
        if not self.releases:
            LOG.error("No releases found in %s", root_dir)
            raise errs.PiscesNoDataFound("No data found in %s" % root_dir)

        LOG.info("Found releases: %s", self.releases.keys())

        for release in self.releases:
            self.releases[release].load_data(root_dir)

        for r in self.releases:
            if len(self.releases[r].scenarios) == 0:
                raise errs.PiscesNoDataFound(
                    "No scenarios found for release %s" % r
                )

        self.experiment_dir = root_dir
        self.__get_scenarios()

    def __get_scenarios(self):
        # checks that all releases have the same scenarios
        scenarios = [
            set(self.releases[r].scenarios.keys()) for r in self.releases
        ]
        if not all(s == scenarios[0] for s in scenarios):
            LOG.warn("Not all scenarios in all releases; using intersection")
            scenario_set = scenarios[0]
            for s in scenarios:
                scenario_set = scenario_set.intersection(s)
        else:
            scenario_set = scenarios[0]

        assert (
            len(scenario_set) > 0
        ), "No common scenarios found between two SHAs"

        self.scenario_set = scenario_set

    def __check_steering_comparison(self, scenario_a, scenario_b):
        assert (
            scenario_a.steering_mode_only == scenario_b.steering_mode_only
        ), "Scenario '{}' should agree whether we're checking steering mode only".format(
            scenario
        )
        return scenario_a.steering_mode_only

    def __get_steering_comparator(self, release_a, release_b, scenario):
        scenario_a = release_a.scenarios[scenario]
        scenario_b = release_b.scenarios[scenario]
        steering_mode_only = self.__check_steering_comparison(
            scenario_a, scenario_b
        )
        if not steering_mode_only:
            return None

        release_a_df = scenario_a.df
        release_b_df = scenario_b.df

        LOG.info("Getting steering comparator for scenario '%s'", scenario)
        return comparator.SteeringComparator(
            scenario, release_a_df, release_b_df
        )

    def __get_planner_comparator(self, release_a, release_b, scenario):
        scenario_a = release_a.scenarios[scenario]
        scenario_b = release_b.scenarios[scenario]
        steering_mode_only = self.__check_steering_comparison(
            scenario_a, scenario_b
        )
        if steering_mode_only:
            return None

        release_a_df = scenario_a.df
        release_b_df = scenario_b.df

        LOG.info("Getting planner comparator for scenario '%s'", scenario)
        return comparator.Comparator(scenario, release_a_df, release_b_df)

    def __get_prediction_comparator(self, release_a, release_b, scenario):
        scenario_a = release_a.scenarios[scenario]
        scenario_b = release_b.scenarios[scenario]
        steering_mode_only = self.__check_steering_comparison(
            scenario_a, scenario_b
        )
        if steering_mode_only:
            return None

        release_a_df = scenario_a.df_pred
        release_b_df = scenario_b.df_pred
        LOG.info("Getting prediction comparator for scenario '%s'", scenario)
        return comparator.PredComparator(scenario, release_a_df, release_b_df)

    def __get_pcp_comparator(self, release_a, release_b, scenario):
        scenario_a = release_a.scenarios[scenario]
        scenario_b = release_b.scenarios[scenario]
        steering_mode_only = self.__check_steering_comparison(
            scenario_a, scenario_b
        )
        if steering_mode_only:
            return None

        release_a_df = scenario_a.df_pcp
        release_b_df = scenario_b.df_pcp
        LOG.info("Getting perception comparator for scenario '%s'", scenario)
        return comparator.PCPComparator(scenario, release_a_df, release_b_df)

    def __is_valid_release(self, release):
        rel = self.releases.get(release, None)
        if rel is None:
            LOG.error(
                "Could not find release %s in experiment %s",
                release,
                self.validation_id,
            )
            return False
        return True

    def __validate_releases(self, candidate_release, control_release):
        if self.__is_valid_release(
            candidate_release
        ) and self.__is_valid_release(control_release):
            return True
        return False

    def render_html(
        self,
        candidate_release,
        candidate_sha,
        control_release,
        control_sha,
        argus_host,
    ):
        work_load = []
        for i, s in enumerate(self.scenario_set):
            ctrl_steering_mode_only = (
                self.releases[control_release].scenarios[s].steering_mode_only
            )
            cand_steering_mode_only = (
                self.releases[candidate_release].scenarios[s].steering_mode_only
            )

            assert ctrl_steering_mode_only == cand_steering_mode_only
            steering_mode_only = (
                ctrl_steering_mode_only and cand_steering_mode_only
            )

            work_tuple = RenderPayload(
                argus_host=argus_host,
                scenario_name=s,
                control_release=control_release,
                control_sha=control_sha,
                control_scenario=self.releases[control_release].scenarios[s],
                candidate_release=candidate_release,
                candidate_sha=candidate_sha,
                candidate_scenario=self.releases[candidate_release].scenarios[
                    s
                ],
            )

            # don't append if steering mode only
            if not steering_mode_only:
                work_load.append(work_tuple)

        pool = mp.Pool(mp.cpu_count())
        results = pool.map(render_pisces_html, work_load)
        pool.close()
        pool.join()
        return results

    def compare(self, candidate_release, control_release):
        valid = self.__validate_releases(candidate_release, control_release)

        if valid:
            release_a = self.releases[candidate_release]
            release_b = self.releases[control_release]
            comparators = [
                (
                    self.__get_planner_comparator(release_a, release_b, s),
                    self.__get_prediction_comparator(release_a, release_b, s),
                    self.__get_pcp_comparator(release_a, release_b, s),
                    self.__get_steering_comparator(release_a, release_b, s),
                )
                for s in self.scenario_set
            ]
            plan_comps, pred_comps, pcp_comps, steering_comps = zip(
                *comparators
            )
            return {
                "Planner": plan_comps,
                "Prediction": pred_comps,
                "Perception": pcp_comps,
                "Steering": steering_comps,
            }
        return None

    def aggregate_zci_tick(self, release, input_data_column):
        # columns for aggregation
        cols = [input_data_column, "zci_tick_all"]

        if self.__is_valid_release(release):
            rel = self.releases[release]
            dataframes = (
                rel.scenarios[s].df[cols]
                for s in self.scenario_set
                if not rel.scenarios[s].steering_mode_only
            )

            return pd.concat(dataframes, ignore_index=True)
        return None

    def describe(self, release):
        if self.__is_valid_release(release):
            rel = self.releases[release]
            planner_stat = (
                {**{"name": s}, **rel.scenarios[s].describe_planner_stats()}
                for s in self.scenario_set
                if rel.scenarios[s].describe_planner_stats() is not None
            )
            planner_stat_df = pd.DataFrame(planner_stat)
            if not planner_stat_df.empty:
                planner_stat_df = planner_stat_df.set_index("name")

            prediction_stat = (
                {
                    "name": s,
                    "stat_df": rel.scenarios[s].describe_prediction_stats(),
                }
                for s in self.scenario_set
                if rel.scenarios[s].describe_prediction_stats() is not None
            )

            return {
                "Planner": planner_stat_df,
                "Prediction": list(prediction_stat),
            }

        return None
