import argparse

from google.protobuf import text_format
from prediction.metrics.apps.evaluation_results_pb2 import EvaluationResultsProto


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_eval_results', required=True)
    parser.add_argument('--output_eval_results', required=True)
    return parser.parse_args()


def load_eval_results(pbtxt):
    eval_results = EvaluationResultsProto()
    with open(pbtxt, 'r') as f:
        text_format.Parse(f.read(), eval_results)
    return eval_results


def main():
    args = parse_args()
    assert args.input_eval_results != args.output_eval_results

    eval_results = load_eval_results(args.input_eval_results)

    # Loop over the samples and collect events.
    samples_by_entity_by_run = {}
    for s in eval_results.sample:
        # Only add samples where the entity actually exited.
        if s.classification_result.label:
            samples_by_entity = samples_by_entity_by_run.setdefault(
                s.run_id, {})
            samples_by_entity.setdefault(s.entity_id, [])
            samples_by_entity[s.entity_id].append(s)

    for run in samples_by_entity_by_run:
        samples_by_entity = samples_by_entity_by_run[run]
        for eid in samples_by_entity:
            samples = samples_by_entity[eid]
            # Sort samples by time.
            samples.sort(key=lambda s: s.ts)

            for s in samples:
                entity_event_id = '{run_id}_{entity_id}_{track_key}'.format(
                    run_id=s.run_id, entity_id=s.entity_id,
                    track_key=s.track_group_key)
                entity_event_id = entity_event_id.replace(
                    '(', '').replace(')', '').replace(', ', '_')
                s.entity_event_id = entity_event_id

    samples = [s for s in eval_results.sample if len(s.entity_event_id) > 0]
    eval_results.ClearField('sample')
    eval_results.sample.extend(samples)

    print('Note: the indexes for tags will be incorrect!')

    with open(args.output_eval_results, 'w') as f:
        f.write(text_format.MessageToString(eval_results))


if __name__ == '__main__':
    main()
