from abc import ABC, abstractmethod
import multiprocessing as mp
import re
import sys

from bokeh.layouts import gridplot, layout
from bokeh.models import Div, Panel, Tabs
import boto3
import matplotlib
import numpy as np
import pandas as pd

from argus.utils.links import (
    render_link,
    read_layout_to_dict,
    add_highlight_entities,
)
from base.logging import zoox_logger
import mined_metric.builder.metrics_impl.pisces.utils.bokeh_plots as bokeh_plots
from mined_metric.builder.metrics_impl.pisces.utils.comparator import (
    is_comparable_low_level_control_field,
)
from mined_metric.builder.metrics_impl.pisces.utils.data_tables import (
    DataTables,
)
from mined_metric.builder.metrics_impl.pisces.utils.steering_mode_scenario_thresholds import (
    STEERING_MODE_SCENARIO_THRESHOLDS,
)
from vehicle.common.proto.controller_pb2 import LowLevelControl
from vehicle.planner.metrics.pmav.tools import record_tracking_data

log = zoox_logger.ZooxLogger(__name__)

TD_REGEX = re.compile(
    r'\s*<td id="T_(?P<id>[_0-9a-fA-F]+)row(?P<row>\d+)_col(?P<col>\d+)" class="(?P<classes>.*)"\s>(?P<value>.*)</td>\s*'
)
PRECISION_THRESH = 1e-5
PISCES_URL = "pisces.web.zooxlabs.com"
# Common formats.
INT_FMT = "{:.0f}"
VEL_FMT = "{:.2f}m/s"
ACCEL_FMT = "{:.2f}m/s<sup>2</sup>"


def set_data_attributes(df, html, subset=None, urls_df=None):
    def html_td_format(
        id, row, col, classes, order, value, url=None, info=None
    ):
        val_str = value
        if info is not None:
            val_str += " [{}]".format(info)
        if url is not None:
            val_str = '<a href="{}" target="_blank">{}</a>'.format(url, val_str)
        return """<td id="T_{}row{}_col{}" class="{}" data-order="{}">{}</td>""".format(
            id, row, col, classes, order, val_str
        )

    # insert the data-order data
    for line in html.split("\n"):
        m = TD_REGEX.search(line)
        if m is not None:
            row_idx = int(m.group("row"))
            col_idx = int(m.group("col"))
            column_name = df.columns[col_idx]
            if subset is None or column_name in subset:
                data = dict(
                    id=m.group("id"),
                    row=m.group("row"),
                    col=m.group("col"),
                    classes=m.group("classes"),
                    value=m.group("value"),
                    order=df.iloc[row_idx, col_idx],
                )
                if urls_df is not None:
                    url_column_name = column_name + "_url"
                    info_column_name = column_name + "_info"
                    if url_column_name in urls_df.columns:
                        data["url"] = urls_df[url_column_name][row_idx]
                    if info_column_name in urls_df.columns:
                        data["info"] = urls_df[info_column_name][row_idx]
                line = html_td_format(**data)
        yield line


def color_str_ok(string, positive_color="#d9331f", negative_color="#f68e2e"):
    # assumes strings of the form "X/Y differences"
    val = int(string.split("/")[0])
    return color_ok(val, positive_color, negative_color)


def color_ok(value, positive_color="#d9331f", negative_color="#f68e2e"):
    if value < 0:
        color = negative_color
        font_color = "white"
    elif value > 0:
        color = positive_color
        font_color = "white"
    else:
        color = "#338f82"
        font_color = "black"
    return "background-color: %s; color: %s" % (color, font_color)


def color_zero(value):
    if value == 0:
        color = "#ffd600"
        return "background-color: %s" % color
    return ""


def highlight_violations(thresholds):
    f = get_violations(thresholds)

    def highlight(s, targets):
        return [
            "background-color: #d9331f; color: white" if v else ""
            for v in f(s, targets)
        ]

    return highlight


def get_violations(thresholds):
    def violations(s, targets):
        """
        highlight violations in a series. if threshold can't be found, uses inf
        as the default
        """

        def exceeds_threshold(value):
            scenario = targets.loc[value.name]
            key = (scenario,) + s.name
            return value > thresholds.get(key, np.inf)

        return s.to_frame().apply(exceeds_threshold, axis=1)[s.name]

    return violations


def build_links(row):
    return """<p><a href="{url}" target="_blank">PISCES</a></p>
        <p><a href="{argus}" target="_blank">Argus</a></p>""".format(
        url=row.url, argus=row.argus
    )


def add_url_column_to_dataframe(df, argus_host, argus_layout, key_metric_cols):
    def get_argus_layout(entity_id=None):
        layout_json = read_layout_to_dict(argus_layout)
        if entity_id:
            return add_highlight_entities(layout_json, [entity_id])
        return layout_json

    def optionally_render_link(time_col, entity_id_col):
        def render(x):
            if pd.notnull(x[time_col]):
                return render_link(
                    x.ctrl_uri,
                    x.cand_uri,
                    get_argus_layout(x[entity_id_col]),
                    time=x[time_col],
                    hostname=argus_host,
                    secure=(argus_host == "argus.zooxlabs.com"),
                )
            else:
                return x[time_col]

        return render

    for col in key_metric_cols:
        time_col = col + "_url_ts"
        if time_col not in df.columns:
            continue
        entity_id_col = col + "_entity_id"

        renderer = optionally_render_link(time_col, entity_id_col)
        df[col + "_url"] = df.apply(renderer, axis=1)
    return df


def zci_tick_box_plot(df, data_input_column, branch_name):
    p = bokeh_plots.boxplot(
        df.groupby(data_input_column), "zci_tick_all", overlay_counts=True
    )
    p.xaxis[0].axis_label = "Number Considered"
    p.yaxis[0].axis_label = "Tick Time (s)"
    p.title.text = branch_name
    p.height = 300
    return p


def upload(args):
    experiment, fn = args
    s3 = boto3.client("s3")
    with open(fn, "rb") as f:
        key = "pisces/{}/{}".format(experiment, fn)
        s3.upload_fileobj(
            f, "zoox-web", key, ExtraArgs={"ContentType": "text/html"}
        )


def format_argus_link(control, candidate, comparison_layout_str, argus_host):
    def render(df):
        return render_link(
            df["ctrl_uri"],
            df["cand_uri"],
            comparison_layout_str,
            primary_nickname=control,
            comparison_nickname=candidate,
            hostname=argus_host,
            secure=(argus_host == "argus.zooxlabs.com"),
        )

    return render


class ReportTab(ABC):
    def __init__(self, dataframe, table_styles, argus_host):
        self.df = dataframe
        self.table_styles = table_styles
        self.argus_host = argus_host

    def get_coverage_sub_tab(self):
        return None

    def get_comparison_sub_tab(self):
        display_df = self.df[self._get_display_cols()]
        if "matplotlib" not in sys.modules:
            raise ImportError("background_gradient requires matplotlib")
        table_html = (
            display_df.style.background_gradient(
                subset=self._get_key_metrics(), cmap="Reds", low=0, high=0.5
            )
            .highlight_null("yellow")
            .hide_index()
            .set_precision(3)
            .set_properties(**{"font-family": "monospace"})
            .format(self._get_metric_formats())
            .render()
        )
        # insert the data-order data
        table_html = "\n".join(
            set_data_attributes(
                display_df,
                table_html,
                subset=self._get_key_metrics(),
                urls_df=self.df,
            )
        )
        data_table = DataTables(text=table_html)
        return Panel(
            child=layout(
                [data_table],
                sizing_mode="stretch_width",
                align=("center", "start"),
            ),
            title="Comparison",
        )

    def get_auxilary_sub_tabs(self):
        return []

    def get_scenario_and_uris_df(self):
        return self.df[[("", "scenario"), ("", "ctrl_uri"), ("", "cand_uri")]]

    def get_key_metric_df(self):
        key_df = self.df[self._get_key_metrics()]
        return key_df

    @abstractmethod
    def get_key_metric_definitions(self):
        raise NotImplementedError("Need to implement this method.")

    def _get_key_metrics(self):
        return list(self.get_key_metric_definitions().keys())

    @abstractmethod
    def _get_metric_formats(self):
        raise NotImplementedError("Need to implement this method.")

    @abstractmethod
    def _get_display_cols(self):
        raise NotImplementedError("Need to implement this method.")

    @property
    @abstractmethod
    def type(self):
        raise NotImplementedError("Need to implement this method.")

    def build(self, title):
        subtabs = []
        coverage_tab = self.get_coverage_sub_tab()
        if coverage_tab:
            subtabs.append(coverage_tab)
        subtabs.append(self.get_comparison_sub_tab())
        subtabs.extend(self.get_auxilary_sub_tabs())
        return Panel(
            child=layout(
                [Tabs(tabs=subtabs)],
                sizing_mode="stretch_width",
                align=("center", "start"),
            ),
            title=title,
        )


class PlannerReportTab(ReportTab):
    def __init__(
        self, dataframe, styles, argus_host, control_stats, candidate_stats
    ):
        super().__init__(dataframe, styles, argus_host)
        self.control_stats = control_stats
        self.candidate_stats = candidate_stats
        self.action_search_plots = None

    @property
    def type(self):
        return "Planner"

    def make_action_search_plots(
        self, experiment, candidate, candidate_sha, control, control_sha
    ):
        results = experiment.render_html(
            candidate, candidate_sha, control, control_sha, self.argus_host
        )
        files = pd.DataFrame(results)
        files.columns = ["scenario", "filename"]
        self.df = pd.merge(files, self.df, on="scenario")

        # TODO: put this in build()
        # TODO: stream to s3
        pool = mp.Pool(mp.cpu_count())
        pool.map(upload, ((experiment.id, fn) for fn in self.df["filename"]))
        pool.close()
        pool.join()

        format_link = lambda x: "http://{}/{}/{}".format(
            PISCES_URL, experiment.id, x
        )
        self.df["url"] = self.df["filename"].apply(format_link)
        self.df["argus"] = self.df[["ctrl_uri", "cand_uri"]].apply(
            format_argus_link(
                control, candidate, "planner_comparison", self.argus_host
            ),
            axis=1,
        )
        self.df["links"] = self.df[["url", "argus"]].apply(build_links, axis=1)

        action_details = Div(
            text="""
        <h3>Route Car Action Search Results</h3><p>This plot shows an aggregate of actions
        considered at a given Decision tick. An action is \"considered\" if it passes
        preconditions and has a rollout. Not all considered actions are terminal
        (ie. successful) because they may fail constraints. Actions are run in parallel
        so latency should not scale linearly with the number of actions. A step increase
        in latency is expected if more than seven actions are considered at a Decision
        tick.</p>
        """
        )
        reference_details = Div(
            text="""
        <h3>Action References Generator Results</h3><p>This plot shows an aggregate of
        references considered at a given Decision tick. A reference is \"considered\" if
        geometric planning and action reference smoothing is executed. Not all considered
        references are feasible and reference generation can fail. Not all feasible
        references are used by an action; contrarily, a given reference may be used by
        multiple actions. References are generated in parallel but do contend for GPU
        resources.</p>
        """
        )

        # build actions considered box plots
        df_actions_candidate = experiment.aggregate_zci_tick(
            candidate, "num_actions_considered"
        )
        df_actions_control = experiment.aggregate_zci_tick(
            control, "num_actions_considered"
        )

        p_actions = zci_tick_box_plot(
            df_actions_control, "num_actions_considered", control
        )
        d_actions = zci_tick_box_plot(
            df_actions_candidate, "num_actions_considered", candidate
        )

        # build references considered box plots
        df_refs_candidate = experiment.aggregate_zci_tick(
            candidate, "num_references_considered"
        )
        df_refs_control = experiment.aggregate_zci_tick(
            control, "num_references_considered"
        )

        p_refs = zci_tick_box_plot(
            df_refs_control, "num_references_considered", control
        )
        d_refs = zci_tick_box_plot(
            df_refs_candidate, "num_references_considered", candidate
        )

        # align plot bounds
        max_y = max(
            [
                p_actions.y_range.end,
                d_actions.y_range.end,
                p_refs.y_range.end,
                d_refs.y_range.end,
            ]
        )
        p_actions.y_range = bokeh_plots.Range1d(0, max_y)
        d_actions.y_range = bokeh_plots.Range1d(0, max_y)
        p_refs.y_range = bokeh_plots.Range1d(0, max_y)
        d_refs.y_range = bokeh_plots.Range1d(0, max_y)

        self.action_search_plots = gridplot(
            [
                [action_details],
                [p_actions, d_actions],
                [reference_details],
                [p_refs, d_refs],
            ],
            sizing_mode="stretch_width",
        )

    def get_coverage_sub_tab(self):
        # Render stats
        stats = pd.concat(
            [self.control_stats.sum(), self.candidate_stats.sum()], axis=1
        )
        stats.columns = ["Control", "Candidate"]
        # Break into actions and subgoals
        ticks = stats.loc["ticks"]
        stats_actions = stats.filter(regex="^ACTION_.*$", axis=0).copy()
        stats_subgoals = stats.filter(regex="^((?!^ACTION_).)*$", axis=0).copy()
        stats_subgoals = stats_subgoals.drop(["ticks"]).copy()

        stats_actions.loc[:, "Diff"] = (
            stats_actions["Candidate"] - stats_actions["Control"]
        )
        stats_subgoals.loc[:, "Diff"] = (
            stats_subgoals["Candidate"] - stats_subgoals["Control"]
        )

        stats_actions.index.name = "Actions"
        stats_subgoals.index.name = "Subgoals"

        action_subgoal_help = Div(
            text="""
            <p>Compares the number of ticks where actions or subgoals are
            active. If there are many 0s in the candidate and control, it
            likely means the test suite does not exercise many of the planner
            features.</p>
        """.format(
                PRECISION_THRESH
            )
        )

        action_summary = (
            stats_actions.sort_values(by="Control", ascending=True)
            .style.applymap(color_ok, subset=["Diff"])
            .applymap(color_zero, subset=["Candidate", "Control"])
            .format(
                {
                    "Control": "{{:0,}}/{:0,} ticks".format(ticks["Control"]),
                    "Candidate": "{{:0,}}/{:0,} ticks".format(
                        ticks["Candidate"]
                    ),
                }
            )
            .set_table_styles(self.table_styles)
            .set_properties(**{"border-width": 1})
            .render()
        )
        subgoal_summary = (
            stats_subgoals.sort_values(by="Control", ascending=True)
            .style.applymap(color_ok, subset=["Diff"])
            .applymap(color_zero, subset=["Candidate", "Control"])
            .format(
                {
                    "Control": "{{:0,}}/{:0,} ticks".format(ticks["Control"]),
                    "Candidate": "{{:0,}}/{:0,} ticks".format(
                        ticks["Candidate"]
                    ),
                }
            )
            .set_table_styles(self.table_styles)
            .set_properties(**{"border-width": 1})
            .render()
        )

        action_stats_table = Div(text=action_summary)
        subgoal_stats_table = Div(text=subgoal_summary)
        coverage_layout = layout(
            [[action_subgoal_help], [action_stats_table, subgoal_stats_table]],
            sizing_mode="stretch_width",
            align=("center", "start"),
        )
        return Panel(child=coverage_layout, title="Test Coverage")

    def get_auxilary_sub_tabs(self):
        return [
            Panel(
                child=self.action_search_plots, title="Action Search Statistics"
            )
        ]

    def get_key_metric_definitions(self):
        return dict(
            vx_max_diff="""<p>
                Max absolute deviation in <b>decision longitudinal velocity</b> in
                its first state at all ticks.
            </p>""",
            accel_max_diff="""<p>
                Max absolute deviation in <b>decision longitudinal acceleration</b>
                in its first control at all ticks.
            </p>""",
            ars_max_diff="""<p>
                Max absolute deviation in <b>action reference longitudinal
                acceleration</b> in its first control at all ticks.
            </p>""",
            accel_rmse="""<p>
                Root mean-squared error in <b>decision longitudinal acceleration</b>
                in its first control at all ticks.
            </p>""",
            actions_max_diff="""<p>
                Max absolute deviation in the <b>number of actions</b> considered at
                every tick.
            </p>""",
            references_max_diff="""<p>
                Max absolute deviation in the <b>number of action references</b>
                considered at every tick.
            </p>""",
            ars_accel_rmse="""<p>
                Root mean-squared error in <b>action refrence longitudinal
                acceleration</b> in its first control at all ticks.
            </p>""",
            corridor_left_max_rmse="""<p>
                Computes the root mean-squared error between the <b>left-hand
                corridor</b> of the two SHAs at every tick and reports the maximum.
            </p>""",
            corridor_right_max_rmse="""<p>
                Computes the root-mean-squared error between the <b>right-hand
                corridor</b> of the two SHAs at every tick and reports the maximum.
            </p>""",
            ey_max_diff="""<p>
                Max absolute deviation in the <b>decision lateral deviation</b> from
                the route reference in its first state at all ticks.
            </p>""",
            ey_rmse="""<p>
                Root mean-squared error in the <b>decision lateral deviation</b>
                from the route reference in its first state at all ticks.
            </p>""",
            ars_avg_num_states_diff="""<p>
                Average difference in the number of <b>action reference smoother
                states</b> between the two SHAs.
            </p>""",
            registry_num_diff="""<p>
                Max absolute deviation in the <b>number of registry entries</b>
                considered at every tick.
            </p>""",
            registry_ped_diff="""<p>
                Max absolute deviation in the <b>number of registry pedestrain entries</b>
                considered at every tick.
            </p>""",
        )

    def _get_metric_formats(self):
        pos_fmt = "{:.2f}m"
        num_fmt = "{:.3f}"
        return {
            "ctrl_dist_to_ar": pos_fmt,
            "cand_dist_to_ar": pos_fmt,
            "accel_max_diff": ACCEL_FMT,
            "ars_max_diff": ACCEL_FMT,
            "lane_ref_max_diff": ACCEL_FMT,
            "accel_rmse": ACCEL_FMT,
            "ars_accel_rmse": ACCEL_FMT,
            "lane_ref_accel_rmse": ACCEL_FMT,
            "vx_max_diff": VEL_FMT,
            "corridor_left_max_rmse": pos_fmt,
            "corridor_right_max_rmse": pos_fmt,
            "ars_avg_s_sampling_diff": pos_fmt,
            "ars_avg_num_states_diff": num_fmt,
            "ey_max_diff": pos_fmt,
            "ey_rmse": pos_fmt,
            "ctrl_latprox_left": num_fmt,
            "ctrl_latprox_right": num_fmt,
            "cand_latprox_left": num_fmt,
            "cand_latprox_right": num_fmt,
            "ctrl_num_actions": INT_FMT,
            "cand_num_actions": INT_FMT,
            "actions_max_diff": INT_FMT,
            "ctrl_num_references": INT_FMT,
            "cand_num_references": INT_FMT,
            "references_max_diff": INT_FMT,
            "registry_num_diff": INT_FMT,
            "registry_ped_diff": INT_FMT,
        }

    def _get_display_cols(self):
        return [
            # Meta data columns
            "scenario",
            "links",
            # Planner metrics columns
            "vx_max_diff",
            "accel_max_diff",
            "ars_max_diff",
            "ey_max_diff",
            "accel_rmse",
            "ars_accel_rmse",
            "ey_rmse",
            "ars_avg_num_states_diff",
            "corridor_left_max_rmse",
            "corridor_right_max_rmse",
            "actions_max_diff",
            "references_max_diff",
            "registry_num_diff",
            "registry_ped_diff",
        ]


class SteeringComparisonReportTab(ReportTab):
    """The tab for reporting VH6 2ws vs 4ws comparisons."""

    def __init__(
        self, dataframe, branch, sha, styles, argus_host, pipe_id, tracking
    ):
        df_traj = self._build_summary_table(dataframe).reset_index(drop=True)
        self.traj_gradient = df_traj.columns
        self.targets = dataframe["sdl_target"]
        self.sha = sha
        self.branch = branch
        self.pipe_id = pipe_id
        self.tracking = tracking

        links = dataframe[["ctrl_uri", "cand_uri"]].apply(
            format_argus_link(
                f"{branch}/2ws",
                f"{branch}/4ws",
                "planner_comparison",
                argus_host,
            ),
            axis=1,
        )
        df_traj[("", "links")] = links.apply(
            lambda url: f"""<p><a href="{url}" target="_blank">Argus</a></p>"""
        )
        df_traj[("", "scenario")] = self.targets  # use SDL target as scenario
        df_traj[("", "ctrl_uri")] = dataframe["ctrl_uri"]
        df_traj[("", "cand_uri")] = dataframe["cand_uri"]
        df_traj.sort_index(axis=1, inplace=True)
        super().__init__(df_traj, styles, argus_host)

    @property
    def type(self):
        return "VH6 2ws vs 4ws"

    def get_key_metric_definitions(self):
        """Returns a dict from Planner Data Comparator metric to definition."""
        return {
            ("decision", "Non-expected heading error diff(deg)"): (
                "Unaccounted heading error (error not due to change in "
                "wheelbase) when projecting VH6 4ws decision pose onto VH6 2ws "
                "decision pose"
            ),
            ("decision", "Projected Lateral error(m)"): (
                "Lateral error when projecting VH6 4ws decision pose onto VH6 "
                "2ws decision pose"
            ),
            ("decision", "Projected Velocity vx error(m/s)"): (
                "Velocity error when projecting VH6 4ws decision pose onto VH6 "
                "2ws decision pose"
            ),
        }

    def get_comparison_sub_tab(self):
        df_traj = self.df[self._get_display_cols()]

        # Create the trajectory projection errors
        traj_html = (
            df_traj.style.apply(
                highlight_violations(STEERING_MODE_SCENARIO_THRESHOLDS),
                subset=self.traj_gradient,
                targets=self.targets,
            )
            .hide_index()
            .set_precision(3)
            .set_properties(**{"font-family": "monospace"})
            .format(self._get_metric_formats())
            .render()
        )
        traj_html = "\n".join(
            set_data_attributes(
                df_traj, traj_html, subset=self._get_key_metrics()
            )
        )
        traj_table = DataTables(text=traj_html, page_length=25)

        comparison_help = Div(
            text=f"""
            <h1>VH6 2ws and 4ws Comparison</h1>
            <p>Compares <code>{self.branch}</code> to itself using VH6 2ws and
            4ws. See
            <a href="https://git.zooxlabs.com/zooxco/driving/blob/master/mined_metric/builder/metrics_impl/pisces/BUILD">mined_metric/builder/metrics_impl/pisces:steering_mode_log_tests</a>
            for the set of tests run.</p>
            <p>Definitions of the columns can be found on
            <a href="https://confluence.zooxlabs.com/display/PLAN/Planner+Motion+Comparator">Confluence</a>
            and are reproduced here for convenience.</p>
        """
        )
        comparison_defs = Div(
            text=f"""
            <h3><a id="error-desc"></a>Error Description</h3>

            <h4><a id="proj"></a>Projected Errors (lateral, heading, and
            longitudinal velocity errors)</h4>
            <p>Computes the projected lateral, heading and velocity
            <code>vx</code> errors by projecting poses from the VH6 4ws
            trajectory onto the VH6 2ws trajectory from
            <code>{self.branch}</code>.
            </p>

            <h4><a id="heading"></a>Expected Heading Error</h4>
            <p>It is the non-zero heading error difference expected to be
            present between the baseline VH6 2ws and comparison 4ws run. In this
            situation, a non-zero heading error difference is expected to be
            present by design for not being able to steer one of the axles on
            the 2ws run. This non-zero error is calculated based on the lever
            arm distance from the center of the non-steered axle to the
            wheelbase center.</p>

            <h4><a id="non-expected"></a>Non-Expected Heading Error</h4>
            <p>The non-expected heading error difference equates to the
            difference between the <a href="#proj">projected heading errors</a>
            and the <a href="#heading">expected heading error</a> difference.
            </p>

            <h4>Spatial Error (Decision s error)</h4>
            <p>The tool also computes errors in Decision trajectory's
            <code>s</code> values for longitudinal comparison between the VH6
            2ws run and the VH6 4ws run on <code>{self.branch}</code>.</p>

            <h3><a id="troubleshoot"></a>Troubleshooting Error Thresholds</h3>

            <p>If you are seeing unexpected delta in the 2ws/4ws comparison, you
            can investigate further by running the
            <a href="https://confluence.zooxlabs.com/pages/viewpage.action?pageId=153372042#PlannerMotionComparator-RunningtheTool">Planner Motion Comparator</a>
            tool with output metrics JSON logging enabled.

            <pre>
            git checkout {self.sha}
            ./sim/launch.sh local planner <b>SCENARIO</b> --save_chum \\
                --simulator_args="--params-kv sim/override_autonomous_mode=PURE_PLAYBACK_MODE \\
                    --params-kv sim/vehicle_type=vh6 \\
                    --params-kv sim/sim_model=vh6a \\
                    --params-kv planner/config_variant_id=vh6_2ws"
            ./sim/launch.sh local planner <b>SCENARIO</b> --save_chum \\
                --simulator_args="--params-kv sim/override_autonomous_mode=PURE_PLAYBACK_MODE \\
                    --params-kv sim/vehicle_type=vh6 \\
                    --params-kv sim/sim_model=vh6a \\
                    --params-kv planner/config_variant_id=vh6_4ws"
            bazel run //vis/data_analysis/data_comparator:planner_motion_comparator -- '<b><2ws chum uri></b>' '<b><4ws chum uri></b>' --decision_only
            </pre>

            If the issue exists on both the candidate and the control SHAs,
            merge the latest upstream changes to the control branch and repeat
            the above step to run the comparator tool on the control SHA to
            confirm this issue has not been fixed by an upstream change.

            If the issue still persists, post on #planner-metrics-help for
            visibility.
            </p>
            """
        )
        steering_mode_comparison = Panel(
            child=layout(
                [[comparison_help], [comparison_defs], [traj_table]],
                sizing_mode="stretch_width",
                align=("center", "start"),
            ),
            title="Comparison",
        )
        return steering_mode_comparison

    def get_key_metric_df(self):
        key_df = super().get_key_metric_df()
        key_df = key_df.apply(
            get_violations(STEERING_MODE_SCENARIO_THRESHOLDS),
            targets=self.targets,
        )
        return key_df

    def record_steering_mode_comparison_data(self, comparison_data, stats):
        """
        Computes and records data for steering mode comparison that can be
        analyzed over time at
        https://pmav.zooxlabs.com/tracking?metric=VH6_2WS_4WS_DIFF&branch=<BRANCH>.

        :param comparison_data: Comparison data that contains useful debugging
        and metadata information.
        :param stats: Bool to check whether a diff exists or not.
        """
        metric = "VH6_2WS_4WS_DIFF"
        # TODO(PLAN-16879): Update the hashableKeys to instead hash the pbtxt
        # files of the scenarios themselves since the underlying scenarios can
        # change.
        metadata = {"hashableKeys": ["scenario"], "details": []}
        if not comparison_data.empty:
            for i in range(len(comparison_data)):
                scenario = comparison_data[("", "scenario")]
                ctrl_uri = comparison_data[("", "ctrl_uri")]
                cand_uri = comparison_data[("", "cand_uri")]
                lat_err_delta = comparison_data[
                    ("decision", "Projected Lateral error(m)")
                ]
                heading_err_delta = comparison_data[
                    ("decision", "Non-expected heading error diff(deg)")
                ]
                vel_err_delta = comparison_data[
                    ("decision", "Projected Velocity vx error(m/s)")
                ]
                detail = {
                    "scenario": scenario.iloc[i],
                    "2wsUri": ctrl_uri.iloc[i],
                    "4wsUri": cand_uri.iloc[i],
                    # Convert np.bool_ to python bool to make JSON serializable
                    "latErrViolated": bool(lat_err_delta.iloc[i]),
                    "headingErrViolated": bool(heading_err_delta.iloc[i]),
                    "velErrViolated": bool(vel_err_delta.iloc[i]),
                }
                metadata["details"].append(detail)
        else:
            detail = {
                "scenario": "",
                "2wsUri": "",
                "4wsUri": "",
                "latErrViolated": False,
                "headingErrViolated": False,
                "velErrViolated": False,
            }
            metadata["details"].append(detail)

        try:
            record_tracking_data(
                metric=metric,
                title="VH6 2ws 4ws Diff",
                stats=stats,
                units={"diff": ""},
                metadata=metadata,
                branch=self.tracking,
                pipe_id=self.pipe_id,
            )
        except Exception as e:
            log.error("Error recording tracking stats: %s", str(e))

    def _get_metric_formats(self):
        """Returns a dict from SteeringComparison metric to display format."""
        ang_fmt = "{:.2f}&deg;"
        pos_fmt = "{:.2f}m"
        return {
            ("decision", "Decision s error(m)"): pos_fmt,
            ("decision", "Projected Heading error(deg)"): ang_fmt,
            ("decision", "Non-expected heading error diff(deg)"): ang_fmt,
            ("decision", "Projected Lateral error(m)"): pos_fmt,
            ("decision", "Projected Velocity vx error(m/s)"): VEL_FMT,
        }

    def _get_display_cols(self):
        """Returns the list of SteeringComparison metrics to display."""
        return [
            ("", "scenario"),
            ("", "links"),
            ("decision", "Decision s error(m)"),
            ("decision", "Projected Heading error(deg)"),
            ("decision", "Projected Lateral error(m)"),
            ("decision", "Projected Velocity vx error(m/s)"),
            ("decision", "Non-expected heading error diff(deg)"),
        ]

    def _build_summary_table(self, df):
        """Builds the summary table."""
        subset = ["decision"]

        data = {
            row["scenario"]: self._extract_max_absolute_errors(
                row["Projected errors"], subset
            )
            for _, row in df[["scenario", "Projected errors"]].iterrows()
        }
        result = pd.concat(data.values(), keys=data.keys())
        return result.unstack()

    @staticmethod
    def _extract_max_absolute_errors(entry, subset):
        """Extracts the "Max absolute" error from the nested dict."""
        df = pd.DataFrame(entry)
        return (
            df[subset]
            .dropna(how="all")
            .drop(["Expected heading error diff(deg)"])
            .applymap(lambda x: x.get("Max absolute", np.nan))
        )


class PredictionReportTab(ReportTab):
    def __init__(
        self,
        dataframe,
        styles,
        argus_host,
        control,
        candidate,
        control_stats,
        candidate_stats,
    ):
        super().__init__(dataframe, styles, argus_host)
        self.df = add_url_column_to_dataframe(
            self.df,
            argus_host,
            "prediction_comparison",
            self._get_key_metrics(),
        )
        self.df["argus"] = self.df[["ctrl_uri", "cand_uri"]].apply(
            format_argus_link(
                control, candidate, "prediction_comparison", argus_host
            ),
            axis=1,
        )
        self.df["argus"] = self.df["argus"].apply(
            lambda x: """<p><a href="{}" target="_blank">Argus</a></p>""".format(
                x
            )
        )
        self.control_stats = control_stats
        self.candidate_stats = candidate_stats

    @property
    def type(self):
        return "Prediction"

    def get_coverage_sub_tab(self):
        # Render stats
        def get_sum_of_dataframes(list_of_data):
            total_df = None
            for data in list_of_data:
                df = data["stat_df"]
                if total_df is None:
                    total_df = df
                else:
                    total_df = total_df.add(df, fill_value=0)
            return total_df

        def render_stats_table(sum_df, table_title):
            sum_df.columns = ["Control", "Candidate"]
            total_sum_row = sum_df.loc["Total"]
            sum_df = sum_df.drop(["Total"])
            sum_df.loc[:, "Diff"] = sum_df["Candidate"] - sum_df["Control"]
            sum_df_summary = (
                sum_df.sort_values(by="Control", ascending=True)
                .style.applymap(color_ok, subset=["Diff"])
                .applymap(color_zero, subset=["Candidate", "Control"])
                .format(
                    {
                        "Control": "{{:0,}}/{:0,}".format(
                            total_sum_row["Control"]
                        ),
                        "Candidate": "{{:0,}}/{:0,}".format(
                            total_sum_row["Candidate"]
                        ),
                    }
                )
                .set_table_styles(self.table_styles)
                .set_properties(**{"border-width": 1})
                .set_caption(table_title)
                .render()
            )
            return Div(text=sum_df_summary)

        sum_control = get_sum_of_dataframes(self.control_stats)
        sum_candidate = get_sum_of_dataframes(self.candidate_stats)
        tick_cnt_stats = pd.concat(
            [sum_control["tick_cnts"], sum_candidate["tick_cnts"]], axis=1
        )
        entity_cnt_stats = pd.concat(
            [sum_control["entity_cnts"], sum_candidate["entity_cnts"]], axis=1
        )

        coverage_layout = layout(
            [
                [
                    render_stats_table(tick_cnt_stats, "Number of Ticks"),
                    render_stats_table(
                        entity_cnt_stats, "Number of Unique Entities"
                    ),
                ]
            ],
            sizing_mode="stretch_width",
            align=("center", "start"),
        )
        return Panel(child=coverage_layout, title="Test Coverage")

    def get_key_metric_definitions(self):
        return dict(
            max_most_likely_at_5s_l2_diff="""<p>
                The maximum difference of L2 distance of two most likely trajectories
                between the control and candidate experiments  among all matched
                entities across the log range. The entity id and timestamp are the
                match keys.
            </p>""",
            max_trajectory_count_diff="""<p>
                The maximum difference of trajectory counts between the control and
                candidate experiments among all matched entities across the log range.
                The entity id and timestamp are the match keys.
            </p>""",
            max_trajectory_weight_diff="""<p>
                The maximum difference of trajectory weights between the control and
                candidate experiments among all matched entities across the log range.
                The entity id and timestamp are the match keys.
            </p>""",
        )

    def _get_metric_formats(self):
        num_fmt = "{:.5f}"
        pos_fmt = "{:.3f}m"
        cnt_fmt = "{:d}"
        return {
            "max_most_likely_at_5s_l2_diff": pos_fmt,
            "max_trajectory_count_diff": cnt_fmt,
            "max_trajectory_weight_diff": num_fmt,
        }

    def _get_display_cols(self):
        cols = ["scenario", "argus", "comparison_valid"]  # Meta data columns
        cols.extend(self._get_key_metrics())
        return cols


class PerceptionReportTab(ReportTab):
    def __init__(self, dataframe, styles, argus_host, control, candidate):
        super().__init__(dataframe, styles, argus_host)
        self.df = add_url_column_to_dataframe(
            self.df,
            argus_host,
            "prediction_comparison",
            self._get_key_metrics(),
        )
        self.add_pcp_track_highlight(
            self.df, argus_host, "prediction_comparison"
        )
        self.df["argus"] = self.df["argus"].apply(
            lambda x: """<p><a href="{}" target="_blank">Argus</a></p>""".format(
                x
            )
        )

    def add_pcp_track_highlight(
        self, df, argus_host, argus_layout="prediction_comparison"
    ):
        def get_argus_layout(diff_tracks_A=None, diff_tracks_B=None):
            # diff_tracks_A is the different track ids and timestamps for sha A
            # diff_tracks_B is the different track ids and timestamps for sha B
            layout_json = read_layout_to_dict(argus_layout)
            if diff_tracks_A and diff_tracks_B:
                track_id_A = list(diff_tracks_A.keys())
                timestamps_A = list(diff_tracks_A.values())
                track_id_B = list(diff_tracks_B.keys())
                timestamps_B = list(diff_tracks_B.values())
                layout_json = add_highlight_entities(
                    layout_json, track_id_A, timestamps_A, 0
                )
                layout_json = add_highlight_entities(
                    layout_json, track_id_B, timestamps_B, 1
                )
                return layout_json
            return layout_json

        def optionally_render_link(diff_tracks_A_key, diff_tracks_B_key):
            def render(x):
                if pd.notnull(x[diff_tracks_A_key]) and pd.notnull(
                    x[diff_tracks_B_key]
                ):
                    return render_link(
                        x["ctrl_uri"],
                        x["cand_uri"],
                        get_argus_layout(
                            x[diff_tracks_A_key], x[diff_tracks_B_key]
                        ),
                        time=None,
                        hostname=argus_host,
                        secure=(argus_host == "argus.zooxlabs.com"),
                    )
                else:
                    return None

            return render

        renderer = optionally_render_link("diff_tracks_A", "diff_tracks_B")
        df["argus"] = df.apply(renderer, axis=1)
        return df

    def get_key_metric_definitions(self):
        return dict(
            fraction_diff="""<p>
                Fraction of tracks which are new in either of the branches.
            </p>"""
        )

    def _get_metric_formats(self):
        num_fmt = "{:.5f}"
        pos_fmt = "{:.3f}m"
        cnt_fmt = "{:d}"
        return {"fraction_diff": num_fmt}

    def _get_display_cols(self):
        cols = ["scenario", "argus"]  # Meta data columns
        cols.extend(self._get_key_metrics())
        return cols

    @property
    def type(self):
        return "Perception"


class TrackerReportTab(ReportTab):
    """The tab for reporting Tracker-related comparisons."""

    def __init__(self, dataframe, styles, argus_host):
        super().__init__(dataframe, styles, argus_host)

    @property
    def type(self):
        return "Tracker"

    def get_key_metric_definitions(self):
        """Returns a dict from Tracker metric to definition."""
        definitions = {}
        for field in LowLevelControl.DESCRIPTOR.fields:
            if is_comparable_low_level_control_field(field):
                definitions[
                    field.name + "_max_diff"
                ] = """<p>
                    Max absolute deviation in the <b>Tracker {}</b> considered
                    at every tick.
                </p>""".format(
                    field.name
                )
        return definitions

    def _get_metric_formats(self):
        """Returns a dict from Tracker metric to display format."""
        ang_fmt = "{:.3f}rad"
        ang_vel_fmt = "{:.3f}rad/s"
        return {
            "steering_angle_front_max_diff": ang_fmt,
            "steering_angle_rear_max_diff": ang_fmt,
            "steering_angle_rate_front_max_diff": ang_vel_fmt,
            "steering_angle_rate_rear_max_diff": ang_vel_fmt,
            "velocity_max_diff": VEL_FMT,
            "acceleration_max_diff": ACCEL_FMT,
            "tracker_state_max_diff": INT_FMT,
            "velocity_control_state_max_diff": INT_FMT,
            "drive_direction_max_diff": INT_FMT,
            "drive_gear_max_diff": INT_FMT,
            "gravity_acceleration_max_diff": ACCEL_FMT,
            "current_odd_max_diff": INT_FMT,
        }

    def _get_display_cols(self):
        """Returns the list of Tracker metrics to display."""
        columns = ["scenario", "links"]
        for field in LowLevelControl.DESCRIPTOR.fields:
            if is_comparable_low_level_control_field(field):
                columns.append(field.name + "_max_diff")
        return columns


class SummaryTab:
    def __init__(self, report_tabs, styles):
        self.report_tabs = report_tabs
        self.styles = styles
        self.summary_df = None

    def get_changed_metrics(self):
        """
        Get the counts of metrics that are changed
        if no changes 0's would be returned
        """
        if self.summary_df is None:
            raise RuntimeError(
                "Can't get summary metrics without building the summary tab first"
            )
        # The columns of self.summary_df should be the same as the ones defined
        # in the `self.build()` function
        return dict(self.summary_df.groupby(["Metric Type"])["Count"].sum())

    def build(self, title="Summary"):
        summary = {}
        scenario_and_uris_df = pd.DataFrame()
        diff_cnt = 0
        for report_tab in self.report_tabs:
            key_metric_df = report_tab.get_key_metric_df()
            if (
                isinstance(report_tab, SteeringComparisonReportTab)
                and report_tab.tracking
            ):
                scenario_and_uris_df = report_tab.get_scenario_and_uris_df()
                violated_df = key_metric_df[
                    (key_metric_df > PRECISION_THRESH).any(1)
                ]
                scenario_and_uris_df = scenario_and_uris_df[
                    scenario_and_uris_df.index.isin(violated_df.index)
                ]
                scenario_and_uris_df = pd.concat(
                    [scenario_and_uris_df, violated_df], axis=1
                )
            total_cnt_gt_thresh = (key_metric_df > PRECISION_THRESH).sum()
            total_cnt = key_metric_df.count()
            for (
                def_key,
                def_description,
            ) in report_tab.get_key_metric_definitions().items():
                diff_cnt = total_cnt_gt_thresh[def_key]
                total = total_cnt[def_key]
                diff_summary = "{:0}/{:0} differences".format(diff_cnt, total)

                # If the key already exists, _diff_ the diff_cnt (diff of diffs)
                # for the key to display differences in the underlying
                # comparison.
                #
                # This is predominantly used for 2ws vs 4ws comparisons.
                #
                # This behavior will be undefined if we try to update a record
                # multiple times.
                if def_key in summary:
                    existing_record = summary[def_key]
                    #
                    # ********************** WARNING **************************
                    #
                    # This is super brittle. It relies on the fact that the
                    # candidate tab is _first_ in the list of report tabs.
                    # Thus, when there are overlaps, the existing record is
                    # on the candidate, while the new value of `diff_cnt` is
                    # from the control. So the code is trying to diff the number
                    # of metrics exceeding threshold on the candidate to those
                    # on the control.
                    diff_of_diff_cnt = existing_record[3] - diff_cnt
                    diff_summary = "{:0}/{:0} differences".format(
                        diff_of_diff_cnt, total
                    )
                    diff_cnt = diff_of_diff_cnt
                summary.update(
                    {
                        def_key: [
                            report_tab.type,
                            def_description,
                            diff_summary,
                            diff_cnt,
                        ]
                    }
                )
            if (
                isinstance(report_tab, SteeringComparisonReportTab)
                and report_tab.tracking
            ):
                stats = {"diff": int(diff_cnt > 0)}
                report_tab.record_steering_mode_comparison_data(
                    scenario_and_uris_df, stats
                )
        self.summary_df = pd.DataFrame.from_dict(
            summary,
            orient="index",
            columns=["Metric Type", "Definition", "Summary", "Count"],
        )
        summary_df = self.summary_df[
            ["Metric Type", "Definition", "Summary"]
        ].sort_values(by="Metric Type")
        summary_html = (
            summary_df.style.applymap(
                color_str_ok, negative_color="#569edd", subset=["Summary"]
            )
            .set_table_styles(self.styles)
            .set_properties(**{"border-width": 1, "max-width": "600px"})
            .render()
        )
        summary_table = Div(text=summary_html)
        summary_help = Div(
            text="""
        <p>There is no behavior change between branches when all metrics are
        green (or blue, improvements to 2ws vs 4ws).</p>
        <p> For the <b>Planner</b>, <b>Prediction</b>, <b>Perception</b> or
        <b>Tracker</b> metric types, failures occur when the key metrics
        exceed a threshold of <code>{:.1e}</code>.</p>
        <p> For the <b>VH6 2ws vs 4ws</b> metric type, failures
        occur when the key metrics exceed the scenario-specific thresholds on
        the candidate SHA alone. The scenario-specific thresholds are
        specified in
        <a href="https://git.zooxlabs.com/zooxco/driving/blob/master/mined_metric/builder/metrics_impl/pisces/utils/steering_mode_scenario_thresholds.py">VH6 2ws vs 4ws Thresholds</a>.</p>
        <p> If there are any deviations, please use information in the
        corresponding tabs to investigate.</p>
        """.format(
                PRECISION_THRESH
            )
        )
        summary_layout = layout(
            [[summary_table], [summary_help]],
            sizing_mode="stretch_width",
            align=("center", "start"),
        )
        return Panel(child=summary_layout, title=title)
