#!/usr/bin/env python

import argparse
from collections import namedtuple
import numpy as np
import matplotlib.pyplot as plt

import base.proto
from mapping.distributed_mapping.common import utils
from mapping.distributed_mapping.proto.aligner_pb2 import VoxCovStack

GridCell = namedtuple('GridCell', ['weight', 'z'])

def process(voxcov, resolution):
  stack = VoxCovStack()
  base.proto.ReadProto(voxcov, stack)

  grid = {}

  for semantic_layer in stack.layers:
    for layer, label in zip(semantic_layer.layers, semantic_layer.labels):
      if layer.resolution == resolution:
        for mean, weight in zip(layer.means, layer.weights):
          coords = int(mean.x / resolution), int(mean.y / resolution)
          grid_cell = GridCell(weight=weight, z=mean.z)

          if (coords not in grid) or (mean.z < grid[coords].z):
            grid[coords] = grid_cell

  
  zero_weights_x = []
  zero_weights_y = []
  nonzero_weights_x = []
  nonzero_weights_y = []
  for coords, cell in grid.iteritems():
    if cell.weight == 0:
      zero_weights_x.append(resolution * coords[0])
      zero_weights_y.append(resolution * coords[1])
    else:
      nonzero_weights_x.append(resolution * coords[0])
      nonzero_weights_y.append(resolution * coords[1])

  zero_weights_x = np.array(zero_weights_x, dtype=np.float32)
  zero_weights_y = np.array(zero_weights_y, dtype=np.float32)
  nonzero_weights_x = np.array(nonzero_weights_x, dtype=np.float32)
  nonzero_weights_y = np.array(nonzero_weights_y, dtype=np.float32)

  plt.figure(1)
  plt.plot(zero_weights_x, zero_weights_y, 'ob')
  plt.plot(nonzero_weights_x, nonzero_weights_y, 'or')
  plt.grid()
  plt.axis('equal')
  plt.xlim((0, 64))
  plt.ylim((0, 64))
  plt.show()

if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('--voxcov', required=True, type=utils.is_valid_file)
  parser.add_argument('--resolution', type=float, default=1.0)
  args = parser.parse_args()

  process(args.voxcov, args.resolution)