#!/usr/bin/env python

import argparse
from glob import glob
import logging
import numpy as np
import os
import simplekml
import utm

import base.proto
from mapping.distributed_mapping.common import utils
from mapping.distributed_mapping.proto.aligner_pb2 import VoxCovStack


def tile_size_from_file_path(file_path):
  fn = os.path.basename(file_path)
  toks = fn.split('-')
  return int(toks[2])

def tile_inds_from_file_path(file_path):
  fn = os.path.splitext(os.path.basename(file_path))[0]
  toks = fn.split('-')
  return int(toks[3]), int(toks[4])

def process_tile(kml, tile_path, semantic_label, resolution):
  stack = VoxCovStack()
  base.proto.ReadProto(tile_path, stack)

  tile_size = tile_size_from_file_path(tile_path)
  ind_x, ind_y = tile_inds_from_file_path(tile_path)

  for idx, semantic_layer in enumerate(stack.layers):
    if semantic_layer.layers[0].resolution == resolution:

      for layer, label in zip(semantic_layer.layers, semantic_layer.labels):
        if int(label) == semantic_label:
          for mean, cov in zip(layer.means, layer.covs):
            voxel_x, voxel_y = int(mean.x / resolution), int(mean.y / resolution)
            
            # Upper left corner of voxel.
            lat_ul, lon_ul = utm.to_latlon(ind_x * tile_size + voxel_x * resolution,
                                           ind_y * tile_size + voxel_y * resolution, 10, 'S')

            # Lower right corner.
            lat_lr, lon_lr = utm.to_latlon(ind_x * tile_size + (voxel_x + 1) * resolution,
                                           ind_y * tile_size + (voxel_y + 1) * resolution, 10, 'S')

            pol = kml.newpolygon(name='')
            pol.outerboundaryis = [(lon_ul, lat_ul), (lon_ul, lat_lr),
                                   (lon_lr, lat_lr), (lon_lr, lat_ul)]

            pol.style.linestyle.width = 0
            pol.style.polystyle.color = simplekml.Color.blue


def process(tile_dir, output, semantic_label, resolution):
  tiles = glob(os.path.join(tile_dir, '*.pb'))

  assert len(tiles) > 0

  kml = simplekml.Kml()

  for idx, tile in enumerate(tiles):
    logging.info('{:d}: {}'.format(idx, tile))
    process_tile(kml, tile, semantic_label, resolution)

  kml.save(output)

if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('--tile_dir', type=utils.is_valid_dir, required=True, help='Path to voxcov tile dir')
  parser.add_argument('--output', type=utils.is_writable_file, required=True, help='Path to output KML file')
  parser.add_argument('--semantic_label', type=int, default=2)
  parser.add_argument('--resolution', type=float, default=2.0)
  args = parser.parse_args()

  logging.basicConfig(level=logging.INFO)

  process(args.tile_dir, args.output, args.semantic_label, args.resolution)