#!/usr/bin/env python

import argparse
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn
import sklearn.linear_model

import base.proto
from mapping.distributed_mapping.common import utils
from mapping.distributed_mapping.eval_suite.scan_alignment.alignment_eval_suite_pb2 import (
        AlignmentEvalResultList,
        VOXCOV,
        ANXIOUS_SEARCH)

OUTLIER_TRANSLATION_THRESH_M = 0.5

if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('--results',
                      required=True,
                      nargs='+',
                      type=utils.is_valid_file)
  args = parser.parse_args()

  plt.figure(1)
  for result_file in args.results:
    results = AlignmentEvalResultList()
    base.proto.ReadProto(result_file, results)

    errors = []
    fraction_overlaps = []
    iteration_counts = []

    for result in results.results:
      trans_error = result.transform_error.translation
      errors.append(np.sqrt(trans_error.x**2 + trans_error.y**2 + trans_error.z**2))
      fraction_overlaps.append(result.fraction_overlap) 
      iteration_counts.append(result.num_iterations) 

    errors = np.array(errors)
    fraction_overlaps = np.array(fraction_overlaps)
    iteration_counts = np.array(iteration_counts)

    label = 'VoxCov' if results.options_used.method == VOXCOV else 'AnxiousSearch'

    inds_inliers = np.where(errors < OUTLIER_TRANSLATION_THRESH_M)[0]
    inds_outliers = np.where(errors >= OUTLIER_TRANSLATION_THRESH_M)[0]

    outlier_pct = 100.0 * len(inds_outliers) / len(errors)
    print('{} outlier ratio: {:d}/{:d} ({:.2f}%)'.format(label,
                                                         len(inds_outliers),
                                                         len(errors),
                                                         outlier_pct))

    y = np.zeros_like(errors)
    y[inds_outliers] = 1.0
    x = iteration_counts.reshape(-1, 1)
    lr = sklearn.linear_model.LogisticRegression().fit(x, y)

    plt.figure(1) 
    plt.plot(fraction_overlaps[inds_inliers], errors[inds_inliers], 'o', label='{}: inliers'.format(label))
    plt.plot(fraction_overlaps[inds_outliers], errors[inds_outliers], 'o', label='{}: outliers ({:.2f}%)'.format(label, outlier_pct))

    plt.figure(2) 
    plt.plot(iteration_counts[inds_inliers], errors[inds_inliers], 'o', label='{}: inliers'.format(label), alpha=0.2)
    plt.plot(iteration_counts[inds_outliers], errors[inds_outliers], 'o', label='{}: outliers ({:.2f}%)'.format(label, outlier_pct), alpha=0.75)

    plt.figure(3)
    niter = np.linspace(0, 26, 100).reshape(-1, 1)
    probs = lr.predict_proba(niter)
    plt.plot(niter, probs[:, 1])
    plt.grid()
    plt.xlabel('Iteration Count')
    plt.ylabel('Outlier Probability')
    plt.title('{} Outlier Probability vs. Iteration Count'.format(label))

    plt.figure(4)
    hist, bins = np.histogram(iteration_counts, bins=np.arange(0, 26), density=True)
    cdf = np.cumsum(hist)
    plt.plot(bins[1:], cdf)
    plt.grid()
    plt.xlabel('$x$, Iteration Count')
    plt.ylabel(r'$\mathrm{Prob}\left[\mathrm{Iteration\;Count} \leq x\right]$')
    plt.title('Iteration Count CDF')


  plt.figure(1)
  plt.grid()
  plt.xlabel('Fraction Overlap')
  plt.ylabel('Translation Error [m]')
  plt.legend()

  plt.figure(2)
  plt.grid()
  plt.xlabel('Iteration Count')
  plt.ylabel('Translation Error [m]')
  plt.legend()


  plt.show()