from __future__ import print_function
from __future__ import division
import botocore
import boto3
import datetime
import errno
import lmdb
import os
import re
import sys
import tarfile
import yaml

from collections import namedtuple

from base.file.utils.file_utils import mkdir_p
from data.chum import chumpy, uri_pb2
from infra.data_catalog.client import data_rest_api
from labeling.lt3d.webserver.util.run_files_sync import (
        download_all_files_from_run,
)
from mapping.component.offline.chum_reader_py import (
        chumReaderInstantiateZrnIdentifiers,
)
from mapping.zrn.client.common_py import (
        identifierToFilepath,
)

# The output directories for the SDL and pbtxt versions of perception log tests.
SDL_OUTPUT_DIR = os.path.join(os.getenv('ZOOX_WORKSPACE_ROOT', ''),
                              'log_tests',
                              'pcp_log_tests')
PBTXT_OUTPUT_DIR = os.path.join(SDL_OUTPUT_DIR, 'test_protos')

# The root that triage will write any variants for regenerated vision/perception
# to.
TRIAGE_CHUM_ROOT = os.path.expanduser('~/.cache/chum')

# The base directory used for the caching of params.yaml and ZRN files (also
# used by base::initFromLog).
PARAMS_CACHE_DIR_BASE = os.path.expanduser('~/.cache/zoox/accessor')

# A named tuple formalizing the contents of a BagID, namely the meta_id string
# for the run and floats for the incident time, start time, and end time of the
# crop. If incident, start_time, and end_time are all None, this BagID refers to
# the full, uncropped run. The incident time, if specified, has an implied start
# time 35 seconds before the incident time and an implied end time 10 seconds
# after.
BagID = namedtuple(
    'BagID', ['meta_id', 'incident', 'start_time', 'end_time'])

def add_perception_runner_args(subparser):
    """
    Add options that all commands that could involve running perception
    (i.e. view, edit, and run_tests) shares.
    """
    subparser.add_argument(
        "--run_perception",
        action="store_true",
        help="If specified, display the obstacles generated by "
        "offline_perception_main rather than those saved in the logged "
        "Chum.")
    subparser.add_argument(
        "--cached_perception",
        action="store_true",
        help="If specified, reuse the previously regenerated offline "
        "perception Chum variant rather than re-running "
        "offline_perception_main.")
    subparser.add_argument(
        "--zcache_vision_dataset",
        type=str,
        default=None,
        help="string indicating the name of the ZCache vision dataset if to "
             "use the ZCache regenerated vision rather than re-running "
             "vision. ZCache vision is used only when all of "
             "zcache_vision_dataset, zcache_vision_build and "
             "zcache_vision_version are specified."
    )
    subparser.add_argument(
        "--zcache_vision_build",
        type=str,
        default=None,
        help="string indicating the build used to generate the ZCache vision "
             "dataset if to use the ZCache regenerated vision rather than "
             "re-running vision. ZCache vision is used only when all of "
             "zcache_vision_dataset, zcache_vision_build and "
             "zcache_vision_version are specified."
    )
    subparser.add_argument(
        "--zcache_vision_version",
        type=int,
        default=None,
        help="int indicating the version of the ZCache vision dataset if to "
             "use the ZCache regenerated vision rather than re-running "
             "vision. ZCache vision is used only when all of "
             "zcache_vision_dataset, zcache_vision_build and "
             "zcache_vision_version are specified."
    )
    subparser.add_argument(
        "--cached_vision",
        action="store_true",
        help="whether to use the previously regenerated vision rather than"
        "re-running vision.")
    subparser.add_argument(
        "--run_vision",
        action="store_true",
        help="If specified, use an updated vision bag (either re-running vision "
        " from scratch or using a cached vision bag locally or on flashblade) to "
        "regenerate the VisionTrack3DPs used by triage and offline_perception_main. Will "
        "also set --run_perception to true. Vision will not re-run if you "
        "have already triaged this incident/run and also specified --cached_vision.")
    subparser.add_argument(
        "--custom_root",
        type=str,
        default=None,
        help="If specified, use a different chum root to store the new data "
        "created by run_perception and run_vision.")
    subparser.add_argument(
        "--custom_variant_name",
        type=str,
        default=None,
        help="If specified, use a different variant to store the new data "
        "created by run_perception.")

def add_tool_args(subparser):
    """
    Add the options that both "tool" commands (view and edit) share.
    """
    add_perception_runner_args(subparser)
    subparser.add_argument(
        "--camera_triage",
        action="store_true",
        help="When this is specified then camera triage tool is launched.")
    subparser.add_argument(
        "--validator_errors",
        action="store_true",
        help="If specified, display error messages from PO2 validator.")
    subparser.add_argument(
        "--alsologtostderr",
        action="store_true",
        help="If specified, display info logs.")
    subparser.add_argument(
        "--ground_truth_input_source",
        type=str,
        default=None,
        help="Path to ground truth LMDB or chum URI containing labeled "
        "scale data (optional). If provided, ground truth annotations will "
        "be displayed. LMDB file can contain either annotated spin or "
        "annotated frame protos.")
    subparser.add_argument(
        "--ground_truth_association_frame_lmdb_path",
        type=str,
        default=None,
        help="Path to ground truth association farme LMDB (optional). If "
        "provided, ground truth associations will be displayed. LMDB file "
        "must contain GroundTruthAssocFrameProtos.")
    subparser.add_argument(
        "--test_pbtxt_path",
        type=str,
        default=None,
        help="Path to the pbtxt file defining the PCP log tests to visualize. "
        "If not specified, infer this automatically from the BagID.")

    subparser.add_argument(
        "--tracking_metrics_lmdb_path",
        type=str,
        default=None,
        help="Path to TrackingMetricsFrame LMDB (optional). If provided, "
        "all matches will be shown.")
    subparser.add_argument(
       "--occlusion_ground_truth_lmdb_path",
       type=str,
       default=None,
       help="Path to OcclusionGroundTruth LMDB (optional). If provided, "
       "occlusion ground truth will be shown.")
    subparser.add_argument(
        "--segmentation_metrics_pb_path",
        type=str,
        default=None,
        help="Path to SegmentationEvalList pb. If provided, evaluated object "
             "boxes will be shown.")
    subparser.add_argument(
        "--window_name",
        type=str,
        default=None,
        help="If specified, give the triage tool GUI a different window "
        "name than the default 'Triage Tool'.")
    subparser.add_argument(
        "--force_ll3d",
        action="store_true",
        help="If specified rerun ll3d components, e.g. aligner, instead of "
        "using logged results. Will also set --run_perception to true.")
    subparser.add_argument(
        "--show_notes",
        action="store_true",
        help="If specified, show the notes pane in the triage tool")
    subparser.add_argument(
        "--no_prediction",
        action="store_false",
        dest="load_prediction",
        help="If specified, do not load prediction messages")

    subparser.add_argument(
               "--use_ground_heights",
               action="store_true",
               help="If specified, use Lidar estimated ground heights for 3D "
               "visualization of layers including ZRN, planner-interactions "
               "and OcclusionGrid. Requires CUDA.")

    start_time_args = subparser.add_mutually_exclusive_group()
    start_time_args.add_argument(
        "--start_time",
        type=float,
        default=None,
        help="Start the triage tool at a specific timestamp in the log. This "
        "does not affect which data are loaded in Chum or how vision, "
        "or perception are re-run. Cannot be used with "
        "--start_offset.")
    start_time_args.add_argument(
        "--start_offset",
        type=float,
        default=-5.0,
        help="Start the triage tool at a specific offset from the incident "
        "time (either positive or negative, in seconds). This does not affect "
        "which data are loaded in Chum or how vision, or perception "
        "are re-run. Ignored if the incident string is of any form "
        "other than <meta_id|JIRA issue>@<incident_time>. Cannot be used with "
        "--start_time.")

    subparser.add_argument(
        "--non_debug_mode",
        action="store_true",
        help="If specified, turn off debug mode when running offline "
             "perception.")
    subparser.add_argument(
        "--dtn_debug_mode",
        action="store_true",
        help="If specified, perform stricter checkings for debugging DTN.")
    subparser.add_argument(
        "--log_track_attr_model_features",
        action="store_true",
        help="If specified, log track attributes features for debugging.")
    subparser.add_argument(
        "--no_yaw_computer",
        action="store_true",
        help="If specified, the yaw computer is disabled.")
    subparser.add_argument(
        "--log_yaw_computer_info",
        action="store_true",
        help="If specified, the yaw computer info is logged.")
    subparser.add_argument(
        "--log_extent_history",
        action="store_true",
        help="If specified, track extent history is logged.")
    subparser.add_argument(
        "--override_params_path",
        type=str,
        default="",
        help="path to a pbtxt that will override specified tracker params")
    subparser.add_argument(
        "--argus",
        action="store_true",
        help="If specified, creates an argus link for the final chum URI "
             "instead of opening the triage tool GUI.")
    subparser.add_argument(
        "--argus_port",
        type=int,
        default=8080,
        help="Port at which local argus server is running.")

def __raise_data_catalog_error(response, error):
    message = response.get("message", None)
    if message:
        error += ": " + error
    else:
        error += "."
    raise ValueError(error)

def is_full_log(bag_id):
    return (bag_id.incident is None and
            (bag_id.start_time is None or bag_id.end_time is None))

def bag_id_to_chum_uri(bag_id):
    """
    Converts a BagID to a Chum URI. Prefers to use the start and end times, but
    if needed, will use incident - 35 as the start time and incident + 10 as the
    end time.

    :param bag_id: BagID named tuple to convert.

    :return: string Chum URI.
    """
    if bag_id.start_time is not None and bag_id.end_time is not None:
        return '{meta_id}@{start}-{end}'.format(
                meta_id=bag_id.meta_id,
                start=bag_id.start_time,
                end=bag_id.end_time)
    elif bag_id.incident is not None:
        return '{meta_id}@{start}-{end}'.format(
                meta_id=bag_id.meta_id,
                start=bag_id.incident - 35.0,
                end=bag_id.incident + 10.0)
    else:
        return bag_id.meta_id

def get_start_and_end_times(bag_id):
    """
    Gets the start and end times of a BagID, or (None, None) if the BagID is for
    the full log. Prefers to use the start and end times, but if needed, will
    use incident - 35 as the start time and incident + 10 as the end time.

    :param bag_id: BagID named tuple to convert.

    :return: tuple of two floats (start and end time) or tuple of two Nones
    """
    if bag_id.start_time is not None and bag_id.end_time is not None:
        return bag_id.start_time, bag_id.end_time
    elif bag_id.incident is not None:
        return bag_id.incident - 35.0, bag_id.incident + 10.0
    else:
        return None, None

def _lookup_meta_id_from_jira_issue(jira_issue):
    """
    Given a JIRA issue, return the meta_id for that run.
    """
    response = data_rest_api.get_runs_from_issue(jira_issue)
    if not response["success"]:
        error = "Could not query JIRA issue '{}' from data catalog".format(jira_issue)
        __raise_data_catalog_error(response, error)

    if len(response["runs"]) == 0:
        raise ValueError("Could not find references to JIRA issue '{}'".format(jira_issue))
    if len(response["runs"]) > 1:
        runs = ", ".join(response["runs"])
        error = "Found multiple runs referenced in JIRA issue '{}': {}".format(jira_issue, runs)
        raise ValueError(error)

    return response["runs"][0]["hr_id"]

def _make_float_or_none(float_str):
    """
    Try to convert a string to a float, but return None if this s impossible.

    :param float_str: a string, ideally of numeric format
    :return: a float or None
    """
    try:
        return float(float_str)
    except ValueError as e:
        return None

def make_time_str(bag_id):
    """
    Given a BagID, returns a string identifying that BagID's time. This is
    either "<incident timestamp>", "<start time>-<end time>", or "all".
    """
    if bag_id.incident:
        return '{:.2f}'.format(bag_id.incident)
    elif bag_id.start_time and bag_id.end_time:
        return '{0:.2f}-{1:.2f}'.format(bag_id.start_time, bag_id.end_time)
    else:
        return 'all'

def parse_meta_id_from_tar_gz_path(tar_gz_path):
    """
    Given an path to a tar.gz for a run typically stored in flashblade or
    data_cache, return the meta_id of that run.
    """
    try:
        return '-'.join(tar_gz_path.split('/')[-1].split('-')[0:2])
    except:
        return None

def parse_s3_path(s3_path):
    """
    Parse an S3 path assumed to be in the format 's3://bucket/key'.
    Returns a (bucket, key) pair. Throws a ValueError if parsing fails.
    """
    match = re.match(r"^s3://([^/]*)(?:/(.*))?$", s3_path)
    if not match:
        raise ValueError("Invalid s3 path: {}".format(s3_path))
    (bucket, key) = match.groups(default="")

    return (bucket, key)

def _get_incident_time(start_time, end_time):
    """
    Given the start and end time for a crop, returns the incident time (which
    would be 20 seconds after the start time and 10 seconds before the end time)
    if possible, or None if not.

    :param start_time: float
    :param end_time: float
    :return: float or None
    """
    EPSILON_TIME = 1e-5
    if abs(end_time - start_time - 30.0) < EPSILON_TIME:
        return start_time + 20.0
    return None

def parse_input_string(input_str):
    """
    Creates a BagID from an incident string of the following format:
    "<meta_id | JIRA_issue>@<timestring>" where the first part is either a meta
    ID or a JIRA issue and the timestring is one of the following formats:
    * "<start_time>-<end_time>"
    * "<start_time>+<duration>"
    * "<incident_time>" (in which case the start time is 35 seconds before the
      incident time and the end time is 10 seconds after the incident time)
    * "<anything_else>" (in which case there is no start/end/incident time and
      the entire run is triaged)
    Returns a BagID containing the meta ID, the start/end/incident times, and
    the JIRA issue for the underlying incident if possible (queried from the
    JIRA API).

    :param input_str: a string
    :return: BagID
    """
    split_input = input_str.split('@')
    meta_id = split_input[0]
    incident = None
    start_time = None
    end_time = None
    jira_issue = None

    # If you cannot parse the meta_id, treat it as a JIRA issue instead.
    try:
        hr_data_id = data_rest_api.data_id_to_human_readable(meta_id)
        if hr_data_id is not None:
            meta_id = hr_data_id
        parse_meta_id(meta_id)
    except ValueError:
        jira_issue = meta_id
        meta_id = _lookup_meta_id_from_jira_issue(jira_issue)

    if len(split_input) > 1:
        timestamp_str = split_input[1]
        # The start_end time format is "<start_time>-<end_time>".
        start_end_match = re.match(r'(\d*\.?\d*)-(\d*\.?\d*)', timestamp_str)
        # The duration time format is "<start_time>+<duration>".
        duration_match = re.match(r'(\d*\.?\d*)\+(\d*\.?\d*)', timestamp_str)

        if start_end_match is not None:
            start_time = float(start_end_match.group(1))
            end_time = float(start_end_match.group(2))
        elif duration_match is not None:
            start_time = float(duration_match.group(1))
            end_time = start_time + float(duration_match.group(2))
        else:
            incident = _make_float_or_none(timestamp_str)

    return BagID(meta_id, incident, start_time, end_time)

def parse_cli_chum_uri(chum_uri):
    """
    Parses an input Chum URI string into a BagID for that time range (with the
    'incident' time set to None), and a list of Chum variants and remote Chum
    input roots.

    :param chum_uri: string Chum URI

    :return: BagID
    :return: list of string Chum variants
    :return: list of string Chum input roots
    """
    # Get the meta ID from the Chum URI proto.
    uri_proto = chumpy.parseChumUriToProto(chum_uri)
    assert uri_proto.HasField('meta_id'), (
            'Did not specify a meta ID in the chum URI {}'.format(chum_uri))
    meta_id = uri_proto.meta_id

    # Get the start/end times and variant/input lists directly from the store
    # and range produced by this URI.
    store, run_range = chumpy.parseChumUri(chum_uri)
    start_time = run_range.start_time / 1e9
    end_time = run_range.end_time / 1e9
    variants = [v.name for v in store.variants()]
    inputs = [i.root for i in store.inputs()]

    if round(end_time - start_time, 2) == 45.0:
        incident_time = end_time - 10
    else:
        incident_time = None

    return BagID(meta_id, incident_time, start_time, end_time), variants, inputs

def bag_id_to_input_string(bag_id):
    """
    Converts a BagID into the input string that would have generated it.

    :param bag_id: a BagID named tuple
    :return: string
    """
    return '{}@{}'.format(bag_id.meta_id, make_time_str(bag_id))

def parse_meta_id(meta_id):
    """
    Given a meta_id in the 'YYYYMMDDTHHMMSS-VEHICLENAME' format, returns a
    parsed (datetime, vehiclename) pair.
    """

    # A meta_id must consist of exactly two parts: the datetime and the vehicle name.
    parts = meta_id.split('-')
    if len(parts) != 2:
        raise ValueError("meta_id must containt exactly two parts seperated by a dash.")

    (date_str, vehicle_name) = parts
    try:
        date = datetime.datetime.strptime(date_str, "%Y%m%dT%H%M%S")
        return (date, vehicle_name)
    except ValueError:
        raise ValueError("Invalid timestamp in meta_id.")

def get_s3_object(s3, path, check=True):
    """
    Try to find a specific file on s3 and return an s3 object representing it.
    By default, returns None if the file does not exist. If check=False is set,
    this function will return an object referring to a non-existant s3 object.
    This is useful for uploading data.
    """
    (bucket, key) = parse_s3_path(path)
    obj = s3.Object(bucket, key)

    if check:
        # Initiate a HEAD request to ensure that the file exists.
        try:
            obj.load()
        except botocore.exceptions.ClientError as e:
            if e.response["Error"]["Message"] == "Not Found":
                print("S3 Object does not exist: {}".format(path))
                return None
            else:
                raise e

    return obj

def get_lmdb_start_end_times(metrics_data):
    """
    Gets an appropriate start and end time from the annotated spin LMDB
    specified. The start time is obtained as start time of LMDB -
    kVisionWarmupTimeInSeconds, and end time = end time of LMDB.
    We set the start_time, end_time fields in the MetricsInput proto
    with the values we computed.
    """
    annospin_lmdb = lmdb.Environment(path=metrics_data.annospin_lmdb_path,
                                     subdir=False,
                                     readonly=True,
                                     create=False)
    cursor = annospin_lmdb.begin().cursor()
    cursor.first()
    start_time = float(cursor.key())
    cursor.last()
    end_time = float(cursor.key())
    if not metrics_data.HasField('start_time'):
        metrics_data.start_time = start_time
    if not metrics_data.HasField('end_time'):
        metrics_data.end_time = end_time

def extract_params_yaml(tarball_path, output_path, force=False):
    """
    Extract the params yaml from a params tarball and write it to the output
    path. If force is True, will do this even if the output path already exists.

    :param tarball_path: string path of the params tarball.
    :param output_path: string path to write the params.yaml file to.
    :param force: optional bool defaulting to False. If False, will no-op if
                  the output path already exists; if True will run regardless.
    """
    if os.path.isfile(output_path) and not force:
        print('{} already exists; returning'.format(output_path))
        return

    tar_info = tarfile.TarFile.open(tarball_path)
    yaml_obj = yaml.load(tar_info.extractfile('params.yaml').read())

    with open(output_path, 'w') as f:
        f.write(yaml.dump(yaml_obj))

# TODO(sabeek): Have this take in a meta_id/chum URI.
def populate_cache_directory(bag_id, cache_base_dir=None):
    """
    Create, fill, and return a cache directory for this BagID that contains
    the params.yaml and RN2 files for that run. Used to get params and ZRN files
    to run triage/PCP for a particular BagID.

    :param bag_id: BagID named tuple.

    :return: string path to cache directory.
    """
    if not cache_base_dir:
        cache_dir = os.path.expanduser(
            os.path.join(PARAMS_CACHE_DIR_BASE, bag_id.meta_id))
    else:
        cache_dir = os.path.join(cache_base_dir, bag_id.meta_id)
    mkdir_p(cache_dir)

    # The files we want to store in this cache directory, namely the params.yaml
    # and ZRN files.
    target_files = ['params.yaml', 'zrn.ZooxRoadNetwork.pb']

    # Only download the params tarball if at least one of our targeted files is
    # missing.
    if any(not os.path.exists(os.path.join(cache_dir, f)) for f in target_files):
        tarball_path = get_params_tarball(bag_id)
        tar_info = tarfile.TarFile.open(tarball_path)

        for basename in target_files:
            full_path = os.path.join(cache_dir, basename)
            if not os.path.exists(full_path):
                print('Extracting {}'.format(basename))
                with open(full_path, 'w') as f:
                    f.write(tar_info.extractfile(basename).read())

    return cache_dir

def get_params_tarball(bag_id):
    """
    For a given BagID, gets the path to its params tarball.

    :param bag_id: BagID named tuple.

    :return: string path to params tarball.
    """
    cache_dir = os.path.expanduser(
        os.path.join(PARAMS_CACHE_DIR_BASE, bag_id.meta_id))
    mkdir_p(cache_dir)
    downloaded_files = download_all_files_from_run(
            run_identifier=bag_id.meta_id,
            local_dir=cache_dir,
            file_types=['ROSParamsTar'])
    assert len(downloaded_files['ROSParamsTar']) == 1
    return downloaded_files['ROSParamsTar'][0]

def get_relative_zrn_path_from_chum(chum_uri):
    """
    Get a relative path to the ZRN file from the given Chum URI. The returned
    path looks something like "vehicle/zrn/sf.zrn.release".
    """
    zrn_identifiers = chumReaderInstantiateZrnIdentifiers(chum_uri)
    if zrn_identifiers:
        # Usually there should be just one identifier. Multiple identifiers are
        # nonetheless possible, e.g. the logger is on while ZRN is getting
        # changed, in this case we select the latest one.
        return identifierToFilepath(zrn_identifiers[-1].identifier)
    else:
        raise ValueError(
            "No ZRN found from the Chum URI '{}'.".format(chum_uri))

def make_test_name(bag_id):
    """
    Returns the name of the triage test that would be created for this BagID.

    :param bag_id: BagID named tuple.

    :return: string name of triage test.
    """
    return '{}-{}'.format(bag_id.meta_id, make_time_str(bag_id))

def get_test_pbtxt_path(bag_id):
    """
    Returns the path where the pbtxt test files for this BagID should be stored.

    :param bag_id: BagID named tuple.

    :return: string path to pbtxt.
    """
    return os.path.join(
            PBTXT_OUTPUT_DIR,
            '{}.pbtxt'.format(make_test_name(bag_id)))

def get_test_sdl_path(bag_id):
    """
    Returns the path where the pbtxt test files for this BagID should be stored.

    :param bag_id: BagID named tuple.

    :return: string path to pbtxt.
    """
    return os.path.join(
            SDL_OUTPUT_DIR,
            '{}.sdl.py'.format(make_test_name(bag_id)))
