#!/usr/bin/env python
'''
This script implements a 1D version of the VoxCov alignment algorithm
for validating the inverse Hessian based covariance derivation.

The 1D line is split into equal sized bins which are then populated
by Gaussian point distributions. Two such sets of binned points
are then aligned by minimizing the sum squared difference between
the means of corresponding bins.
'''

from __future__ import division

import copy
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats

NUM_BINS = 100
BIN_SPACING = 1.0
MC_SAMPLES = 50000
MAX_POINT_SIGMA = 0.3
NUM_POINT_SIGMA = 50
SURFACE_SIGMA = 0.02
KERNEL_RADIUS = 1
BLOCK_LENGTH = 5

# Set the seed for the global numpy rng for repeatability.
np.random.seed(12345)

class Match(object):
    '''
    Statistics for matched bins.
    '''
    def __init__(self):
        self.weight = 0
        self.mean_a = 0
        self.mean_b = 0

        # Not really a "normal" in this 1D case,
        # but a scalar for whitening residuls.
        self.normal = 0

        # Estimated residual standard deviation.
        self.sigma_z = 0


class AlignmentResult(object):
    '''
    Derivatives, etc. of cost for at a given alignment.
    '''
    def __init__(self):
        # Gradient of cost w.r.t. alignment translation
        self.grad = 0

        #J^T * W * J
        self.JTWJ = 0

        #J^T * sigma_z^2 * W^T * J
        self.JTsigmaz2W2J = 0

        self.total_error = 0
        self.total_weight = 0

        # Save raw residuals for plotting
        self.residuals = []


class ResidualSampler(object):
    def __init__(self, n_samples, bootstrap=False, block_size=1, wild=False):
        self.n_samples = int(n_samples)
        self.bootstrap = bootstrap
        self.block_size = int(block_size)
        self.wild = wild
        assert bootstrap or block_size == 1

    def __iter__(self):
        count = 0
        for count in range(int(self.n_samples / self.block_size)):
            if not self.bootstrap:
                yield count, 1.0
            else:
                block_index = np.random.random_integers(0, int(self.n_samples / self.block_size)-1)
                for res_idx in range(block_index * self.block_size, self.block_size * (block_index + 1)):
                    if self.wild:
                        sign = 1.0 if np.random.uniform(0, 1.0) > 0.5 else -1
                    else:
                        sign = 1.0

                    yield res_idx, sign


class BinnedPoints(object):
    '''
    Simple 1D version of a VoxCovLayer. Linear space is divided into equal sized
    "bins", where each bin contains many points which are assumed to be normally
    distributed with their mean inside of that bin. 
    '''

    def __init__(self, bins, points, point_sigma):
        self.bins = bins
        self.points = points
        self.point_sigma = point_sigma

        # Compute statistics in each bin.
        self.means = []
        self.M2s = []
        self.weights = []
        for binned_points in self.points:
            mean = np.mean(binned_points)
            weight = len(binned_points)
            M2 = np.sum([(point - mean)**2 for point in binned_points])

            self.means.append(mean)
            self.M2s.append(M2)
            self.weights.append(weight)

        self.means = np.array(self.means)
        self.M2s = np.array(self.M2s)
        self.weights = np.array(self.weights)
        self.vars = self.M2s / self.weights

    def __len__(self):
        '''
        Total number of bins.
        '''
        return len(self.bins)

    def match(self, other, index):
        '''
        Find a match between bin `index` between this cloud and `other`.
        '''
        assert index < len(self.bins)
        EPS = 1e-12

        variance = self.vars[index] + other.vars[index] + EPS

        weight_a = self.weights[index]
        weight_b = other.weights[index]
        
        match = Match()
        match.weight = (weight_a * weight_b) / (weight_a + weight_b)
        match.mean_b = self.means[index]
        match.mean_a = other.means[index]
        match.normal = 1.0 / np.sqrt(variance)
        match.sigma_z = (self.point_sigma**2) / (variance * match.weight)

        return match

    @staticmethod
    def create(bins, surface_dists, bin_spacing, point_sigma, offset=0):
        '''
        Factory method that constructs a `BinnedPoints` object with
        `num_bins` total bins spaced `bin_spacing` meters apart. Points are
        added to each bin with a mean randomly chosen inside of the bin and
        a standard deviation `point_sigma`.
        '''
        points = []

        for bin, surface_dist in zip(bins, surface_dists):
            bin_points = []
            
            decay_constant = 20 * BIN_SPACING
            points_per_bin = int(50 * np.exp(-np.abs(bin / decay_constant)))

            # Populate the with normally distributed points centered in the bin.
            for _ in range(points_per_bin):
                point = np.random.normal(offset + surface_dist, point_sigma)
                bin_points.append(point)

            points.append(bin_points)

        return BinnedPoints(bins, points, point_sigma)

    def compute_error(self, points_a, alignment, sample=False):
        '''
        Evaluates the cost function at the given alignment and returns
        derivative information.
        '''
        assert len(points_a) == len(self)
        result = AlignmentResult()

        if sample:
            window_len = BLOCK_LENGTH
        else:
            window_len = 1

        sampler = ResidualSampler(len(points_a), bootstrap=sample, wild=False, block_size=window_len)

        for index, sign in sampler:
            # Skip if either bins are empty.
            if self.weights[index] == 0 or points_a.weights[index] == 0:
                continue

            match = self.match(points_a, index)

            error = sign * match.normal * (match.mean_b + alignment - match.mean_a)
            result.residuals.append(error)

            # We could reweight here, but just pass it through for testing with an L2 losscovariance
            weight = match.weight

            J_err = match.normal
            result.grad += J_err * weight * error
            result.JTWJ += weight * J_err**2
            result.JTsigmaz2W2J += match.sigma_z * weight**2 * J_err**2
            result.total_error += weight * error**2
            result.total_weight += weight

        result.residuals = np.array(result.residuals)
        return result

    def smooth(self, kernel_radius=1):
        '''
        Smooth with a simple box filter.
        '''
        new_means = np.zeros_like(self.means)
        new_M2s = np.zeros_like(self.M2s)
        new_weights = np.zeros_like(self.weights)

        kernel_radius = int(kernel_radius)

        # Kernel radius of 1 gives a 3 element centered moving average.
        window_norm = 2.0 * kernel_radius + 1.0

        for center_idx in range(kernel_radius, len(self.bins) - kernel_radius):
            total_weight = 0
            mean = 0
            M2 = 0
            for idx in range(center_idx - kernel_radius, center_idx + kernel_radius + 1):
                total_weight += self.weights[idx]
                mean += self.means[idx] * self.weights[idx]
                M2 += self.M2s[idx]

            new_weights[center_idx] = total_weight / window_norm
            new_means[center_idx] = mean / total_weight
            new_M2s[center_idx] = M2

        self.weights = new_weights
        self.means = new_means
        self.M2s = new_M2s
        self.vars = self.M2s / (self.weights + 1e-6)


def main(plot=False):
    # Stopping tolerance on delta RMSE.
    TOL = 1e-12

    # Range of point noise standard devations to test with.
    sigma_range = np.linspace(0, MAX_POINT_SIGMA, NUM_POINT_SIGMA)

    m_analytic = []
    m_mc = []

    bins = np.arange(-int(NUM_BINS/2.0), int(NUM_BINS/2.0), BIN_SPACING)
    surface_dists = np.random.normal(0, SURFACE_SIGMA, size=bins.shape)

    # Solution standard deviations for runs with different data noise parameters.
    analytic_stds = []
    mc_stds = []

    # Sweep through point noise values and measure the analytical and monte carlo variance
    # of the solution.
    for point_sigma in sigma_range:
        # Create the two binned point "clouds"
        binned_points_a = BinnedPoints.create(bins, surface_dists, BIN_SPACING, point_sigma)
        binned_points_b = BinnedPoints.create(bins, surface_dists, BIN_SPACING, point_sigma, offset=0)

        copy_a = copy.deepcopy(binned_points_a)

        print('Weight before smoothing: {:f}'.format(np.sum(binned_points_a.weights)))
        binned_points_a.smooth(KERNEL_RADIUS)
        binned_points_b.smooth(KERNEL_RADIUS)

        print('Weight after smoothing: {:f}'.format(np.sum(binned_points_a.weights)))
        print('')
        print(50*'=')
        print('point_sigma: {:.3f} m'.format(point_sigma))
        print('filter radius: {:d} m'.format(KERNEL_RADIUS))
        print('')
        print('iter\tRMSE\t\talignment')

        # Initialize the alignment translation between the two clouds to zero.
        alignment = 0
        prev_rms = 1e12
        for iter in range(20):
            result = binned_points_b.compute_error(binned_points_a, alignment)

            # Gauss-Newton step.
            d_alignment = -result.grad / result.JTWJ
            alignment += d_alignment

            rms = np.sqrt(result.total_error / result.total_weight)
            print('{:d}\t{:.6f}\t{:.3f}'.format(iter, rms, alignment))

            delta_rms = prev_rms - rms
            if delta_rms < TOL:
                break
            prev_rms = rms

        var = result.JTsigmaz2W2J / (result.JTWJ**2)

        print('')
        print("Analytical Stddev: {:.6f}".format(np.sqrt(var)))

        mc_var = 0
        for _ in range(MC_SAMPLES):
            mc_result = binned_points_b.compute_error(binned_points_a, alignment, sample=True)
            d_alignment = -mc_result.grad / mc_result.JTWJ
            mc_var += d_alignment**2

        mc_var /= MC_SAMPLES - 1
        print("Monte Carlo Stddev: {:.6f}".format(np.sqrt(mc_var)))

        analytic_stds.append(np.sqrt(var))
        mc_stds.append(np.sqrt(mc_var))

    an_fit = np.polyfit(sigma_range, analytic_stds, 1)
    mc_fit = np.polyfit(sigma_range, mc_stds, 1)

    m_analytic.append(an_fit[0])
    m_mc.append(mc_fit[0])

    print(50*'=')
    print('m_analytic: {:.3f}'.format(an_fit[0]))
    print('m_mc: {:.3f}'.format(mc_fit[0]))
    print('ratio: {:.3f}'.format(an_fit[0] / (mc_fit[0] + 1e-9)))

    if plot:
        plt.figure(1, figsize=(10, 8))
        plt.errorbar(bins, binned_points_a.means, yerr=np.sqrt(binned_points_a.vars), fmt='ob', label='A')
        plt.errorbar(bins, binned_points_b.means, yerr=np.sqrt(binned_points_b.vars), fmt='or', label='B')
        plt.grid()
        plt.legend()
        plt.xlabel('Distance [m]')
        plt.tight_layout()

        plt.figure(2)
        plt.plot(sigma_range, analytic_stds, '.-b', label='analytical, m={:.3f}, b={:.3f}'.format(an_fit[0], an_fit[1]))
        plt.plot(sigma_range, mc_stds, '.-r', label='monte carlo, m={:.3f}, b={:.3f}'.format(mc_fit[0], mc_fit[1]))
        plt.grid()
        plt.xlabel('Point $\sigma$ [m]')
        plt.ylabel('Alignment $\sigma$ [m]')
        plt.legend()
        plt.title('Solution Stddev vs. Point Noise')

        plt.figure(3, figsize=(10, 20))
        plt.subplot(2,1,1)
        sns.distplot(result.residuals)
        plt.title('Solution Residual Distribution')

        plt.subplot(2,1,2)
        scipy.stats.probplot(result.residuals, plot=plt)
        plt.title('Solution Residual Q-Q Plot')

        plt.figure(4)
        plt.plot(bins, binned_points_a.weights, 'ob', alpha=0.5, label='smoothed')
        plt.plot(bins, copy_a.weights, 'or', alpha=0.5, label='original')
        plt.legend()
        plt.grid()

        plt.show()


if __name__ == '__main__':
    main(plot=True)
