#!/usr/bin/env python

import argparse
import collections
import logging

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

import base.geometry.euler_py as euler
import vehicle.localization.pose_estimation.sliding_window_filter.metrics.logio as logio

class Trajectory(object):
    def __init__(self):
        self.t = []

        self.x = []
        self.y = []
        self.z = []

        self.utm_x = []
        self.utm_y = []
        self.utm_z = []

        self.roll = []
        self.pitch = []
        self.yaw = []

        self.rate_roll = []
        self.rate_pitch = []
        self.rate_yaw = []

        self.accel_x = []
        self.accel_y = []
        self.accel_z = []

        self.speed = []

    def __getattr__(self, key):
        return np.array(self.get(key))


def read_traj(fn, delay_ms=0.0):
    if fn.endswith('.log.gz'):
        traj_proto = logio.read_dgc(fn)
    else:
        traj_proto = logio.read_bag(fn, ['/driving/ApplanixPose'])

    traj = Trajectory()

    for pose in traj_proto.poses:
        traj.t.append(pose.time + delay_ms * 1e-3)

        traj.x.append(pose.smooth_x)
        traj.y.append(pose.smooth_y)
        traj.z.append(pose.smooth_z)

        traj.utm_x.append(pose.pose.translation.x)
        traj.utm_y.append(pose.pose.translation.y)
        traj.utm_z.append(pose.pose.translation.z)

        y, p, r = euler.EulerAnglesFromQuaternion([pose.pose.quaternion_rotation.w,
                                                   pose.pose.quaternion_rotation.x,
                                                   pose.pose.quaternion_rotation.y,
                                                   pose.pose.quaternion_rotation.z])

        traj.roll.append(np.rad2deg(r))
        traj.pitch.append(np.rad2deg(p))
        traj.yaw.append(np.rad2deg(y))

        traj.rate_roll.append(np.rad2deg(pose.rate_roll))
        traj.rate_pitch.append(np.rad2deg(pose.rate_pitch))
        traj.rate_yaw.append(np.rad2deg(pose.rate_yaw))

        traj.accel_x.append(pose.accel_x)
        traj.accel_y.append(pose.accel_y)
        traj.accel_z.append(pose.accel_z)

        traj.speed.append(np.sqrt(pose.vel_east**2 + pose.vel_north**2 + pose.vel_up**2))

    return traj


def main(applanix_log, mofo_log):
    apx_traj = read_traj(applanix_log)
    mfo_traj = read_traj(mofo_log, delay_ms=-19)

    apx_smooth_origin_x = apx_traj.utm_x[0] - apx_traj.x[0]
    apx_smooth_origin_y = apx_traj.utm_y[0] - apx_traj.y[0]
    apx_smooth_origin_z = apx_traj.utm_z[0] - apx_traj.z[0]

    mfo_smooth_origin_x = mfo_traj.utm_x[0] - mfo_traj.x[0]
    mfo_smooth_origin_y = mfo_traj.utm_y[0] - mfo_traj.y[0]
    mfo_smooth_origin_z = mfo_traj.utm_z[0] - mfo_traj.z[0]

    plt.figure(1)
    plt.subplot(1,2,1)
    plt.plot(np.array(apx_traj.x) + apx_smooth_origin_x,
             np.array(apx_traj.y) + apx_smooth_origin_y, '-b', label='Applanix')
    plt.plot(np.array(mfo_traj.x) + mfo_smooth_origin_x,
             np.array(mfo_traj.y) + mfo_smooth_origin_y, '-r', label='MoFo')
    plt.grid()
    plt.axis('equal')
    plt.xlabel('x [m]')
    plt.ylabel('y [m]')
    plt.title('Smooth Trajectory')
    plt.legend()

    plt.subplot(1,2,2)
    plt.plot(apx_traj.t, np.array(apx_traj.z) + apx_smooth_origin_z, '-b', label='Applanix')
    plt.plot(mfo_traj.t, np.array(mfo_traj.z) + mfo_smooth_origin_z, '-r', label='MoFo')
    plt.plot(mfo_traj.t, np.array(mfo_traj.speed) / np.max(mfo_traj.speed), '--k')
    plt.grid()
    plt.xlabel('x [m]')
    plt.ylabel('y [m]')
    plt.title('Smooth Z')
    plt.legend()

    plt.figure(2)
    plt.subplot(3,1,1)
    plt.plot(apx_traj.t, apx_traj.roll, '-b', label='Applanix')
    plt.plot(mfo_traj.t, mfo_traj.roll, '-r', label='MoFo')
    plt.grid()
    plt.legend()
    plt.xlabel('Time [s]')
    plt.ylabel('Degrees')
    plt.title('Roll')

    plt.subplot(3,1,2)
    plt.plot(apx_traj.t, apx_traj.pitch, '-b', label='Applanix')
    plt.plot(mfo_traj.t, mfo_traj.pitch, '-r', label='MoFo')
    plt.grid()
    plt.legend()
    plt.xlabel('Time [s]')
    plt.ylabel('Degrees')
    plt.title('Pitch')

    plt.subplot(3,1,3)
    plt.plot(apx_traj.t, apx_traj.yaw, '-b', label='Applanix')
    plt.plot(mfo_traj.t, mfo_traj.yaw, '-r', label='MoFo')
    plt.grid()
    plt.legend()
    plt.xlabel('Time [s]')
    plt.ylabel('Degrees')
    plt.title('Yaw')

    apx_roll_interp = interp1d(apx_traj.t, apx_traj.roll, bounds_error=False)(mfo_traj.t)
    apx_pitch_interp = interp1d(apx_traj.t, apx_traj.pitch, bounds_error=False)(mfo_traj.t)
    apx_yaw_interp = interp1d(apx_traj.t, apx_traj.yaw, bounds_error=False)(mfo_traj.t)

    roll_diff = mfo_traj.roll - apx_roll_interp
    pitch_diff = mfo_traj.pitch - apx_pitch_interp
    yaw_diff = mfo_traj.yaw - apx_yaw_interp

    plt.figure(3)

    plt.subplot(3,1,1)
    plt.plot(mfo_traj.t, roll_diff, '-b')
    plt.plot(mfo_traj.t, np.array(mfo_traj.speed) / np.max(mfo_traj.speed), '--k')
    plt.grid()
    plt.xlabel('Time [s]')
    plt.ylabel('Degrees')
    plt.title('Mofo - Applanix Roll Diff')

    plt.subplot(3,1,2)
    plt.plot(mfo_traj.t, pitch_diff, '-b')
    plt.plot(mfo_traj.t, np.array(mfo_traj.speed) / np.max(mfo_traj.speed), '--k')
    plt.grid()
    plt.xlabel('Time [s]')
    plt.ylabel('Degrees')
    plt.title('Mofo - Applanix Pitch Diff')

    plt.subplot(3,1,3)
    plt.plot(mfo_traj.t, yaw_diff, '-b')
    plt.plot(mfo_traj.t, np.array(mfo_traj.speed) / np.max(mfo_traj.speed), '--k')
    plt.grid()
    plt.xlabel('Time [s]')
    plt.ylabel('Degrees')
    plt.title('Mofo - Applanix Yaw Diff')

    plt.figure(4)

    plt.subplot(3,2,1)
    plt.plot(apx_traj.t, apx_traj.rate_roll, '-b', label='Applanix')
    plt.plot(mfo_traj.t, mfo_traj.rate_roll, '-r', label='MoFo')
    plt.grid()
    plt.legend()
    plt.xlabel('Time [s]')
    plt.ylabel('Degree / sec')
    plt.title('Roll Rate')

    plt.subplot(3,2,3)
    plt.plot(apx_traj.t, apx_traj.rate_pitch, '-b', label='Applanix')
    plt.plot(mfo_traj.t, mfo_traj.rate_pitch, '-r', label='MoFo')
    plt.grid()
    plt.legend()
    plt.xlabel('Time [s]')
    plt.ylabel('Degree / sec')
    plt.title('Pitch Rate')

    plt.subplot(3,2,5)
    plt.plot(apx_traj.t, apx_traj.rate_yaw, '-b', label='Applanix')
    plt.plot(mfo_traj.t, mfo_traj.rate_yaw, '-r', label='MoFo')
    plt.grid()
    plt.legend()
    plt.xlabel('Time [s]')
    plt.ylabel('Degree sec')
    plt.title('Yaw Rate')

    apx_rate_roll_interp = interp1d(apx_traj.t, apx_traj.rate_roll, bounds_error=False)(mfo_traj.t)
    apx_rate_pitch_interp = interp1d(apx_traj.t, apx_traj.rate_pitch, bounds_error=False)(mfo_traj.t)
    apx_rate_yaw_interp = interp1d(apx_traj.t, apx_traj.rate_yaw, bounds_error=False)(mfo_traj.t)

    rate_roll_diff = mfo_traj.rate_roll - apx_rate_roll_interp
    rate_pitch_diff = mfo_traj.rate_pitch - apx_rate_pitch_interp
    rate_yaw_diff = mfo_traj.rate_yaw - apx_rate_yaw_interp

    plt.subplot(3,2,2)
    plt.plot(mfo_traj.t, rate_roll_diff, '-b')
    plt.grid()
    plt.legend()
    plt.xlabel('Time [s]')
    plt.ylabel('Degree / sec')
    plt.title('Roll Rate')

    plt.subplot(3,2,4)
    plt.plot(mfo_traj.t, rate_pitch_diff, '-b')
    plt.grid()
    plt.legend()
    plt.xlabel('Time [s]')
    plt.ylabel('Degree / sec')
    plt.title('Pitch Rate')

    plt.subplot(3,2,6)
    plt.plot(mfo_traj.t, rate_yaw_diff, '-b')
    plt.grid()
    plt.legend()
    plt.xlabel('Time [s]')
    plt.ylabel('Degree sec')
    plt.title('Yaw Rate')

    plt.figure(5)

    plt.subplot(3,2,1)
    plt.plot(apx_traj.t, apx_traj.accel_x, '-b', label='Applanix')
    plt.plot(mfo_traj.t, mfo_traj.accel_x, '-r', label='MoFo')
    plt.grid()
    plt.legend()
    plt.xlabel('Time [s]')
    plt.ylabel('m/s^2')
    plt.title('accel_x')

    plt.subplot(3,2,3)
    plt.plot(apx_traj.t, apx_traj.accel_y, '-b', label='Applanix')
    plt.plot(mfo_traj.t, mfo_traj.accel_y, '-r', label='MoFo')
    plt.grid()
    plt.legend()
    plt.xlabel('Time [s]')
    plt.ylabel('m/s^2')
    plt.title('accel_y')

    plt.subplot(3,2,5)
    plt.plot(apx_traj.t, apx_traj.accel_z, '-b', label='Applanix')
    plt.plot(mfo_traj.t, mfo_traj.accel_z, '-r', label='MoFo')
    plt.grid()
    plt.legend()
    plt.xlabel('Time [s]')
    plt.ylabel('m/s^2')
    plt.title('accel_z')

    apx_accel_x_interp = interp1d(apx_traj.t, apx_traj.accel_x, bounds_error=False)(mfo_traj.t)
    apx_accel_y_interp = interp1d(apx_traj.t, apx_traj.accel_y, bounds_error=False)(mfo_traj.t)
    apx_accel_z_interp = interp1d(apx_traj.t, apx_traj.accel_z, bounds_error=False)(mfo_traj.t)

    accel_x_diff = mfo_traj.accel_x - apx_accel_x_interp
    accel_y_diff = mfo_traj.accel_y - apx_accel_y_interp
    accel_z_diff = mfo_traj.accel_z - apx_accel_z_interp

    plt.subplot(3,2,2)
    plt.plot(mfo_traj.t, accel_x_diff, '-b')
    plt.grid()
    plt.legend()
    plt.xlabel('Time [s]')
    plt.ylabel('m/s^2')
    plt.title('accel_x')

    plt.subplot(3,2,4)
    plt.plot(mfo_traj.t, accel_y_diff, '-b')
    plt.grid()
    plt.legend()
    plt.xlabel('Time [s]')
    plt.ylabel('m/s^2')
    plt.title('accel_y')

    plt.subplot(3,2,6)
    plt.plot(mfo_traj.t, accel_z_diff, '-b')
    plt.grid()
    plt.legend()
    plt.xlabel('Time [s]')
    plt.ylabel('m/s^2')
    plt.title('accel_z')

    plt.show()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('applanix_log', type=str, help='Applanix adjusted log')
    parser.add_argument('mofo_log', type=str, help='Applanix adjusted log')
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)

    main(**vars(args))
