# /usr/bin/env python

"""
This is a wrapper for the Perception Triage Tool that specializes in triaging
batches from a Perception-CI suite. It accepts lmdb files stored in
/mnt/nautilus/3dbt_data/scale_api_nightly_download only, and lets you specify a
branch name for a perception ci run to use data saved from the perception ci
run. Lastly, it makes use of nightly cached vision data.
"""
from __future__ import print_function
import argparse
import lmdb
import os
import subprocess
import sys
import tempfile

from copy import deepcopy
from pick import pick

# Set this to query the 3dbt database for run ID.
os.environ["LABELING_ENV"] = "production"  # NOQA
# This flag prevents an issue where the triage viewer attempts to load
# logged metaspins, even though that isn't possible from the annotated spin
# point cloud.
os.environ["FLAGS_ignore_logged_metaspins"] = 'true'  # NOQA

from argus.utils.links import render_link
from base.proto import ReadProtoAsText
from data.triage import (
    launch_triage_tool,
    util)
from labeling.scale.util.util import job_id_tuple_from_x
from lidar.metrics.proto.metrics_data_pb2 import MetricsSuite
from lidar.metrics.proto.tracking_metrics_v2_pb2 import LmdbComparisons
from lidar.metrics.utils.variant_utils import (
        make_pcp_variant,
        make_vision_variant,
        sanitize_variant_name)
from lidar.metrics.v2.perception_ci_directory_utils import (
    choose_run_directories,
    copy_file_to_local_dir,
    pick_comparison_metric_and_filter,
    remove_results_dir)
from lidar.metrics.v2.run_metrics_single_log import (
    init_chum_uri_fields,
    gen_output_file_name_noext,
    gen_short_chum_uri,
    strip_last_extension)
from vehicle.perception.learning.utils.data_location import get_data_location, get_staging_dir


def get_start_end_times(lmdb_path):
    """
    Gets an appropriate start and end time from the annotated spin LMDB
    specified.

    :input lmdb_path: Absolute path to the labeled annotated spin LMDB file.

    :return: Tuple of (start_time, end_time) floats representing the beginning
             and end timestamps of the lmdb.
    """
    annospin_lmdb = lmdb.Environment(path=lmdb_path,
                                     subdir=False,
                                     readonly=True,
                                     create=False,
                                     lock=False)
    cursor = annospin_lmdb.begin().cursor()
    cursor.first()
    start_time = float(cursor.key())
    cursor.last()
    end_time = float(cursor.key())
    return start_time, end_time


def get_local_tracking_metrics_lmdb(lmdb_directory, job_id):
    """
    Given a directory containing several tracking metrics LMDBs, and the job ID
    for an LMDB, copies the tracking.lmdb for the given job ID to a temporary
    directory and returns the path to the locally copied tracking metrics LMDB.

    :input lmdb_directory: Path to a directory containing tracking.lmdb
    :input job_id: Job ID for the LMDB to be copied

    :return: local tracking.lmdb path
    """
    # Tracking metrics LMDBs follow the naming convention job_id.tracking.lmdb.
    tracking_metrics_lmdb = os.path.join(lmdb_directory,
                                         job_id + '.tracking.lmdb')
    assert os.path.exists(tracking_metrics_lmdb), (
            'Tracking metrics LMDB for job ID {} in directory '
            '{} doesn\'t exist').format(job_id, lmdb_directory)

    # Make a local temporary dir for downloading tracking metrics LMDB.
    local_tracking_metrics_lmdb_dir = tempfile.mkdtemp(prefix='perception-ci')

    local_tracking_metrics_lmdb = copy_file_to_local_dir(
        tracking_metrics_lmdb, local_tracking_metrics_lmdb_dir)
    # Deal with the case where tracking metrics lmdb does not have metadata.
    remote_tracking_metrics_metadata = tracking_metrics_lmdb + '.metadata'
    assert os.path.exists(remote_tracking_metrics_metadata), (
            "tracking metrics LMDB metadata file {} doesn't exist.".format(
                remote_tracking_metrics_metadata))
    copy_file_to_local_dir(remote_tracking_metrics_metadata,
                           local_tracking_metrics_lmdb_dir)

    return local_tracking_metrics_lmdb


def get_local_segmentation_metrics_pb(pb_dir, job_id):
    """Copy segmentation metrics to local temporary file.

    :input pb_dir: Path to dir containing pb of segmentation metrics.
    :input job_id: Job ID for the metrics to be copied.

    :return: Path to local copy of segmentation metrics.
    """
    segmentation_metrics_pb = os.path.join(pb_dir, job_id + '.segmentation.pb')
    assert os.path.exists(segmentation_metrics_pb), \
        ('Segmentation metrics pbtxt for job ID {} in directory '
         '{} does not exist').format(job_id, pb_dir)

    local_dir = tempfile.mkdtemp(prefix='perception-ci')
    local_pbtxt = copy_file_to_local_dir(segmentation_metrics_pb, local_dir)

    return local_pbtxt


def get_sorted_comparisons(comparisons):
    """
    Given a list of comparisons where each comparison is a LmdbComparison,
    returns a list sorted by increasing weighted_diff field.

    :input comparisons: list of LmdbComparison protos

    :return: comparisons sorted by increasing weighted_diff
    """
    return sorted(comparisons.comparison,
                  key=lambda comp: comp.weighted_diff)


def get_pick_str(i, comparison):
    """
    Given an LmdbComparison and its position in a list of LmdbComparison,
    returns a string which can be used to pick that LMDB.

    :input i: Index of the given comparison in a list of LmdbComparison
    :input comparison: an LmdbComparison

    :return: string which can be used to pick the given LmdbComparison
    """
    return '[{:2d}] {}\t\tdiff: {:.4f}\t\tweighted: {:.4f}'.format(
        i, comparison.lmdb_name, comparison.diff, comparison.weighted_diff)


def compare(control_directory,
            candidate_directory,
            filter_name,
            metric_name,
            diff_binary):
    """
    Given a control_directory and a candidate_directory along with a
    filter_name and a metric_name, generates a list of comparisions for that
    metric across the two runs for each job in the run. Returns a list of
    comparisons where each comparison is a LmdbComparison.

    :input control_directory: Suite directory corresponding to control run.
    :input candidate_directory: Suite directory for candidate run.
    :input filter_name: Name of the filter to be used for comparing.
    :input metric_name: Name of the metric to be used for comparing.
    :input diff_binary: Path to the binary used for comparision. This may vary
                        depending on the type of metric.

    :return: A list of LmdbComparison.
    """
    temp_file = tempfile.NamedTemporaryFile()
    subprocess.check_call([diff_binary,
                           '-control_dir', control_directory,
                           '-candidate_dir', candidate_directory,
                           '-metric_name', metric_name,
                           '-filter_name', filter_name,
                           '-output_lmdb_comparisons', temp_file.name,
                           '-alsologtostderr'])
    lmdb_comparisons = LmdbComparisons()
    ReadProtoAsText(temp_file.name, lmdb_comparisons)
    return lmdb_comparisons


def pick_lmdb(lmdb_comparisons, take_lowest):
    """
    Given a comparison result file, prompts the user to pick one of jobs based
    on the comparision. Returns the name of the picked job's lmdb.

    :input lmdb_comparisons: A list of comparisons where each comparison is a
                             LmdbComparison.
    :input take_lowest: Boolean that indicates if the list should be sorted
                        so that the lowest weighted value is at the top.

    :return: string with the name of the picked job LMDB.
    """
    sorted_comparisons = get_sorted_comparisons(lmdb_comparisons)
    if take_lowest:
        sorted_comparisons = sorted_comparisons[::-1]
    for comp in sorted_comparisons:
        print(comp.lmdb_name, comp.diff, comp.weighted_diff)
    pick_strs = [get_pick_str(i, comp)
                 for (i, comp) in enumerate(sorted_comparisons)]
    _, index = pick(pick_strs,
                    'Pick the LMDB you want to triage',
                    indicator='->')
    return sorted_comparisons[index].lmdb_name


def get_diff_binary_for_component(metrics_component):
    """Returns the diff binary for the given metrics component."""
    diff_binaries = {
        'tracking': 'lidar/metrics/tracking_metrics_diff_per_lmdb',
        'segmentation': 'lidar/metrics/segmentation_metrics_diff_per_lmdb'
    }
    return diff_binaries[metrics_component]


def get_files_for_component(candidate_directory,
                            control_directory,
                            lmdb_name,
                            metrics_component):
    """Get file path arguments for triage runner for given metrics component.

    :input candidate_directory: Path to candidate perception-ci run.
    :input control_directory: Path to control run or None for no comparison.
    :input lmdb_name: Job lmdb for which to retrieve files.
    :input metrics_component: Metrics component for which to retrieve files.

    :output runner_kwargs: Metric file keyword arguments dict for triage
                           runner.
    """
    runner_kwargs = {}
    if metrics_component == 'tracking':
        get_metrics_file = get_local_tracking_metrics_lmdb
        candidate_arg = 'tracking_metrics_lmdb_path'
        control_arg = 'control_tracking_metrics_lmdb_path'
    elif metrics_component == 'segmentation':
        get_metrics_file = get_local_segmentation_metrics_pb
        candidate_arg = 'segmentation_metrics_pb_path'
        control_arg = 'control_segmentation_metrics_pb_path'
    else:
        print('Unknown metrics component {}'.format(metrics_component))
        exit(1)

    runner_kwargs[candidate_arg] = get_metrics_file(
        os.path.join(candidate_directory, 'lmdbs'), lmdb_name)
    if control_directory is not None:
        runner_kwargs[control_arg] = get_metrics_file(
            os.path.join(control_directory, 'lmdbs'), lmdb_name)

    return runner_kwargs


def make_custom_variant_name(suite_path, lmdb_file):
    """Make a custom variant name for the perception-ci triage job.

    The variant name is constructed from the suite path and lmdb file to
    uniquely identify a job lmdb in a perception-ci run. Non-unique
    informantion is removed to constrct a variant name of the form:

        branch_name/run_datetime/suite_name/job_id

    Returns the variant name sanitized for special symbols.
    """
    # Trim non-unique information from the suite path and lmdb file.
    trimmed_path = remove_results_dir(suite_path)
    trimmed_job = lmdb_file.replace(".lmdb", "")
    return sanitize_variant_name(os.path.join(trimmed_path, trimmed_job))


def roots_and_variants_from_metrics_info(metrics_dir, lmdb_file, branch):
    """Get roots and variants from metrics dir, lmdb file, and branch."""
    metrics_suite = gen_metrics_suite_from_dir(metrics_dir)
    metrics_data = get_metrics_data_for_lmdb_name(metrics_suite, lmdb_file)
    lmdb_dir = os.path.join(metrics_dir, 'lmdbs')
    with open(os.path.join(lmdb_dir, 'vision_root'), 'r') as f:
        vision_root = f.read()
    with open(os.path.join(lmdb_dir, 'vision_hash'), 'r') as f:
        vision_hash = f.read()
    with open(os.path.join(lmdb_dir, 'pcp_root'), 'r') as f:
        pcp_root = f.read()
    with open(os.path.join(lmdb_dir, 'datetime_str'), 'r') as f:
        datetime_str = f.read()
    roots = [vision_root, pcp_root]
    variants = [
        make_vision_variant(vision_hash, gen_short_chum_uri(metrics_data)),
        make_pcp_variant(
            branch,
            datetime_str,
            gen_output_file_name_noext(metrics_data)),
    ]
    return roots, variants


def make_chum_uri(id, roots, variants):
    """Make chum uri from roots and variants."""
    uri = "{}?i={}".format(id, ','.join(roots))
    if variants:
        uri = "{}&v={}".format(uri, ','.join(variants))
    return uri

def main(argv=sys.argv[1:]):
    args = parse_args(argv)

    control_directory, candidate_directory = choose_run_directories(
        args.control_branch,
        args.candidate_branch,
        args.suite_name,
        args.use_latest)

    take_lowest = False
    metric_name = None
    filter_name = None
    if args.control_branch is not None:
        (filter_name,
         metric_name,
         take_lowest) = pick_comparison_metric_and_filter(
             candidate_directory,
             args.triage_filter_name,
             args.triage_metric_name,
             args.metrics_component)
    if args.lmdb_file is None:
        assert (control_directory is not None
                and candidate_directory is not None), \
                "Need a control and a candidate directory to autopick file."
        diff_binary = get_diff_binary_for_component(args.metrics_component)
        lmdb_comparisons = compare(control_directory,
                                   candidate_directory,
                                   filter_name,
                                   metric_name,
                                   diff_binary)
        continue_picking = True
        while continue_picking:
            lmdb_file = pick_lmdb(lmdb_comparisons, take_lowest)
            launch_triage_tool_on_lmdb(lmdb_file,
                                       args,
                                       metric_name,
                                       filter_name,
                                       candidate_directory,
                                       control_directory)
            message = ('Continue triaging on another LMDB?')
            continue_picking = \
                (pick(['Yes', 'No'], message, indicator='->')[0] == 'Yes')
    else:
        lmdb_file = args.lmdb_file
        launch_triage_tool_on_lmdb(lmdb_file,
                                   args,
                                   metric_name,
                                   filter_name,
                                   candidate_directory,
                                   control_directory)


def launch_triage_tool_on_lmdb(lmdb_file,
                               args,
                               metric_name,
                               filter_name,
                               candidate_directory,
                               control_directory):
    # Make sure that the lmdb filepath is valid.
    assert lmdb_file.endswith('.lmdb')
    lmdb_path = os.path.join(args.lmdb_dir, lmdb_file)
    if not os.path.isfile(lmdb_path):
        print("LMDB file {} not found in {}".format(lmdb_file, args.lmdb_dir))
        return 1

    # Get the batch id from the lmdb file name.
    lmdb_name = os.path.splitext(lmdb_file)[0]

    start_time, end_time = get_start_end_times(lmdb_path)
    job_id_tuple = job_id_tuple_from_x(lmdb_name)
    print(job_id_tuple)
    if job_id_tuple:
        run_id = job_id_tuple.run_id
    else:
        raise ValueError("Couldn't parse job ID from lmdb name.")

    # Make the BagID from the meta ID, start time, and end time.
    bag_id = util.BagID(run_id, None, start_time, end_time)

    # Get some basic kwargs, including nearly all the arguments from argparse.
    runner_kwargs = deepcopy(vars(args))
    runner_kwargs['edit'] = False
    runner_kwargs['perception_ci_lmdb_basename'] = lmdb_name
    runner_kwargs['ground_truth_input_source'] = lmdb_path
    if metric_name is not None:
        runner_kwargs['triage_metric_name'] = metric_name
    if filter_name is not None:
        runner_kwargs['triage_filter_name'] = filter_name

    # Remove flags that are unique to perception-ci triage and thus don't
    # correspond to a triage tool launcher kwarg.
    pcp_ci_triage_only_flags = [
            'lmdb_file',
            'suite_name',
            'candidate_branch',
            'control_branch',
            'use_latest',
    ]
    for flag in pcp_ci_triage_only_flags:
        del runner_kwargs[flag]

    # Update file arguments in runner_kwargs needed for metrics component.
    runner_kwargs.update(
        get_files_for_component(candidate_directory,
                                control_directory,
                                lmdb_name,
                                args.metrics_component))

    # Get info about the Chum outputs from perception-ci.
    roots, variants = roots_and_variants_from_metrics_info(
        candidate_directory, lmdb_file, args.candidate_branch)

    if args.argus:
        candidate_uri = make_chum_uri(run_id, roots, variants)
        layout = 'pcp_tracker'

        control_uri = None
        if control_directory:
            control_uri = make_chum_uri(
                run_id,
                *roots_and_variants_from_metrics_info(
                    control_directory, lmdb_file, args.control_branch)
            )
            layout = 'pcp_tracker_comparison'
        link = render_link(
            candidate_uri,
            comparison_uri=control_uri,
            layout=layout,
            time=start_time)
        print("Argus link:")
        print(link)

    else:
        runner_kwargs['remote_roots'] = roots
        runner_kwargs['variants'] = variants

        # Create a unique variant to store chum output.
        runner_kwargs['custom_variant_name'] = make_custom_variant_name(
            candidate_directory, lmdb_file)

        runner = launch_triage_tool.TriageToolRunner(bag_id, **runner_kwargs)
        runner.main()


def get_current_perception_ci_suite_name():
    suite_pbtxt = get_data_location().perception_ci_data_set
    metrics_suite = MetricsSuite()
    ReadProtoAsText(suite_pbtxt, metrics_suite)
    return metrics_suite.name

def gen_metrics_suite_from_dir(metrics_dir):
    """Get metrics suite from path to metrics suite results directory."""
    suite_pbtxt = os.path.join(metrics_dir, "metrics_suite.pbtxt")
    metrics_suite = MetricsSuite()
    ReadProtoAsText(suite_pbtxt, metrics_suite)
    return metrics_suite


def get_metrics_data_for_lmdb_name(metrics_suite, lmdb_name):
    """Get metrics data from metrics suite matching lmdb name."""
    for metrics_data in metrics_suite.data:
        init_chum_uri_fields(metrics_data)
        if (gen_output_file_name_noext(metrics_data) ==
                strip_last_extension(lmdb_name)):
            return metrics_data
    return None


def parse_args(argv):
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        '--lmdb_file',
        type=str,
        help=('The basename of an lmdb file that is in a perception_ci '
              'suite. Ex: "760.lmdb"'))
    parser.add_argument(
        '--candidate_branch',
        type=str,
        required=True,
        default=None,
        help=('If a branch name is given, you will be able to pick a '
              'perception ci run from which to retrieve lmdbs to view '
              ' in the triage tool.'))
    parser.add_argument(
        '--suite_name',
        type=str,
        default=get_current_perception_ci_suite_name(),
        help=('Metrics suite to use for the perception-ci triaging'))
    parser.add_argument(
        '--lmdb_dir',
        type=str,
        default=get_staging_dir('bounding_box'),
        help='Directory containing ground truth lmdbs')
    parser.add_argument(
        '--control_branch',
        type=str,
        default=None,
        help=('If a control branch name is given, you will be able to pick '
              'a perception ci run from which to retrieve lmdbs to compare '
              'tracking metrics in the triage tool.'))
    parser.add_argument(
        '--use_latest',
        action='store_true',
        default=False,
        required=False,
        help=('Adding this flag will use the latest perception ci run for '
              'the specified branch.'))
    parser.add_argument(
        '--triage_metric_name',
        type=str,
        default=None,
        help=('Name of the metric for which to triage regressions.'))
    parser.add_argument(
        '--triage_filter_name',
        type=str,
        default=None,
        help=('Name of the filter for which to triage regressions.'))
    parser.add_argument(
        '--metrics_component',
        default='tracking',
        choices=[
            'tracking',
            'segmentation'
        ],
        help=('Which component of metrics to use when comparing control and '
              'candidate branches and for selecting lmdb to compare. '
              'Defaults to tracking metrics.'))

    util.add_tool_args(parser)
    return parser.parse_args(argv)


if __name__ == '__main__':
    main()
