import argparse
import os
import pprint

from typing import Dict

from google.protobuf import text_format
from prediction.learning.topdown.junction_exit_dense.junction_event_attributes_pb2 import BucketsCount


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_dir', required=True)
    return parser.parse_args()

def get_bucket_counts(directory):
    # (str) -> Dict[str, int]
    bucket_count_fn = os.path.join(directory, 'buckets_count.pbtxt')
    if not os.path.isfile(bucket_count_fn):
        print('No bucket file found: {}'.format(bucket_count_fn))
        return {}
    proto = BucketsCount()
    with open(bucket_count_fn, 'r') as f:
        text_format.Parse(f.read(), proto)
    return dict(proto.bucket_id_to_count)


def main():
    args = parse_args()

    bucket_count_map = {}
    for d in os.listdir(args.input_dir):
        full_dir = os.path.join(args.input_dir, d)
        if not os.path.isdir(full_dir):
            continue
        tmp_map = get_bucket_counts(full_dir)
        for bucket_id, count in tmp_map.iteritems():
            bucket_count_map.setdefault(bucket_id, 0)
            bucket_count_map[bucket_id] += count
    print('Absolute:')
    pprint.pprint(bucket_count_map)
    total = float(sum(bucket_count_map.values()))
    print('Percent:')
    percent_dict = {k: bucket_count_map[k] / total for k in bucket_count_map}
    pprint.pprint(percent_dict)
    print('Division by type_nways:')
    new_dict = {}
    for k in percent_dict:
        new_key = k.rsplit('_', 1)[0]
        new_dict.setdefault(new_key, 0)
        new_dict[new_key] += percent_dict[k]
    pprint.pprint(new_dict)
    print('Division by type_nways:')
    turn_dict = {}
    for k in percent_dict:
        new_key = k.rsplit('_', 1)[1]
        turn_dict.setdefault(new_key, 0)
        turn_dict[new_key] += percent_dict[k]
    pprint.pprint(turn_dict)
    print('Dict dict')
    dict_dict = {}
    for k in percent_dict:
        p1, p2 = k.rsplit('_', 1)
        dict_dict.setdefault(p1, {})
        dict_dict[p1].setdefault(p2, 0)
        dict_dict[p1][p2] = percent_dict[k]
    for k in dict_dict:
        total = float(sum(dict_dict[k].values()))
        dict_dict[k] = {i: dict_dict[k][i] / total for i in dict_dict[k]}
    pprint.pprint(dict_dict)



if __name__ == '__main__':
    main()
