"""
Validate a dataset by ensuring all TFRecords contain the expected keys.
"""
import argparse
import os
import sys

import numpy as np
import tensorflow as tf

from glob import glob

from base.file.utils.file_utils import mkdir_p_public
from base.proto import ReadProtoAsText
from vehicle.perception.learning.safetynet.dataset_generator.dataset_generator_pb2 import DatasetGeneratorOptions
from vehicle.perception.learning.safetynet.networks.constants import *
from vehicle.perception.learning.safetynet.utils.key_utils import *
from vehicle.perception.learning.safetynet.utils.testing_utils import tfrecord_to_numpy_generator



def get_expected_keys(dataset_generator_options):
    """
    Get the set of TFRecord keys we expect to find in a dataset
    base on the specified DatasetGeneratorOptions.

    Args:
        dataset_generator_options: DatasetGeneratorOptions
            Proto specifying the format of the dataset.

    Returns: set of string
    """
    expected_keys = set([
        'key', 'po2_ts', 'run_id', 'hero_velocity', 'hero_yaw_rate',
        'baselink_t0_from_smooth'
    ])

    num_input_frames = dataset_generator_options.safetynet_pipeline_options.input_spec.num_frames
    num_output_frames = dataset_generator_options.safetynet_pipeline_options.output_spec.num_frames
    for frame_number in range(num_input_frames):
        if dataset_generator_options.use_camera_views:
            for camera_name in dataset_generator_options.safetynet_pipeline_options.camera_name:
                expected_keys.update(
                    image_keys(frame_key(frame_number, camera_name)))
        if dataset_generator_options.use_lidar_features:
            expected_keys.update(lidar_handcrafted_keys(frame_number))
        if dataset_generator_options.use_radar:
            expected_keys.update(radar_handcrafted_keys(frame_number))
        if dataset_generator_options.use_motion_correction:
            expected_keys.update([translation_offset_key(frame_number)])
            expected_keys.update([rotation_offset_key(frame_number)])
    for frame_number in range(num_output_frames):
        expected_keys.update(image_keys(frame_key(frame_number, 'po2')))
        expected_keys.update([frame_key(frame_number, 'vx_po2')])
        expected_keys.update([frame_key(frame_number, 'vy_po2')])
    for frame_number in range(num_input_frames):
        expected_keys.update(image_keys(frame_key(frame_number, 'past_po2')))
    if dataset_generator_options.use_zrn:
        # Only one frame for zrn.
        expected_keys.update(image_keys(frame_key(0, 'zrn')))
    if dataset_generator_options.use_lidar_point_features:
        expected_keys.update(lidar_point_feature_keys())
    return expected_keys


def validate_dataset(dataset_dir):
    """
    Validate a dataset by checking that all expected keys appear in each sample.

    Args:
        dataset_dir: string
            Base path of dataset.

    Returns: bool
        True iff the dataset is valid.
    """
    dataset_generator_options = DatasetGeneratorOptions()
    ReadProtoAsText(os.path.join(dataset_dir, 'dataset_generator.pbtxt'),
                    dataset_generator_options)
    expected_keys = get_expected_keys(dataset_generator_options)
    for sub_dir in SUBDIRS:
        tfrecords = glob(os.path.join(dataset_dir, sub_dir, "records.*"))
        for tfrecord in tfrecords:
            count = 0
            for record in tf.compat.v1.io.tf_record_iterator(
                    tfrecord,
                    tf.io.TFRecordOptions(
                        compression_type=tf.compat.v1.io.TFRecordCompressionType.
                        GZIP)):
                example = tf.train.Example()
                example.ParseFromString(record)
                actual_keys = set([str(x) for x in example.features.feature])
                if actual_keys == expected_keys:
                    count += 1
                    print("Reading a record: {}".format(count))
                else:
                    print(
                        "Keys are missing. Expected: {}, Actual: {}, More than expected: {}, Missing: {}"
                        .format(expected_keys, actual_keys,
                                actual_keys - expected_keys,
                                expected_keys - actual_keys))
                    return False
    return True


def diff_tfrecords(control_dataset_dir, candidate_dataset_dir):
    """Reads tfrecord from control and candidate dataset directories
    a diff's each example. This can be used to test any changes
    in the dataset gen pipeline.

    Args:
        control_dataset_dir: Path to control dataset directory
        candidate_dataset_dir: Path to candidate dataset directory
    Returns:
        True if all examples are equal, False otherwise.

    """
    dataset_generator_options = DatasetGeneratorOptions()
    ReadProtoAsText(
        os.path.join(control_dataset_dir, 'dataset_generator.pbtxt'),
        dataset_generator_options)
    all_same = True
    expected_keys = get_expected_keys(dataset_generator_options)
    for sub_dir in SUBDIRS:
        tfrecords = glob(
            os.path.join(control_dataset_dir, sub_dir, "records.*"))
        for control_tfrecord in tfrecords:
            candidate_tfrecord = os.path.join(
                candidate_dataset_dir, sub_dir,
                os.path.basename(control_tfrecord))
            for control_example, candid_example in zip(
                    tfrecord_to_numpy_generator(control_tfrecord,
                                                expected_keys),
                    tfrecord_to_numpy_generator(candidate_tfrecord,
                                                expected_keys)):
                for key, expected_value in control_example.items():
                    if candid_example[key].shape != expected_value.shape:
                        all_same = False
                        print(f"{key} has a different shape in control and candidate tfrecord example")
                        continue
                    diff = candid_example[key] - expected_value
                    different_pixel_count = diff.nonzero()[0].size
                    total_pixel_count = expected_value.size
                    all_close = np.allclose(candid_example[key], expected_value)
                    different_pixel_fraction = float(different_pixel_count) / total_pixel_count
                    # Intensity is based on lidar point with minimum z in a
                    # given pixel, which may have ties.
                    INTENSITY_COMPARISON_THRESHOLD = 0.01
                    if 'intensity' in key:
                        if different_pixel_fraction > INTENSITY_COMPARISON_THRESHOLD:
                            all_same = False
                            print(f"{key} is different in control and candidate tfrecord example")
                    elif not all_close:
                        all_same = False
                        print(f"{key} is different in control and candidate tfrecord example")
    return all_same
