""" Sample usage.

brun :ztrace_report -- --id 20190505T211409-kitt_01@1557093786379488256-1557094231020820480 --output $PWD/20190505T211409-kitt_01.csv
"""
import argparse
import datetime
import pandas as pd

from data.chum import chumpy

from vehicle.common.ztrace.metrics_pb2 import PerTickMetrics
from vehicle.common.proto.decision_pb2 import Decision
from vehicle.common.proto.decision_debug_pb2 import Decision as DecisionDebug
from scripts.tools.system_latency.latency_routes import find_latency_candidates

def extract_zci_tick_all(per_tick_metric):
    def is_zci_tick_all(tick):
        return tick.probe == 0 or tick.probe == 1
    begin_and_end = list(filter(is_zci_tick_all, per_tick_metric.events))
    assert len(begin_and_end) == 2
    return (begin_and_end[1].timestamp_microsec -
            begin_and_end[0].timestamp_microsec)/1e3


def extract_route_slen(decision):
    return decision.route_length

def extract_route_s(decision):
    return decision.route_stz.s


def extract_active_entities(decision_debug):
    agents = decision_debug.agent_relevance_filter.agent_status
    return len(list(filter(lambda x: x.valid, agents)))


def extract_search_results_rollout(decision_debug):
    search_results = decision_debug.search_results

    num_rollout_timesteps = 0
    for rollout in search_results.DEPRECATED_action_rollout_results:
        num_rollout_timesteps += len(rollout.rollout_cost_results)
    return num_rollout_timesteps


def extract_best_search_results_rollout(decision_debug):
    search_results = decision_debug.search_results

    best_action = search_results.DEPRECATED_best_action_type
    for rollout in search_results.DEPRECATED_action_rollout_results:
        if rollout.action_type == best_action:
            return len(rollout.rollout_cost_results)
    return -1


def time_synchronize_planner_data(run_id):
    store, range = chumpy.parseChumUri(run_id)
    range.topics = set([
        '/planner/decision',
        '/planner/decision/debug',
        '/planner/decision/metrics'
    ])
    reader = chumpy.Reader.create(store, range)

    last_per_tick_metric = None
    last_decision = None
    last_decision_debug = None

    changed_per_tick_metric = False
    changed_decision = False
    changed_decision_debug = False

    for message in reader:
        if message.type_name == 'zoox.metrics.proto.PerTickMetrics':
            last_per_tick_metric = PerTickMetrics()
            last_per_tick_metric.ParseFromString(message.getData())
            changed_per_tick_metric = True
        elif message.type_name == 'zoox.common.proto.Decision':
            last_decision = Decision()
            last_decision.ParseFromString(message.getData())
            changed_decision = True
        elif message.type_name == 'zoox.planner.proto.debug.Decision':
            last_decision_debug = DecisionDebug()
            last_decision_debug.ParseFromString(message.getData())
            changed_decision_debug = True
        else:
            raise Exception("Got unexpected message of type '%s'" % message.type_name)

        if changed_per_tick_metric and changed_decision and changed_decision_debug:
            yield {
                'dt': datetime.datetime.utcfromtimestamp(message.message_timestamp/1e9),
                'num_active_agents':
                extract_active_entities(last_decision_debug),
                'route_s': extract_route_s(last_decision),
                'route_len': extract_route_slen(last_decision),
                'zci_tick_all': extract_zci_tick_all(last_per_tick_metric),
                'total_rollout_timesteps':
                extract_search_results_rollout(last_decision_debug),
                'actual_rollout_timesteps':
                extract_best_search_results_rollout(last_decision_debug),
            }
            changed_per_tick_metric = False
            changed_decision = False
            changed_decision_debug = False

def parse_args():
    parser = argparse.ArgumentParser("Create pretty ztrace reports for PRC runs")
    parser.add_argument('--run_meta_id', '--id',
            required=True,
            type=str,
            help='Run ID to grab ztrace information for')
    parser.add_argument('--output',
            required=True,
            type=str,
            help='Path to output CSV')
    return parser.parse_args()

"""
'19.18': '20190505T204940-kitt_10@1557092361493110784-1557092776566482176',
'19.18': '20190502T160550-kitt_06@1556815394234288384-1556815819274500864',
'19.16': '20190505T211409-kitt_01@1557093786379488256-1557094231020820480',
"""

def main():
    args = parse_args()
    #x = find_latency_candidates(args.run_meta_id)
    #print(x)
    df = pd.DataFrame(time_synchronize_planner_data(args.run_meta_id))
    df.to_csv(args.output)

if __name__=="__main__":
    main()
