import itertools
import string

from bokeh.plotting import figure
from bokeh.layouts import gridplot
from bokeh.models import (
    ColumnDataSource,
    HoverTool,
    DatetimeTickFormatter,
    Range1d,
)
from bokeh.palettes import Category10_10
from bokeh.resources import INLINE
from bokeh.embed import file_html
import pandas as pd

from argus.utils.links import render_link
from base.logging import zoox_logger
from mined_metric.builder.metrics_impl.pisces.utils.comparator import (
    is_comparable_low_level_control_field,
)
from vehicle.common.proto.controller_pb2 import LowLevelControl


LOG = zoox_logger.ZooxLogger(__name__)


def _format_sdl_target(value):
    if pd.isna(value):
        return None
    return value.split(" ")[0]


class Plotter:
    # A set of experiments to compare
    def __init__(self):
        # TODO: The order is not guaranteed in this plotter
        self.data = {}
        self.actions = set()
        self.subgoals = set()
        self.count = 0
        self.template = None

        self.plot_columns = set(
            [
                "chum_uri",
                "sdl_target",
                "sha",
                "ars_eyaw",
                "ars_latprox_left",
                "decision_accel",
                "lane_ref_ave_steering_angle",
                "weighted_accel",
                "ars_accel",
                "ars_vel",
                "lane_ref_vel",
                "decision_final_s",
                "decision_vel",
                "lane_ref_eyaw",
                "lane_ref_ey",
                "decision_ey",
                "ars_long0_iters",
                "decision_eyaw",
                "lane_ref_arc_len",
                "ars_steering_angle",
                "ars_latprox_right",
                "zci_tick_all",
                "ars_final_s",
                "lane_ref_accel",
                "decision_steering_angle",
                "min_cost",
                "ars_arc_len",
                "lane_ref_final_s",
                "decision_ave_steering_angle",
                "lane_ref_steering_angle",
                "ars_ey",
                "corridor_left_cnt",
                "corridor_right_cnt",
                "ars_ave_steering_angle",
                "ars_long1_iters",
                "decision_arc_len",
                "num_actions_considered",
                "dist_next_state_to_ar",
                "ars_lat_iters",
            ]
        )
        # Tracker output.
        for field in LowLevelControl.DESCRIPTOR.fields:
            if is_comparable_low_level_control_field(field):
                self.plot_columns.add(field.name)

    def add_data(self, release_name, release_sha, scenario):
        df = scenario.df
        df["label"] = release_name
        df["sha"] = release_sha

        nonzero_actions = df[scenario.actions].sum() > 0
        actions = nonzero_actions[nonzero_actions].index.values

        nonzero_subgoals = df[scenario.subgoals].abs().sum() > 0
        subgoals = nonzero_subgoals[nonzero_subgoals].index.values

        self.subgoals = self.subgoals.union(subgoals)
        self.actions = self.actions.union(actions)
        self.name = scenario.name

        # Only try to plot columns that are actually in the DataFrame.
        plot_columns = set(filter(lambda field: field in df, self.plot_columns))

        # Only try to plot columns that are actually in the DataFrame.
        plot_columns = set(filter(lambda field: field in df, self.plot_columns))

        plotted_fields = (
            plot_columns.union(self.subgoals)
            .union(self.actions)
            .union(x + "_COST" for x in self.actions)
            .union(x + "_is_active" for x in self.subgoals)
        )

        self.data[release_name] = ColumnDataSource(df[plotted_fields])

    def plot_field(self, field, title):
        p = figure(
            plot_width=600, plot_height=400, title=title, x_axis_type="datetime"
        )

        for key, color in zip(self.data, itertools.cycle(Category10_10)):
            if field not in self.data[key].data:
                # Sometimes fields get renamed; this prevents us from
                # attempting to plot a nonexistent field in one SHA
                continue

            p.line(
                x="timestamp",
                y=field,
                source=self.data[key],
                legend_label=key,
                color=color,
                line_width=2,
                alpha=0.8,
            )
            p.circle(
                x="timestamp",
                y=field,
                source=self.data[key],
                legend_label=key,
                color=color,
                size=2,
                alpha=0.8,
            )

        if len(p.legend) > 0:
            p.legend.location = "bottom_left"
            p.legend.click_policy = "hide"

        p.add_tools(
            HoverTool(
                tooltips=[
                    ("time", "@timestamp{F}"),
                    (
                        "value",
                        "@" + field + "{0.2f}",
                    ),  # use @{ } for field names with spaces
                    ("label", "@label"),
                ]
            )
        )

        p.xaxis[0].formatter = DatetimeTickFormatter(
            microseconds="%H:%M:%S.%f",
            milliseconds="%H:%M:%S.%3N",
            seconds="%H:%M:%S",
            minsec="%H:%M:%S",
            minutes="%H:%M",
        )

        return p

    def plot_accel(self):
        return self.plot_field(
            "weighted_accel", title="Longitudinal Weighted ax"
        )

    def plot_subgoal(self, subgoal):
        return self.plot_field(subgoal, subgoal)

    def make_template(self):
        if self.template is None:
            template = "mined_metric/builder/metrics_impl/pisces/utils/pisces_plots.tpl"
            with open(template) as f:
                self.template = f.read()
        return self.template

    def plot_dist_to_traj_and_show(self):
        p = self.plot_field(
            "dist_next_state_to_ar", "Min Dist of Hero State to Action Ref Traj"
        )
        plots = [[p]]

        gp = gridplot(plots, merge_tools=True, sizing_mode="stretch_width")
        gp.tags = ["Trajectory Metrics"]
        return gp

    def plot_controls_and_show(
        self,
        fields=[
            "accel",
            "vel",
            "steering_angle",
            "ave_steering_angle",
            "arc_len",
            "final_s",
            "ey",
            "eyaw",
        ],
    ):
        accel = self.plot_accel()

        plots = []
        for f in fields:
            p = self.plot_field("decision_" + f, "Decision " + f)
            p.x_range = accel.x_range

            s = self.plot_field("ars_" + f, "ARS " + f)
            s.x_range = accel.x_range
            s.y_range = p.y_range

            q = self.plot_field("lane_ref_" + f, "Lane Ref " + f)
            q.x_range = accel.x_range
            q.y_range = p.y_range

            plots.append([p, s, q])

        gp = gridplot(plots, merge_tools=True, sizing_mode="stretch_width")
        gp.tags = ["Controls and States"]
        return gp

    def plot_action_ref_metrics_and_show(self):
        accel = self.plot_accel()

        latprox_left = self.plot_field(
            "ars_latprox_left", "Lateral Safety (Left)"
        )
        latprox_left.x_range = accel.x_range

        latprox_right = self.plot_field(
            "ars_latprox_right", "Lateral Safety (Right)"
        )
        latprox_right.x_range = accel.x_range
        latprox_right.y_range = latprox_left.y_range

        long0 = self.plot_field("ars_long0_iters", "ARS Long0 #Iters")
        long0.x_range = accel.x_range

        lat = self.plot_field("ars_lat_iters", "ARS Lat #Iters")
        lat.x_range = accel.x_range
        lat.y_range = long0.y_range

        long1 = self.plot_field("ars_long1_iters", "ARS Long1 #Iters")
        long1.x_range = accel.x_range
        long1.y_range = long0.y_range

        plots = [[latprox_left, latprox_right], [long0, lat, long1]]

        gp = gridplot(plots, merge_tools=True, sizing_mode="stretch_width")
        gp.tags = ["Action Reference Metrics"]
        return gp

    def plot_subgoals_and_show(self):
        accel = self.plot_accel()
        accel_copy = self.plot_accel()
        accel_copy.x_range = accel.x_range
        accel_copy.y_range = accel.y_range

        plots = [[accel, accel_copy]]

        for subgoal in self.subgoals:
            p1 = self.plot_subgoal(subgoal)
            p1.x_range = accel.x_range

            p2 = self.plot_field(subgoal + "_is_active", subgoal + " active?")
            p2.x_range = accel.x_range
            plots.append([p2, p1])

        gp = gridplot(plots, merge_tools=True, sizing_mode="stretch_width")
        gp.tags = ["Subgoals"]
        return gp

    def plot_actions_and_show(self):

        root = self.plot_field("min_cost", "Minimum cost")
        root_copy = self.plot_field("min_cost", "Minimum cost")
        root_copy.x_range = root.x_range
        root_copy.y_range = root.y_range

        num_actions = self.plot_field(
            "num_actions_considered", "Number actions"
        )
        num_actions.x_range = root.x_range

        ztrace = self.plot_field("zci_tick_all", "Decision tick time")
        ztrace.x_range = root.x_range
        ztrace.yaxis.axis_label = "Time (s)"

        plots = [[num_actions, ztrace], [root, root_copy]]

        x_range = root.x_range
        y_range = root.y_range

        for action in self.actions:
            p = self.plot_field(action, action)
            p.x_range = x_range
            p.y_range = Range1d(-0.01, 1.1)

            s = self.plot_field(action + "_COST", action + " Cost")
            s.x_range = p.x_range
            s.y_range = y_range

            plots.append([p, s])

        gp = gridplot(plots, merge_tools=True, sizing_mode="stretch_width")
        gp.tags = ["Actions"]
        return gp

    def plot_corridor_and_show(self):
        """
        Time series (vector) of number of corridor points per tick.
        """
        left_corr = self.plot_field(
            "corridor_left_cnt", "Corridor Num Pts (Left)"
        )
        right_corr = self.plot_field(
            "corridor_right_cnt", "Corridor Num Pts (Right)"
        )
        plots = [[left_corr, right_corr]]
        gp = gridplot(plots, merge_tools=True, sizing_mode="stretch_width")
        gp.tags = ["Corridor"]
        return gp

    def plot_tracker_output_and_show(self):
        """Plots Tracker output."""
        layout = [
            ["steering_angle_front", "steering_angle_rear"],
            ["steering_angle_rate_front", "steering_angle_rate_rear"],
            ["velocity", "acceleration"],
            ["tracker_state", "velocity_control_state"],
            ["drive_direction", "drive_gear"],
            ["gravity_acceleration", "current_odd"],
        ]

        ref = self.plot_accel()
        plot_grid = []
        for row in layout:
            plot_row = []
            for col in row:
                plot = self.plot_field(
                    col, string.capwords(col.replace("_", " "))
                )
                plot.x_range = ref.x_range
                plot_row.append(plot)

            plot_grid.append(plot_row)

        gp = gridplot(plot_grid, merge_tools=True, sizing_mode="stretch_width")
        gp.tags = ["Tracker Output"]
        return gp

    def render_html(self, argus_host):
        LOG.info("Rendering HTML for %s", self.name)
        controls = self.plot_controls_and_show()
        subgoals = self.plot_subgoals_and_show()
        actions = self.plot_actions_and_show()
        corridor = self.plot_corridor_and_show()
        tracker = self.plot_tracker_output_and_show()

        plots = [controls, subgoals, actions, corridor, tracker]

        accels = []
        ars_accels = []
        lane_ref_accels = []
        for r in self.data:
            s = pd.Series(
                self.data[r].data["decision_accel"],
                index=self.data[r].data["timestamp"],
            )
            s = s.resample("100ms").ffill()

            ars = pd.Series(
                self.data[r].data["ars_accel"],
                index=self.data[r].data["timestamp"],
            )
            ars = ars.resample("100ms").ffill()

            lane_ref = pd.Series(
                self.data[r].data["lane_ref_accel"],
                index=self.data[r].data["timestamp"],
            )
            lane_ref = lane_ref.resample("100ms").ffill()

            accels.append(s)
            ars_accels.append(ars)
            lane_ref_accels.append(lane_ref)

        signal = (accels[1] - accels[0]).fillna(0).abs()
        ars_signal = (ars_accels[1] - ars_accels[0]).fillna(0).abs()
        lane_ref_signal = (
            (lane_ref_accels[1] - lane_ref_accels[0]).fillna(0).abs()
        )

        idx = pd.DataFrame([signal, ars_signal, lane_ref_signal]).max().idxmax()
        max_diff_ts = idx.timestamp()

        releases = list(r for r in self.data)
        uris = list(self.data[r].data["chum_uri"][0] for r in self.data)

        argus_link = render_link(
            uris[0],
            uris[1],
            "planner_comparison",
            primary_nickname=releases[0],
            comparison_nickname=releases[1],
            hostname=argus_host,
            secure=(argus_host == "argus.zooxlabs.com"),
        )
        argmax_argus_link = render_link(
            uris[0],
            uris[1],
            "planner_comparison",
            primary_nickname=releases[0],
            comparison_nickname=releases[1],
            hostname=argus_host,
            secure=(argus_host == "argus.zooxlabs.com"),
            time=max_diff_ts,
        )
        render_uris = zip(uris, releases)

        file_name = f"{self.name}.html"

        s = file_html(
            plots,
            INLINE,
            title=file_name,
            template=self.make_template(),
            template_variables={
                "uris": list(render_uris),
                "experiments": [
                    (
                        f"chum://{self.name}",
                        self.data[key].data["sha"][0],
                        _format_sdl_target(
                            self.data[key].data["sdl_target"][0]
                        ),
                        key,
                    )
                    for key in self.data
                ],
                "argus_link": argus_link,
                "argmax_argus_link": argmax_argus_link,
                "ts": max_diff_ts,
            },
        )
        with open(file_name, "wb") as file_obj:
            file_obj.write(s.encode("utf-8"))
        return file_name
