import json

from google.protobuf.descriptor import FieldDescriptor
import numpy as np
import pandas as pd

from vehicle.common.proto.controller_pb2 import LowLevelControl
import mined_metric.builder.metrics_impl.pisces.utils.pcp_matcher_py as pcp_matcher


def is_comparable_low_level_control_field(field):
    """Returns True if LowLevelControl field is comparable.

    This logic should be kept compatible with what's in extract_data.cpp.
    """
    return (
        (field.name != "tracing_tags")
        and (field.type != FieldDescriptor.TYPE_MESSAGE)
        and (field.type != FieldDescriptor.TYPE_STRING)
        and (field.type != FieldDescriptor.TYPE_BYTES)
        and (field.label != FieldDescriptor.LABEL_REPEATED)
    )


def find_sequence(df, weight=None):
    series = df.idxmax(axis=1).dropna()
    last_elem = None
    seq = []
    last_ts = 0
    for ts, elem in series.iteritems():
        if elem != last_elem:
            seq.append({"action": elem, "duration": 0, "weight": 0})
        else:
            seq[-1]["duration"] = (
                seq[-1]["duration"] + (ts - last_ts).total_seconds()
            )
            if weight is not None:
                seq[-1]["weight"] = seq[-1]["weight"] + weight.loc[ts]
        last_elem = elem
        last_ts = ts

    if weight is not None:
        for s in seq:
            s["weight"] = s["weight"] / s["duration"]
    return seq


def diff_sequence(seqA, seqB, seqA_weights=None, seqB_weights=None, window=1):
    if len(seqA) == 0 or len(seqB) == 0:
        return 0
    dynamic_time_warping_costs = np.zeros((len(seqA), len(seqB)))
    for i in range(len(dynamic_time_warping_costs)):
        for j in range(len(dynamic_time_warping_costs[i])):
            dynamic_time_warping_costs[i][j] = np.inf

    dynamic_time_warping_costs[0][0] = 0
    for i in range(len(dynamic_time_warping_costs)):
        for j in range(i - window, i + window + 1):
            if j < 0:
                continue
            if j >= len(seqB):
                continue
            dynamic_time_warping_costs[i][j] = 0

    for i in range(len(dynamic_time_warping_costs)):
        for j in range(i - window, i + window + 1):
            if j < 0:
                continue
            if j >= len(seqB):
                continue
            if seqA[i] == seqB[j]:
                cost = 0
            else:
                cost = 0

            best_action = None
            best_cost = np.inf
            if (j - 1) >= 0:
                deletion_cost = dynamic_time_warping_costs[i][j - 1]
                if deletion_cost <= best_cost:
                    best_action = "delete"
                    best_cost = deletion_cost

            if (i - 1) >= 0:
                insertion_cost = dynamic_time_warping_costs[i - 1][j]
                if insertion_cost < best_cost:
                    best_action = "insert"
                    best_cost = insertion_cost

            if (i - 1) >= 0 and (j - 1) >= 0:
                match_cost = dynamic_time_warping_costs[i - 1][j - 1]
                if match_cost < best_cost:
                    best_action = "match"
                    best_cost = match_cost

            if best_action == "delete" or best_action == "insert":
                best_cost += 1

            if best_action is None:
                best_cost = 0

            dynamic_time_warping_costs[i][j] = cost + best_cost

    return dynamic_time_warping_costs[len(seqA) - 1][len(seqB) - 1]


class Comparator:
    # Compares data from two "SHAS" (dataframes)
    # Maybe just compare last couple seconds
    #
    # dfA = candidate
    # dfB = control
    def __init__(self, name, dfA, dfB):
        self.name = name

        # resample will produce NaNs when the timestamps don't fall on 100ms
        # boundaries for a dataframe with a *single* row
        if len(dfA) == len(dfB) == 1:
            self.dfA = dfA
            self.dfB = dfB
        else:
            self.dfA = dfA.resample("100ms").ffill()
            self.dfB = dfB.resample("100ms").ffill()

        actions_in_A = filter(
            lambda x: x.startswith("ACTION") and not x.endswith("_COST"),
            dfA.columns,
        )
        actions_in_B = filter(
            lambda x: x.startswith("ACTION") and not x.endswith("_COST"),
            dfB.columns,
        )
        self.actions = list(set(actions_in_A).union(set(actions_in_B)))

        subgoals_in_A = filter(
            lambda x: not x.startswith("ACTION")
            and not x.endswith("_COST")
            and not x.endswith("_is_active")
            and x.isupper(),
            dfA.columns,
        )
        subgoals_in_B = filter(
            lambda x: not x.startswith("ACTION")
            and not x.endswith("_COST")
            and not x.endswith("_is_active")
            and x.isupper(),
            dfB.columns,
        )
        self.subgoals = list(set(subgoals_in_A).union(set(subgoals_in_B)))

        # track overlapping time indexes
        ts_a = self.dfA.dropna(how="all").index
        ts_b = self.dfB.dropna(how="all").index
        ts_overlap = ts_a.intersection(ts_b)

        self.dfA = self.dfA.loc[ts_overlap]
        self.dfB = self.dfB.loc[ts_overlap]

    def correlate(self, field):
        # add a very small amount of noise (same) to both so that there is
        # *some* variance
        eps = pd.Series(
            1e-6 * np.random.randn(len(self.dfA[field])), index=self.dfA.index
        )
        a = self.dfA[field] + eps
        b = self.dfB[field] + eps

        return a.corr(b, method="spearman")

    def argmax_diff(self, field):
        """ Find the location of the max delta between the series. """
        signal = (self.dfA[field] - self.dfB[field]).fillna(0).abs()
        idx = signal.idxmax()
        return idx

    def max_abs_diff(self, field):
        if field in self.dfA and field in self.dfB:
            signal = (self.dfA[field] - self.dfB[field]).fillna(0).abs()
            return signal.max()
        return np.nan

    def rmse(self, field):
        return np.sqrt(
            (((self.dfA[field] - self.dfB[field])).apply(np.square)).sum()
            / (len(self.dfA))
        )

    def diff_accel(self):
        # rmse
        return self.correlate("decision_accel")

    def diff_arclen(self):
        return self.correlate("decision_arc_len")

    def diff_action_sequence(self):
        seqA = find_sequence(self.dfA[self.actions])
        seqB = find_sequence(self.dfB[self.actions])

        # returns number of different actions
        return diff_sequence(
            [x["action"] for x in seqA], [x["action"] for x in seqB], window=10
        )

    def diff_subgoal_sequence(self):
        # ensure most "conservative" subgoals always match (we agree on how hard
        # to deccel)
        seqA = find_sequence(-self.dfA[self.subgoals])
        seqB = find_sequence(-self.dfB[self.subgoals])

        # returns number of different subgoals
        return diff_sequence(
            [x["action"] for x in seqA], [x["action"] for x in seqB], window=10
        )

    def corridor_max_rmse(self, field):
        """
        Compute the maximum RMSE between corridor boundaries.

        Over every tick, if the current_action value matches, compute the
        RMSE between a corridor boundary.  The "residuals" are the euclidean
        distance between boundary points of matching index.  If the corridor
        boundaries consist of differing numbers of points at that tick,
        ignore the extra points in the longer.
        """
        max_rmse = 0.0
        for (
            (action_ts_A, action_col_A),
            (action_ts_B, action_col_B),
            (corr_ts_A, corr_col_A),
            (corr_ts_B, corr_col_B),
        ) in zip(
            self.dfA["current_action"].iteritems(),
            self.dfB["current_action"].iteritems(),
            self.dfA[field].iteritems(),
            self.dfB[field].iteritems(),
        ):
            if action_col_A == action_col_B:
                try:
                    a_traj = np.array(json.loads(corr_col_A))
                    b_traj = np.array(json.loads(corr_col_B))
                except:
                    continue
                num_pts = min(len(a_traj), len(b_traj))
                if num_pts == 0:
                    continue
                residuals = np.linalg.norm(
                    a_traj[:num_pts] - b_traj[:num_pts], axis=1
                )
                rmse = np.sqrt(np.mean(np.square(residuals)))
                max_rmse = max(max_rmse, rmse)
        return max_rmse

    def diff_avg_sampling_resolution(self):
        avg_sampling_a = self.dfA["ars_avg_s_spacing"]
        avg_sampling_b = self.dfB["ars_avg_s_spacing"]
        return np.mean(avg_sampling_a) - np.mean(avg_sampling_b)

    def diff_avg_num_states(self):
        num_ars_states_a = self.dfA["ars_num_states"]
        num_ars_states_b = self.dfB["ars_num_states"]
        return np.mean(num_ars_states_a - num_ars_states_b)

    def diff(self):
        ctrl_uri = self.dfB["chum_uri"].iloc[0]
        cand_uri = self.dfA["chum_uri"].iloc[0]

        diffs = {
            "action_seq": self.diff_action_sequence(),
            "subgoal_seq": self.diff_subgoal_sequence(),
            "accel_rmse": self.rmse("decision_accel"),
            "accel_max_diff": self.max_abs_diff("decision_accel"),
            "arclen_rmse": self.rmse("decision_arc_len"),
            "final_s_rmse": self.rmse("decision_final_s"),
            "arbiter_ax_rmse": self.rmse("weighted_accel"),
            "ave_steering_angle_rmse": self.rmse("decision_ave_steering_angle"),
            "argmax_diff_accel": self.argmax_diff("decision_accel"),
            "vx_rmse": self.rmse("decision_vel"),
            "vx_max_diff": self.max_abs_diff("decision_vel"),
            "ey_rmse": self.rmse("decision_ey"),
            "ey_max_diff": self.max_abs_diff("decision_ey"),
            "ars_accel_rmse": self.rmse("ars_accel"),
            "ars_max_diff": self.max_abs_diff("ars_accel"),
            "ars_vx_rmse": self.rmse("ars_vel"),
            "ars_vx_max_diff": self.max_abs_diff("ars_vel"),
            "ars_avg_s_sampling_diff": self.diff_avg_sampling_resolution(),
            "ars_avg_num_states_diff": self.diff_avg_num_states(),
            "lane_ref_accel_rmse": self.rmse("lane_ref_accel"),
            "lane_ref_max_diff": self.max_abs_diff("lane_ref_accel"),
            "lane_ref_vx_rmse": self.rmse("lane_ref_vel"),
            "lane_ref_vx_max_diff": self.max_abs_diff("lane_ref_vel"),
            "corridor_left_max_rmse": self.corridor_max_rmse("corridor_left"),
            "corridor_right_max_rmse": self.corridor_max_rmse("corridor_right"),
            "min_dist_diff": self.max_abs_diff("dist_next_state_to_ar"),
            "ctrl_dist_to_ar": self.dfB["dist_next_state_to_ar"].max(),
            "cand_dist_to_ar": self.dfA["dist_next_state_to_ar"].max(),
            "latprox_rmse": np.sqrt(
                0.5 * np.square(self.rmse("ars_latprox_left"))
                + 0.5 * np.square(self.rmse("ars_latprox_right"))
            ),
            "ctrl_latprox_left": self.dfB["ars_latprox_left"].max(),
            "ctrl_latprox_right": self.dfB["ars_latprox_right"].max(),
            "cand_latprox_left": self.dfA["ars_latprox_left"].max(),
            "cand_latprox_right": self.dfA["ars_latprox_right"].max(),
            "ctrl_num_actions": self.dfB["num_actions_considered"].max(),
            "cand_num_actions": self.dfA["num_actions_considered"].max(),
            "actions_max_diff": self.max_abs_diff("num_actions_considered"),
            "ctrl_num_references": self.dfB["num_references_considered"].max(),
            "cand_num_references": self.dfA["num_references_considered"].max(),
            "references_max_diff": self.max_abs_diff(
                "num_references_considered"
            ),
            "registry_num_diff": self.max_abs_diff("num_registry_entry_size"),
            "registry_ped_diff": self.max_abs_diff(
                "num_registry_ped_entry_size"
            ),
        }

        # Tracker output.
        for field in LowLevelControl.DESCRIPTOR.fields:
            if is_comparable_low_level_control_field(field):
                diffs[field.name + "_max_diff"] = self.max_abs_diff(field.name)

        header = {
            "scenario": self.name,
            "ctrl_uri": ctrl_uri,
            "cand_uri": cand_uri,
        }

        return {"diffs": {**header, **diffs}}


class SteeringComparator:
    """Comparator for steering mode scenarios."""

    def __init__(self, name, dfA, dfB):
        assert dfA["hash_id"].iloc[0] == dfB["hash_id"].iloc[0]
        assert dfA["sdl_target"].iloc[0] == dfB["sdl_target"].iloc[0]

        self.name = name

        self.dfA = dfA
        self.dfB = dfB

    def _diff(self, df):
        # this has all the sdl args, so we split by spaces to keep just the main
        # target
        sdl_target_args = df["sdl_target"].fillna("").iloc[0].split()
        if len(sdl_target_args) > 0:
            sdl_target = sdl_target_args[0]
        else:
            sdl_target = ""

        ctrl_uri = df.control_chum_uri.iloc[0]
        cand_uri = df.candidate_chum_uri.iloc[0]
        planner_comparator_results = df.results.to_dict()

        header = {
            "scenario": self.name,
            "sdl_target": sdl_target,
            "ctrl_uri": ctrl_uri,
            "cand_uri": cand_uri,
        }

        return {**header, **planner_comparator_results}

    def diff(self):
        return {
            "candidate": self._diff(self.dfA),
            "control": self._diff(self.dfB),
        }


class PredComparator:
    def __init__(self, name, dfA, dfB):
        self.name = name
        self.dfA = dfA
        self.dfB = dfB
        # To avoid floating point difference, round the timestamp up to the next
        # 10ms. Note, shouldn't do floor, because this would cause the argus
        # always load one prediction message before the event time.
        self.dfA["timestamp"] = self.dfA["timestamp"].apply(
            lambda x: x.ceil(freq="10ms")
        )
        self.dfB["timestamp"] = self.dfB["timestamp"].apply(
            lambda x: x.ceil(freq="10ms")
        )
        self.dfA["pos_at_5s_x"] = (
            self.dfA["pos_at_5s"]
            .apply(lambda x: x.strip("()").split("|")[0])
            .astype("float32")
        )
        self.dfA["pos_at_5s_y"] = (
            self.dfA["pos_at_5s"]
            .apply(lambda x: x.strip("()").split("|")[1])
            .astype("float32")
        )
        self.dfB["pos_at_5s_x"] = (
            self.dfB["pos_at_5s"]
            .apply(lambda x: x.strip("()").split("|")[0])
            .astype("float32")
        )
        self.dfB["pos_at_5s_y"] = (
            self.dfB["pos_at_5s"]
            .apply(lambda x: x.strip("()").split("|")[1])
            .astype("float32")
        )

        # Combine the two dataframes on the same time to facilitate
        # aggregations.
        self.df = pd.merge(
            self.dfA, self.dfB, on=["timestamp", "track_id"], how="inner"
        )
        self.df["L2_diff"] = np.sqrt(
            np.square(self.df["pos_at_5s_x_x"] - self.df["pos_at_5s_x_y"])
            + np.square(self.df["pos_at_5s_y_x"] - self.df["pos_at_5s_y_y"])
        )

    def max_abs_diff(self, field, diff_ops=None):
        field_x = field + "_x"
        field_y = field + "_y"
        if diff_ops is None:
            diff_col = (self.df[field_x] - self.df[field_y]).fillna(0).abs()
        else:
            diff_col = diff_ops(self.df[field_x], self.df[field_y])
        return (
            (diff_col.max(), diff_col.idxmax())
            if len(diff_col)
            else (None, None)
        )

    def get_iou_of_matches(self):
        # This function obtains the intersection over union for matching records
        # between the two dataframes. This is a metric to recognize prediction
        # from 2 releases may not generate the same number of output at exactly
        # the same time.
        # High mismatch would have a low iou score. A good threshold for
        # acceptable comparison is yet to be determined.
        outer_joined_df = pd.merge(
            self.dfA, self.dfB, on=["track_id", "timestamp"], how="outer"
        )
        if outer_joined_df.shape[0] > 0:
            return self.df.shape[0] / outer_joined_df.shape[0]
        return 0

    def diff(self):
        def gen_key_metric_cols(metric_name, val=0, val_thresh=0, row_idx=0):
            url_ts_field = metric_name + "_url_ts"
            info_field = metric_name + "_info"
            entity_id_field = metric_name + "_entity_id"
            cols = {metric_name: val, url_ts_field: None, info_field: None}
            if val is not None and val > val_thresh:
                cols[url_ts_field] = self.df["timestamp"][row_idx].timestamp()
                cols[info_field] = "EntityID:{}".format(
                    self.df["track_id"][row_idx]
                )
                cols[entity_id_field] = self.df["track_id"][row_idx]
            return cols

        metrics = {"scenario": self.name}
        if not self.df.empty:
            metrics.update(
                {
                    "ctrl_uri": self.dfA["chum_uri"].iloc[0],
                    "cand_uri": self.dfB["chum_uri"].iloc[0],
                }
            )
            metrics.update(
                gen_key_metric_cols(
                    metric_name="max_most_likely_at_5s_l2_diff",
                    val=self.df["L2_diff"].max(),
                    val_thresh=1e-6,
                    row_idx=self.df["L2_diff"].idxmax(),
                )
            )
            max_num_trajs_diff, max_num_trajs_diff_idx = self.max_abs_diff(
                "num_trajs"
            )
            metrics.update(
                gen_key_metric_cols(
                    metric_name="max_trajectory_count_diff",
                    val=max_num_trajs_diff,
                    val_thresh=0,
                    row_idx=max_num_trajs_diff_idx,
                )
            )
            max_weight_diff, max_weight_diff_idx = self.max_abs_diff(
                "traj_weight"
            )
            metrics.update(
                gen_key_metric_cols(
                    metric_name="max_trajectory_weight_diff",
                    val=max_weight_diff,
                    val_thresh=0,
                    row_idx=max_weight_diff_idx,
                )
            )
        else:
            metrics.update(
                gen_key_metric_cols(metric_name="max_most_likely_at_5s_l2_diff")
            )
            metrics.update(
                gen_key_metric_cols(metric_name="max_trajectory_count_diff")
            )
            metrics.update(
                gen_key_metric_cols(metric_name="max_trajectory_weight_diff")
            )

        # Add some final metrics.
        metrics.update({"comparison_valid": self.get_iou_of_matches() == 1.0})
        return metrics


class PCPComparator:
    def __init__(self, name, dfA, dfB):
        self.name = name
        self.dfA = dfA
        self.dfB = dfB

    # Comparator should return a self defined dictionary of results that we want
    # to visualize in the HTML page. The standard way is to return simple key
    # value pairs. But more complicated ways can be done, such as use some
    # self-defined formats and parse it later during viz.
    def diff(self):
        timestamp_a = self.dfA["timestamp"].astype(int) / 10 ** 9
        timestamp_b = self.dfB["timestamp"].astype(int) / 10 ** 9
        diff_tracks_A, diff_tracks_B = pcp_matcher.compare_pcp(
            timestamp_a.tolist(),
            self.dfA["track_id"].tolist(),
            self.dfA["center_x"].tolist(),
            self.dfA["center_y"].tolist(),
            self.dfA["center_z"].tolist(),
            self.dfA["extent_x"].tolist(),
            self.dfA["extent_y"].tolist(),
            self.dfA["extent_z"].tolist(),
            self.dfA["yaw"].tolist(),
            self.dfA["class"].tolist(),
            timestamp_b.tolist(),
            self.dfB["track_id"].tolist(),
            self.dfB["center_x"].tolist(),
            self.dfB["center_y"].tolist(),
            self.dfB["center_z"].tolist(),
            self.dfB["extent_x"].tolist(),
            self.dfB["extent_y"].tolist(),
            self.dfB["extent_z"].tolist(),
            self.dfB["yaw"].tolist(),
            self.dfB["class"].tolist(),
        )
        num_track_instances = len(timestamp_a) + len(timestamp_b)
        num_diffs = sum([len(v) for v in diff_tracks_A.values()]) + sum(
            [len(v) for v in diff_tracks_B.values()]
        )
        fraction_diff = 0
        if num_track_instances > 0:
            fraction_diff = num_diffs / float(num_track_instances)

        data = {
            "scenario": self.name,
            "fraction_diff": fraction_diff,
            "diff_tracks_A": diff_tracks_A,
            "diff_tracks_B": diff_tracks_B,
        }
        if not self.dfA.empty and not self.dfB.empty:
            data.update(
                {
                    "ctrl_uri": self.dfA["chum_uri"].iloc[0],
                    "cand_uri": self.dfB["chum_uri"].iloc[0],
                }
            )
        return data
