import argparse
import random
import time

from collections import namedtuple
from enum import Enum, auto
from itertools import zip_longest
from pprint import pprint
from typing import Dict, List, NamedTuple, Optional, Set, Tuple

from google.protobuf import text_format
from prediction.data.entity_event_pb2 import (
    EntityEvent, EntityEventList, JunctionExitEvent
)
from prediction.learning.topdown.junction_exit_dense.junction_map_pb2 import JunctionMap
from vehicle.common.proto import junction_pb2


class JunctionType(Enum):
    FIFO = auto()
    PRIORITY = auto()
    TRAFFICLIGHT = auto()

    @classmethod
    def from_proto_type(cls, t):
        conversion_map = {
            junction_pb2.JUNCTION_TYPE_FIFO: cls.FIFO,
            junction_pb2.JUNCTION_TYPE_PRIORITY: cls.PRIORITY,
            junction_pb2.JUNCTION_TYPE_TRAFFICLIGHT: cls.TRAFFICLIGHT,
        }
        return conversion_map[t]


class JunctionTag(NamedTuple):
    junction_type: JunctionType
    n_ways: int


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_events_pb', required=True)
    parser.add_argument('--junction_map_pbtxt', required=True)
    parser.add_argument('--output_pbtxt', required=True)
    parser.add_argument('--stats_fn', required=True)
    return parser.parse_args()


def load_events(fn):
    # type: (str) -> EntityEventList
    eel = EntityEventList()
    with open(fn, 'rb') as f:
        # This is slow for large pbs, but comparable to the C++ timing.
        eel.ParseFromString(f.read())
    return eel


def load_junction_map(fn):
    # type: (str) -> JunctionMap
    m = JunctionMap()
    with open(fn) as f:
        text_format.Parse(f.read(), m)
    return m


def get_junction_tag(event):
    assert event.HasExtension(JunctionExitEvent.ext)
    jee = event.Extensions[JunctionExitEvent.ext]
    return JunctionTag(
        junction_type=JunctionType.from_proto_type(jee.junction_type),
        n_ways=jee.num_ways)


def get_junction_id(event, junction_map):
    # type: (EntityEvent, JunctionMap) -> Optional[int]
    assert event.HasExtension(JunctionExitEvent.ext)
    jee = event.Extensions[JunctionExitEvent.ext]
    if jee.junction_id in junction_map.old_to_new_junction_id:
        remapped = junction_map.old_to_new_junction_id[jee.junction_id]
        return remapped if remapped != 0 else None
    return jee.junction_id


def is_event_valid(event, junction_map):
    # Returns whether we should keep the event.
    if get_junction_id(event, junction_map) is None:
        return False

    assert event.HasExtension(JunctionExitEvent.ext)
    jee = event.Extensions[JunctionExitEvent.ext]
    # Must have more than one connection option for at least one tick. If
    # there's only one option, prediction is trivial.
    return any(
        len(tick_data.connection_option) > 1 for tick_data in jee.tick_data
    )


def group_events_by_junction_tag(eel, junction_map):
    # type: (EntityEventList, JunctionMap) -> Dict[JunctionTag, Dict[int, List[EntityEvent]]]
    ret = {}  # type: Dict[JunctionTag, Dict[int, List[EntityEvent]]]
    for event in eel.event:
        if not is_event_valid(event, junction_map):
            continue
        junction_id = get_junction_id(event, junction_map)
        # If event is valid, junction ID should not be None!
        assert junction_id is not None

        tag = get_junction_tag(event)
        ret.setdefault(tag, {})
        events_by_junction = ret[tag]
        events_by_junction.setdefault(junction_id, [])
        events_by_junction[junction_id].append(event)
    return ret


def filter_events_by_junction_by_tag(events_by_junction_by_tag):
    # type: (Dict[str, Dict[int, List[EntityEvent]]]) -> Dict[str, Dict[int, List[EntityEvent]]]
    def filter1():
        MIN_EVENTS = 5000
        MAX_FACTOR = 3

        junction_counts = [
            (j, len(events_by_junction_by_tag[tag][j]))
            for tag in events_by_junction_by_tag
            for j in events_by_junction_by_tag[tag]
        ]
        junction_counts.sort(key=lambda x: x[1])

        best_min_count = 0
        best_num_events = 0
        best_junction_count = 0
        for j, count in junction_counts:
            if count <= best_min_count:
                continue
            filtered_junctions_counts = [
                x for x in junction_counts if x[1] >= count
            ]
            num_events = sum(
                min(x[1], MAX_FACTOR * count) for x in filtered_junctions_counts
            )
            # num_junctions = sum(1 for _ in filtered_junctions_counts)
            num_junctions = len(filtered_junctions_counts)
            if num_events > MIN_EVENTS and num_junctions > best_junction_count:
                best_min_count = count
                best_num_events = num_events
                best_junction_count = sum(
                    1 for x in junction_counts if x[1] >= count
                )
        print(best_min_count)
        print(best_num_events)
        print(best_junction_count)

        ret = {}
        for tag in events_by_junction_by_tag:
            tmp = {}
            for j in events_by_junction_by_tag[tag]:
                if len(events_by_junction_by_tag[tag][j]) >= best_min_count:
                    tmp[j] = events_by_junction_by_tag[tag][j]
            if len(tmp) > 0:
                ret[tag] = tmp
        return ret

    def filter_by_run_count():
        # Keep junctions that appear in at least 10% of runs.
        MIN_RUN_FRACTION = 0.1
        runs = set()  # type: Set[str]
        for events_by_junction in events_by_junction_by_tag.values():
            for events in events_by_junction.values():
                for event in events:
                    runs.add(event.run_id)
        num_runs = len(runs)
        print(
            f'Num runs: {num_runs}. Rejecting runs less '
            f'than: {MIN_RUN_FRACTION * num_runs}')

        run_count_by_junction = {
            j: len(set(e.run_id for e in events_by_junction_by_tag[tag][j]))
            for tag in events_by_junction_by_tag
            for j in events_by_junction_by_tag[tag]
        }
        ret = {}
        for tag in events_by_junction_by_tag:
            tmp = {}
            for j in events_by_junction_by_tag[tag]:
                if run_count_by_junction[j] >= num_runs * 0.1:
                    tmp[j] = events_by_junction_by_tag[tag][j]
            if len(tmp) > 0:
                ret[tag] = tmp
        return ret
    return filter_by_run_count()


def subsample(events_by_junction_by_tag):
    # type: (Dict[str, Dict[int, List[EntityEvent]]]) -> Dict[str, Dict[int, List[EntityEvent]]]
    EVENTS_PER_TAG = 1000

    # Seed so results are deterministic.
    random.seed(42)

    ret = {}  # type: Dict[str, Dict[int, List[EntityEvent]]]
    for tag in events_by_junction_by_tag:
        events_by_junction = events_by_junction_by_tag[tag]
        junction_events = [
            [
                (j, event)
                for event in random.sample(
                    # Shuffle the events
                    events_by_junction[j], k=len(events_by_junction[j])
                )
            ]
            for j in events_by_junction
        ]  # type: List[List[Tuple[int, EntityEvent]]]
        junction_events.sort(key=lambda l: len(l))
        junction_events.reverse()
        zipped = zip_longest(*junction_events)

        junction_events_for_tag = []  # type: List[Tuple[int, EntityEvent]]
        for z in zipped:
            for junction_event in z:
                if junction_event is not None:
                    junction_events_for_tag.append(junction_event)
        junction_events_for_tag = junction_events_for_tag[:EVENTS_PER_TAG]
        events_by_junction = {}
        for junction, event in junction_events_for_tag:
            events_by_junction.setdefault(junction, [])
            events_by_junction[junction].append(event)
        ret[tag] = events_by_junction
    return ret


def runs_per_junction(junction, events_by_junction):
    # type: (int, Dict[int, EntityEvent]) -> int
    return len(set(event.run_id for event in events_by_junction[junction]))


def save_events(events_by_junction_by_tag, fn):
    eel = EntityEventList()
    for events_by_junction in events_by_junction_by_tag.values():
        for events in events_by_junction.values():
            eel.event.extend(events)
    with open(fn, 'w') as f:
        f.write(text_format.MessageToString(eel))


def save_stats(events_by_junction_by_tag, fn):
    # Note, we can't just sum up the events in the EEL because some of them may
    # have been discarded.
    with open(fn, 'w') as f:
        total_events = 0
        for tag in events_by_junction_by_tag:
            for junction in events_by_junction_by_tag[tag]:
                total_events += len(events_by_junction_by_tag[tag][junction])
        f.write(f'Total events: {total_events}\n')
        for tag in events_by_junction_by_tag:
            f.write(f'Tag: {tag}\n')
            events_by_junction = events_by_junction_by_tag[tag]
            # List of tuples of (junction, event_count)
            junction_count_list = [
                (j, len(events_by_junction[j])) for j in events_by_junction
            ]
            junction_count_list.sort(key=lambda x: x[1])
            f.write('\tEvent Count: {}\n'.format(
                sum(e[1] for e in junction_count_list)
            ))
            f.write('\tJunction Count: {}\n'.format(
                len(set(e[0] for e in junction_count_list))
            ))
            for junction, count in junction_count_list:
                run_count = runs_per_junction(junction, events_by_junction)
                f.write(
                    f'\t\tJunction: {junction}. Event count: {count}. '
                    f'Run count: {run_count}\n'
                )


def main():
    print('Parsing args...')
    args = parse_args()
    print('Loading junction map...')
    junction_map = load_junction_map(args.junction_map_pbtxt)
    print('Loading events...')
    eel = load_events(args.input_events_pb)
    print('Grouping events...')
    events_by_junction_by_tag = group_events_by_junction_tag(eel, junction_map)
    events_by_junction_by_tag = filter_events_by_junction_by_tag(
        events_by_junction_by_tag)
    events_by_junction_by_tag = subsample(events_by_junction_by_tag)
    print('Saving events...')
    save_events(events_by_junction_by_tag, args.output_pbtxt)
    print('Saving stats...')
    save_stats(events_by_junction_by_tag, args.stats_fn)


if __name__ == '__main__':
    main()
