""" Plot the error between two inputs + the covariances as confidence bands. """

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from base.geometry.isometry3d_py import Isometry3d
from base.plotting.plot_utils import (
    DPI,
    HEIGHT_2K,
    WIDTH_2K,
)
from clams.trajectory.trajectory_py import Pose, Trajectory
from data.tools import zed
from localization.common.arg_utils_py import get_parser_for_any_input
from localization.common.dofs import ROTATION_DOFS, TRANSLATION_DOFS, DOFS
from localization.common.logging_utils import configure_logging_module
from localization.common.pandas_utils import interpolate_dataframe
from localization.data_loading.argparse_utils import chum_uris_from_args
from localization.data_loading.load_poses_py import load_hero_global_poses
from vehicle.localization.data_loading.topics_py import swf_solver_state_topic


COLORS = ['r', 'g', 'b']


def get_global_pose_covariances(reader):
    t = reader.driving_mofo_swf_solver_state[:].pose.time

    cov_x = reader.driving_mofo_swf_solver_state[:].pose.sigmas.bl_sig__map_trans_bl.x
    cov_y = reader.driving_mofo_swf_solver_state[:].pose.sigmas.bl_sig__map_trans_bl.y
    cov_z = reader.driving_mofo_swf_solver_state[:].pose.sigmas.bl_sig__map_trans_bl.z
    cov_R = reader.driving_mofo_swf_solver_state[:].pose.sigmas.bl_sig__map_rot_bl.x
    cov_P = reader.driving_mofo_swf_solver_state[:].pose.sigmas.bl_sig__map_rot_bl.y
    cov_Y = reader.driving_mofo_swf_solver_state[:].pose.sigmas.bl_sig__map_rot_bl.z

    data = np.transpose([t, cov_x, cov_y, cov_z, cov_R, cov_P, cov_Y])
    return pd.DataFrame(data=data, columns=("t",) + DOFS)


def get_data(uri):
    reader = zed.from_chum(uri, [swf_solver_state_topic()])
    poses = load_hero_global_poses(uri, wait_for_first_valid_hero=True)
    covs = get_global_pose_covariances(reader)
    return poses, covs


def plot_error_and_covariance(ax, dof, t, error, cov, color='0.5'):
    ax.plot(t, error, color=color)
    ax.fill_between(t, error - cov, error + cov, color=color, alpha=0.2)


def plot_cov_with_error(uri, control_uri):
    # This script requires a control to work
    assert control_uri, "control chum URI is required!"

    # Load the data.
    traj, covs = get_data(uri)
    control_traj, control_covs = get_data(control_uri)

    # Get the pose errors.
    # NOTE: this interpolates the pose timestamps into the ground truths.
    pose_errors = traj - control_traj
    t_error = [p.timestamp() * 1.0e-9 for p in pose_errors]

    # Resample the covariance data to the desired timestamps.
    covs = interpolate_dataframe(covs, t_error)
    control_covs = interpolate_dataframe(control_covs, t_error)

    # Set up the plot structure.
    fig, axes = plt.subplots(nrows=6, sharex=True, figsize=(WIDTH_2K/DPI, HEIGHT_2K/DPI), dpi=DPI)

    # Poopulate the translation DoFs.
    for i, (dof, color) in enumerate(zip(TRANSLATION_DOFS, COLORS)):
        error = [p.translation()[i] for p in pose_errors]
        plot_error_and_covariance(axes[i], dof, t_error, error, covs[dof], color)
        plot_error_and_covariance(axes[i], dof, t_error, [0] * len(error), control_covs[dof])
        axes[i].set_ylabel("{} error [m]".format(dof))

    # Poopulate the rotation DoFs.
    for i, (dof, color) in enumerate(zip(ROTATION_DOFS, COLORS)):
        error = [p.isometry().rotation_rpy()[i] for p in pose_errors]
        #print("DOF: {}, error:\n{}".format(dof, error))
        plot_error_and_covariance(axes[i+3], dof, t_error, np.degrees(error), np.degrees(covs[dof]), color)
        plot_error_and_covariance(axes[i+3], dof, t_error, [0] * len(error), np.degrees(control_covs[dof]))
        axes[i+3].set_ylabel("{} error [deg]".format(dof))

    # Finishing touches.
    plt.suptitle("Online vs Batch solution with covariances for:\n{}\n{}".format(uri, control_uri))
    axes[-1].set_xlabel("timestamp [s]")
    plt.show()


def main(argv=None):
    # Get chum URIs from the command line arguments.
    args = get_parser_for_any_input().parse_args()
    uris, control_uris = chum_uris_from_args(args)

    # Loop over the URIs and do the plotting.
    for uri, control_uri in zip(uris, control_uris):
        plot_cov_with_error(uri, control_uri)



if __name__ == '__main__':
    configure_logging_module("plot_cov_with_error")
    main()

