#!/usr/bin/env python

import argparse
import rospy 
import rosbag
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

if __name__ == '__main__':
    parser = argparse.ArgumentParser('plot_applanix')
    parser.add_argument('batch_mofo_bag', type=str, help='Batch Mofo bag.')
    parser.add_argument('sideband_bag', type=str, default='', help='Sideband bag.')
    args = parser.parse_args()
    
    ws_t, ws = [], []
    b_t, b_g_x, b_g_y, b_g_z = [], [], [], []
    g_t, g_x, g_y, g_z = [], [], [], []

    with rosbag.Bag(args.batch_mofo_bag, 'r') as bag:
        for topic, msg, time in bag.read_messages(topics=['/driving/ImuBias']):
            if topic == '/driving/ImuBias':
                b_t.append(msg.header.stamp.to_sec())
                b_g_x.append(msg.gyro_x)
                b_g_y.append(msg.gyro_y)
                b_g_z.append(msg.gyro_z)

    with rosbag.Bag(args.sideband_bag, 'r') as bag:
        for topic, msg, time in bag.read_messages(topics=['/autobox/esp_wheel_speeds',
                                                          '/driving/mofo/imu']):

            if topic == '/autobox/esp_wheel_speeds':
                ws_t.append(msg.header.stamp.to_sec())
                ws.append(0.5 * (msg.ESP_wheelSpeedRR + msg.ESP_wheelSpeedRL))

            elif topic == '/driving/mofo/imu':
                g_t.append(msg.header.stamp.to_sec())
                g_x.append(msg.angular_velocity.x)
                g_y.append(msg.angular_velocity.y)
                g_z.append(msg.angular_velocity.z)

    
    g_t = np.array(g_t)
    g_x = np.array(g_x)
    g_y = np.array(g_y)
    g_z = np.array(g_z)

    ws_int = interp1d(ws_t, ws, bounds_error=False)(g_t)

    stopped_inds = np.where(np.array(ws_int) < 1e-3)[0]
    print('Gyro_x bias: {}'.format(np.mean(g_x[stopped_inds])))
    print('Gyro_y bias: {}'.format(np.mean(g_y[stopped_inds])))
    print('Gyro_z bias: {}'.format(np.mean(g_z[stopped_inds])))

    transition_inds = []
    for idx in xrange(len(ws) - 1):
        if ws[idx] > 0 and ws[idx + 1] < 1e-3:
            transition_inds.append(idx)

        elif ws[idx] < 1e-3 and ws[idx + 1] > 1e-3:
            transition_inds.append(idx)

    plt.figure(1)
    plt.plot(b_t, np.rad2deg(b_g_x), '-r', label='Gyro Bias x')
    plt.plot(b_t, np.rad2deg(b_g_y), '-g', label='Gyro Bias y')
    plt.plot(b_t, np.rad2deg(b_g_z), '-b', label='Gyro Bias z')

    ylim = plt.ylim()
    for ind in transition_inds:
        plt.plot([ws_t[ind], ws_t[ind]], [ylim[0], ylim[1]], '--k')

    plt.xlabel('Time, s')
    plt.ylabel('Angular Rate, deg/s')
    plt.legend()
    plt.grid()
    plt.ylim(ylim)

    plt.figure(2)
    abs_bias = np.abs(np.rad2deg(b_g_y))
    abs_bias -= np.mean(abs_bias)
    abs_bias /= np.std(abs_bias)
    plt.plot(b_t[1:], np.abs(np.diff(abs_bias)), '-b')
    plt.plot(ws_t, np.array(ws)/np.max(ws), '--k')
    plt.grid()

    plt.show()


