#!/usr/bin/env python

import os
import argparse
import logging
import matplotlib.pyplot as plt
import numpy as np
import sys
from glob import glob

from base.geometry.isometry3d_py import Isometry3d
from base.proto import ReadProto, WriteProto
from mapping.distributed_mapping.common import utils
from mapping.distributed_mapping.data_loading.file_utils import getLastProtoSuffixes
from mapping.distributed_mapping.proto.dataset_pb2 import DatasetList
from mapping.distributed_mapping.proto.eval_suite_results_pb2 import EvalSuiteResults

TRANSLATION_DOFS = ['x', 'y', 'z']
ROTATION_DOFS = ['R', 'P', 'Y']
DOFS = TRANSLATION_DOFS + ROTATION_DOFS
DPI = 96

def loadTrajectoryData(filename):
    '''Load and return the trajectory data given a path to a datasets.pb file.'''
    output = []
    data = DatasetList()
    ReadProto(filename, data)
    for dataset in data.datasets:
        origin_x = dataset.trajectory.vehicle_state[0].pose_smooth.translation.x
        origin_y = dataset.trajectory.vehicle_state[0].pose_smooth.translation.y
        origin_z = dataset.trajectory.vehicle_state[0].pose_smooth.translation.z
        for i, pose in enumerate(dataset.trajectory.vehicle_state):
            pose.pose_smooth.translation.x -= origin_x
            pose.pose_smooth.translation.y -= origin_y
            pose.pose_smooth.translation.z -= origin_z
            output.append(Isometry3d(pose.pose_smooth))
            # output[-1].m_[0:3, 3] -= [origin.x, origin.y, origin.z]
            # if i > 1000:
            #     break
    return output


def get_component(component, trajectory):
    '''Return a single component slice (x, y, z, R, P, or Y) of a full 6dof trajectory.'''
    if component in TRANSLATION_DOFS:
        i = TRANSLATION_DOFS.index(component)
        output = [pose.translation()[i] for pose in trajectory]
    elif component in ROTATION_DOFS:
        i = ROTATION_DOFS.index(component)
        output = [pose.rotation_rpy()[i] for pose in trajectory]
    else:
        assert False, "No such component, valid options are 'x', 'y', 'z', 'R', 'P', and 'Y'"
    return output


def bigFigure():
    screen_w = 2560
    screen_h = 1440
    fig = plt.figure(figsize=(screen_w/DPI, screen_h/DPI), dpi=DPI)
    return fig


def main(input_dir):
    datasets = glob(input_dir + '/*.pb')
    # print datasets
    N = len(datasets)

    data = []
    for i in xrange(N):
        for fname in datasets:
            if str(i) in fname:
                data.append(loadTrajectoryData(fname))
                break

    for i in xrange(N):
        print i, len(get_component('R', data[i]))

    errors = {}
    for dof in DOFS:
        errors[dof] = []
        for i in xrange(1, N):
            # data[0] must be groundtruth, so formula is: error = data[x] - groundtruth
            errors[dof].append(np.array(get_component(dof, data[i])) - np.array(get_component(dof, data[0])))

    # wrap angle errors
    for dof in ROTATION_DOFS:
        for i in xrange(len(errors[dof])):
            errors[dof][i] = [v if abs(v) <= np.pi else v * (v/abs(v)) - 2*np.pi for v in errors[dof][i]]

    # Plot CLAMS trajectory XYZ/RPY errors
    # plot_order = [1, 3, 5, 2, 4, 6]
    # color = ['red', 'orange', 'green', 'blue', 'purple']
    # for n, dof in enumerate(DOFS):
    #     fig, ax = plt.subplots(figsize=(2560/DPI, 1440/DPI), dpi=DPI)
    #     ax.set_title(dof + ' error')
    #     ax.set_xlabel('trajectory index')
    #     ax.set_ylabel('error [meters]' if plot_order[n] % 2 else 'error [radians]')
    #     ax.grid(True)
    #     for i in xrange(len(errors[dof])):
    #         ax.plot(errors[dof][i], color[i], label="{} iteration(s)".format(i+1))
    #         num_indices = len(errors[dof][i])
    #         ax.set_xlim(xmin=0, xmax=num_indices)
    #     ax.legend()
    #     plt.legend()
    #     plt.show()

    # Calculate CLAMS trajectory XYZ/RPY RMS and max errors.
    stats_RMS = {}
    stats_mean = {}
    stats_stdev = {}
    stats_max = {}
    for dof in DOFS:
        stats_RMS[dof] = []
        stats_mean[dof] = []
        stats_stdev[dof] = []
        stats_max[dof] = []
        for i in xrange(len(errors[dof])):
            L = len(errors[dof][i])
            RMS_error = np.sqrt(np.sum(np.square(errors[dof][i])) / L)
            stats_RMS[dof].append(RMS_error)
            # print "{} iteration trajectory RMS {} error: {}".format(i+1, dof, RMS_error)
            # output_stats[dof].error_RMS = RMS_error

            mean_error = np.mean(errors[dof][i])
            stats_mean[dof].append(mean_error)
            # print "{} iteration trajectory mean {} error: {}".format(i+1, dof, mean_error)
            # output_stats[dof].error_mean = mean_error

            stdev_error = np.std(errors[dof][i])
            stats_stdev[dof].append(stdev_error)
            # print "{} iteration trajectory {} error std dev: {}".format(i+1, dof, stdev_error)
            # output_stats[dof].error_std_dev = stdev_error

            max_error = max(max(errors[dof][i]), abs(min(errors[dof][i])))
            stats_max[dof].append(max_error)
            # print "{} iteration trajectory max {} error: {}".format(i+1, dof, max_error)
            # output_stats[dof].error_max = max_error

        print "RMS ERRORS FOR {}:".format(dof)
        print stats_RMS[dof]

    pos_y_min = 0.02
    pos_y_max = 0.04
    rot_y_min = 0.0006
    rot_y_max = 0.0015

    # Plot error statistics vs iteration
    fig, axarr = plt.subplots(3, 2, figsize=(2560/DPI, 1440/DPI), dpi=DPI)
    plot_order = [1, 3, 5, 2, 4, 6]
    # color = ['red', 'orange', 'green', 'blue', 'purple']
    for n, dof in enumerate(DOFS):
        ax = axarr[n%3][n/3]
        ax.set_title(dof + ' RMS error vs iteration number')
        ax.set_xlabel('iteration #')
        ax.set_ylabel('error [meters]' if plot_order[n] % 2 else 'error [radians]')
        ax.grid(True)
        ax.plot(range(1, len(stats_RMS[dof])+1), stats_RMS[dof])
        ax.set_xlim(xmin=0, xmax=len(stats_RMS[dof]))
        # if plot_order[n] % 2:
        #     ax.set_ylim(ymin=pos_y_min, ymax=pos_y_max)
        # else:
        #     ax.set_ylim(ymin=rot_y_min, ymax=rot_y_max)
    ax.legend()
    plt.show()

    fig, axarr = plt.subplots(3, 2, figsize=(2560/DPI, 1440/DPI), dpi=DPI)
    plot_order = [1, 3, 5, 2, 4, 6]
    # color = ['red', 'orange', 'green', 'blue', 'purple']
    for n, dof in enumerate(DOFS):
        ax = axarr[n%3][n/3]
        ax.set_title(dof + ' MEAN error vs iteration number')
        ax.set_xlabel('iteration #')
        ax.set_ylabel('error [meters]' if plot_order[n] % 2 else 'error [radians]')
        ax.grid(True)
        ax.plot(range(1, len(stats_mean[dof])+1), stats_mean[dof])
        ax.set_xlim(xmin=0, xmax=len(stats_mean[dof]))
        # if plot_order[n] % 2:
        #     ax.set_ylim(ymin=pos_y_min, ymax=pos_y_max)
        # else:
        #     ax.set_ylim(ymin=rot_y_min, ymax=rot_y_max)
    ax.legend()
    plt.show()

    fig, axarr = plt.subplots(3, 2, figsize=(2560/DPI, 1440/DPI), dpi=DPI)
    plot_order = [1, 3, 5, 2, 4, 6]
    # color = ['red', 'orange', 'green', 'blue', 'purple']
    for n, dof in enumerate(DOFS):
        ax = axarr[n%3][n/3]
        ax.set_title(dof + ' std dev vs iteration number')
        ax.set_xlabel('iteration #')
        ax.set_ylabel('error [meters]' if plot_order[n] % 2 else 'error [radians]')
        ax.grid(True)
        ax.plot(range(1, len(stats_stdev[dof])+1), stats_stdev[dof])
        ax.set_xlim(xmin=0, xmax=len(stats_stdev[dof]))
        # if plot_order[n] % 2:
        #     ax.set_ylim(ymin=pos_y_min, ymax=pos_y_max)
        # else:
        #     ax.set_ylim(ymin=rot_y_min, ymax=rot_y_max)
    ax.legend()
    plt.show()

    fig, axarr = plt.subplots(3, 2, figsize=(2560/DPI, 1440/DPI), dpi=DPI)
    plot_order = [1, 3, 5, 2, 4, 6]
    # color = ['red', 'orange', 'green', 'blue', 'purple']
    for n, dof in enumerate(DOFS):
        ax = axarr[n%3][n/3]
        ax.set_title(dof + ' MAX error vs iteration number')
        ax.set_xlabel('iteration #')
        ax.set_ylabel('error [meters]' if plot_order[n] % 2 else 'error [radians]')
        ax.grid(True)
        ax.plot(range(1, len(stats_max[dof])+1), stats_max[dof])
        ax.set_xlim(xmin=0, xmax=len(stats_max[dof]))
        # if plot_order[n] % 2:
        #     ax.set_ylim(ymin=pos_y_min, ymax=pos_y_max)
        # else:
        #     ax.set_ylim(ymin=rot_y_min, ymax=rot_y_max)
    ax.legend()
    plt.show()

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--input_dir', type=utils.is_valid_dir, required=True)
    args = parser.parse_args()

    main(args.input_dir)
