from google.protobuf import text_format
from prediction.learning.learning_pb2 import InferenceInterfaceFormat, MLModel
from third_party.caffe.src.caffe.proto.caffe_pb2 import NetParameter, SolverParameter
from vision.inference.proto.inference_pb2 import CaffeWrapperOptions, WrapperWithFallbackOptions

import argparse
import logging
import os
import pprint
import re
import shutil
import subprocess
import sys
import threading
import uuid

LOG_FN = 'optim.log'
SOLVER_FN = 'solver.pbtxt'
INFERENCE_FN = 'inference.pbtxt'
TRAIN_FN = 'train.pbtxt'
SNAPSHOTS_DN = 'snapshots'
MERGED_PBTXT_FN = 'merged.pbtxt'
MERGED_CAFFEMODEL_FN = 'merged.caffemodel'

"""
Copy pasted train automator so that it can be used to search for buckets.
Ideally...this would somehow be combined into the main train automator....but
too lazy for that.
"""


def tee(process, fn=None):
    # If we already have this FN, copy it to {fn_name}{number}.{ext} so that
    # we don't lose the old copy. The most recent copy is always at {fn}.
    fn_copy = fn
    if fn_copy:
        counter = 0
        while os.path.exists(fn_copy):
            fn_no_ext, ext = os.path.splitext(fn)
            fn_copy = fn_no_ext + str(counter) + ext
            counter += 1

    if fn_copy != fn:
        shutil.move(fn, fn_copy)

    counter = 0
    buf = ''
    for c in iter(lambda: process.stdout.read(1), ''):
        sys.stdout.write(c)
        counter += 1
        buf += c
        if counter % 10000 == 0:
            if fn is not None:
                with open(fn, 'a') as f:
                    f.write(buf)
                    buf = ''
    # Write the rest of the buffer.
    if fn is not None:
        with open(fn, 'a') as f:
            f.write(buf)

    process.wait()
    assert process.returncode == 0


def brun(target, flags):
    '''
    Runs a command and returns the result
    '''
    cmd = ' '.join([target] + flags)
    print('Running command:', cmd)

    process = subprocess.Popen(
        cmd, stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT, shell=True)
    return process


def run_multiblob(input_dir, output_dir, config):
    '''
    Runs a datamatrix to multiblob proto command and returns the process.
    '''
    flags = [
        '--input_dir', input_dir, '--output_dir', output_dir,
        '--alsologtostderr'
    ]
    if config['subsample_pct']:
        flags += ['--subsample_pct', str(config['subsample_pct'])]
    if config['feature_spec_fn']:
        flags += ['--feature_spec_fn', config['feature_spec_fn']]
    if config['blob_creator']:
        flags += ['--blob_creator', str(config['blob_creator'])]
    if config['num_junction_pose_s_buckets']:
        flags += ['--num_junction_pose_s_buckets',
                  str(config['num_junction_pose_s_buckets'])]
    DATAMATRIX_TARGET = 'prediction/learning/data_matrix_to_multi_blob_proto'
    # Stream the output of the process.
    tee(brun(DATAMATRIX_TARGET, flags))


def run_caffe_train(model_dir, solver_fn):
    # Don't forget to pipe output to a file.
    CAFFE_TARGET = 'third_party/caffe/caffe_tool'
    flags = ['train', '--gpu', '0', '--solver', solver_fn]
    tee(brun(CAFFE_TARGET, flags), os.path.join(model_dir, LOG_FN))


def run_trt(snapshot_fn, model_dir, inference_fn):
    MERGE_TARGET = 'third_party/caffe/scripts/merge_batchnorm/gen_merged_model_fusedbn'
    tee(brun(MERGE_TARGET, [
        inference_fn, snapshot_fn,
        '--output_model', os.path.join(model_dir, MERGED_PBTXT_FN),
        '--output_weights', os.path.join(model_dir, MERGED_CAFFEMODEL_FN),
    ]))


def find_best_snapshot_fn(log_file, snapshot_dir):
    # Super hacky log file parsing to find the lowest test loss.
    snapshot_losses = {}
    last_loss = None
    with open(log_file, 'r') as f:
        for line in f:
            if last_loss is not None:
                match = re.search(r'Iteration (.*),', line)
                assert match
                snapshot_losses[match.group(1)] = last_loss
                last_loss = None
            if 'Test loss' in line:
                match = re.search(r'Test loss: (.*)', line)
                assert match
                last_loss = float(match.group(1))
    best_snapshot = min(snapshot_losses, key=snapshot_losses.get)
    print('Best loss: ', best_snapshot, snapshot_losses[best_snapshot])
    return (os.path.join(snapshot_dir, '_iter_' + best_snapshot + '.caffemodel'),
            snapshot_losses[best_snapshot])


def update_solver_file(solver_path, model_dir, config):
    assert not solver_path == config['solver_fn']
    train_path = os.path.join(model_dir, TRAIN_FN)
    assert not train_path == config['train_fn']

    solver = SolverParameter()
    with open(solver_path, 'r') as f:
        text_format.Parse(f.read(), solver)
    solver.net = train_path
    solver.snapshot_prefix = os.path.join(model_dir, SNAPSHOTS_DN) + '/'
    with open(solver_path, 'w') as f:
        f.write(text_format.MessageToString(solver))


def get_batch_name_from_file_name(fn):
    return fn[(fn.index('_') + 1):].split('.')[0].lower()


def update_train_file(train_path, lmdb_dir, config):
    # assert not train_path.startswith(TRAIN_ROOT)
    assert not train_path == config['train_fn']

    net = NetParameter()
    with open(train_path, 'r') as f:
        text_format.Parse(f.read(), net)

    # Update the LMDB paths and outputs.
    for layer in net.layer:
        # Update LMDB path
        if layer.type == 'MultiBlobLmdbData':
            basename = os.path.basename(
                layer.multi_blob_lmdb_data_param.filename)
            layer.multi_blob_lmdb_data_param.filename =\
                os.path.join(lmdb_dir, basename)
            # Maybe update the train/test batch balance
            if basename.startswith('train_') or (config['batch_balance_test'] and basename.startswith('test_')):
                batch_name = get_batch_name_from_file_name(basename)
                if 'batch_balance' in config:
                    assert batch_name in config['batch_balance']
                    layer.multi_blob_lmdb_data_param.batch_size =\
                        config['batch_balance'][batch_name]

        # Update output buckets.
        if layer.type == 'InnerProduct' and\
                layer.name.startswith('fc_output_s_buckets'):
            layer.inner_product_param.num_output =\
                config['num_junction_pose_s_buckets']

    with open(train_path, 'w') as f:
        f.write(text_format.MessageToString(net))


def update_inference_file(inference_path, config):
    assert not inference_path == config['inference_fn']

    net = NetParameter()
    with open(inference_path, 'r') as f:
        text_format.Parse(f.read(), net)

    # Update the LMDB paths and outputs.
    for layer in net.layer:
        # Update output buckets.
        if layer.type == 'InnerProduct' and\
                layer.name.startswith('fc_output_s_buckets'):
            layer.inner_product_param.num_output =\
                config['num_junction_pose_s_buckets']

    with open(inference_path, 'w') as f:
        f.write(text_format.MessageToString(net))


def create_files_and_folders(config):
    assert 'experiment_id' in config
    experiment_id = config['experiment_id']

    # Create the LMDB directory
    lmdb_dir = os.path.join(
        config['lmdb_dir'], config['experiment_group_name'], experiment_id)
    if not os.path.exists(lmdb_dir):
        os.makedirs(lmdb_dir)
    print('Making LMDB directory:', lmdb_dir)

    # Create the model directory & snapshot directory
    model_dir = os.path.join(
        config['model_dir'], config['experiment_group_name'], experiment_id)
    snapshot_dir = os.path.join(model_dir, SNAPSHOTS_DN)
    print('Making model directory:', model_dir)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
        os.makedirs(snapshot_dir)
    # Create the train file. model file and solver file.
    new_solver_fn = os.path.join(model_dir, SOLVER_FN)
    shutil.copyfile(config['solver_fn'], new_solver_fn)
    new_train_fn = os.path.join(model_dir, TRAIN_FN)
    shutil.copyfile(config['train_fn'], new_train_fn)
    new_inference_fn = os.path.join(model_dir, INFERENCE_FN)
    shutil.copyfile(config['inference_fn'], new_inference_fn)
    # Update the solver and train files so they reference the right files e.g.
    # right LMDBs
    update_solver_file(new_solver_fn, model_dir, config)
    update_train_file(new_train_fn, lmdb_dir, config)
    update_inference_file(new_inference_fn, config)

    return {
        'lmdb_dir': lmdb_dir,
        'model_dir': model_dir,
        'solver_fn': new_solver_fn,
        'inference_fn': new_inference_fn,
        'snapshot_dir': snapshot_dir,
    }


def save_config(config, directory):
    with open(os.path.join(directory, 'config.txt'), 'w') as f:
        f.write('\n'.join([': '.join((str(k), str(v)))
                           for k, v in config.iteritems()]))


def eval(config, model_dir):
    model_override = MLModel()
    with open(config['model_override_fn'], 'r') as f:
        text_format.Parse(f.read(), model_override)
    assert model_override.type

    # Compute the input size.
    iif = model_override.Extensions[InferenceInterfaceFormat.ext]
    net = NetParameter()
    with open(os.path.join(model_dir, MERGED_PBTXT_FN)) as f:
        text_format.Parse(f.read(), net)
    for layer in net.layer:
        if layer.type == 'Input':
            shape = layer.input_param.shape
            assert len(shape) == 1
            for dim in shape[0].dim:
                if dim != 1:
                    iif.input_size = dim
                    break

    caffe_wrapper_options = iif.inference_options.Extensions[
        WrapperWithFallbackOptions.ext].cpu_options.Extensions[CaffeWrapperOptions.ext]
    caffe_wrapper_options.model_file = os.path.join(model_dir, MERGED_PBTXT_FN)
    caffe_wrapper_options.params_file =\
        os.path.join(model_dir, MERGED_CAFFEMODEL_FN)
    model_override_fn = os.path.join(model_dir, 'ml_model.pbtxt')
    with open(model_override_fn, 'w') as f:
        f.write(text_format.MessageToString(model_override))

    EVAL_TARGET = 'prediction/metrics/apps/eval_pred'
    flags = [
        '--description', config['experiment_group_name'] +
        '_' + config['experiment_id'],
        '--max_samples', str(config['eval_num_events']),
        '--mapreduce_num_workers', str(config['eval_num_workers']),
        '--pipedream',
    ]
    flags += ['--model_override_fn', model_override_fn]
    flags += ['--app', config['eval_app']]
    if config['eval_dataset']:
        flags += ['--dataset', config['eval_dataset']]

    tee(brun(EVAL_TARGET, flags))


def run(config):
    stages = config['stages'].split(',')

    # Making the info is a non-optional stage.
    info = create_files_and_folders(config)
    save_config(config, info['model_dir'])

    if 'multiblob' in stages:
        run_multiblob(config['datamatrix_dir'], info['lmdb_dir'], config)
    if 'train' in stages:
        run_caffe_train(info['model_dir'], info['solver_fn'])
        best_snapshot_fn, best_loss = find_best_snapshot_fn(
            os.path.join(info['model_dir'], LOG_FN),
            info['snapshot_dir'])

        config['best_loss'] = best_loss

        # Save the config again so that we get the best loss.
        save_config(config, info['model_dir'])

    if 'trt' in stages:
        # Run this again in case train wasn't run.
        best_snapshot_fn, best_loss = find_best_snapshot_fn(
            os.path.join(info['model_dir'], LOG_FN),
            info['snapshot_dir'])

        config['best_loss'] = best_loss
        run_trt(best_snapshot_fn, info['model_dir'], info['inference_fn'])

        # Save the config again so that we get the best loss.
        save_config(config, info['model_dir'])

    if 'eval' in stages:
        eval(config, info['model_dir'])


def parse_batch_balance_args(train_fn):
    # Batch balance args are automatically constructed from the train pbtxt.
    # They are of the form --bb_{batch_balance_class_name}
    # NOTE: batch balancing flags only affect the TRAIN layers not the TEST
    # ones.
    net = NetParameter()
    with open(train_fn, 'r') as f:
        text_format.Parse(f.read(), net)

    batch_names = set()
    # Update the LMDB paths.
    for layer in net.layer:
        if layer.type == 'MultiBlobLmdbData':
            basename = os.path.basename(
                layer.multi_blob_lmdb_data_param.filename)
            if basename.startswith('train_'):
                batch_names.add(get_batch_name_from_file_name(basename))

    parser = argparse.ArgumentParser()
    for bn in batch_names:
        parser.add_argument('--bb_' + bn, type=float, default=-1.0)

    return parser.parse_known_args()[0]


def parse_args():
    parser = argparse.ArgumentParser()

    # The experiment name. Used to generate the folder name.
    parser.add_argument('--experiment_group_name', required=True)

    # File path to the solver pbtxt
    parser.add_argument('--solver_fn', required=True)
    # File path to the train pbtxt. This will be cloned to make the train pbtxt
    # for each experiment.
    parser.add_argument('--train_fn', required=True)
    # File path to the inference pbtxt. This will be used to generate the merged
    # model.
    parser.add_argument('--inference_fn', required=True)
    # File path to the MLModel definition which will be copied and used to
    # override the existing model.
    parser.add_argument('--model_override_fn', required=True)

    # The datamatrix directory.
    parser.add_argument('--datamatrix_dir', required=True)

    # Where to store the LMDBs
    parser.add_argument('--lmdb_dir', required=True)
    # Where to store the models
    parser.add_argument('--model_dir', required=True)

    # The eval app to run.
    parser.add_argument('--eval_app', required=True)
    # The eval dataset to use.
    parser.add_argument('--eval_dataset')
    # The number of workers to use for the evaluator.
    parser.add_argument('--eval_num_workers', type=int, default=5)
    # The number of events to use for the evaluator. This sets max_samples.
    parser.add_argument('--eval_num_events', type=int, default=1000)

    # A comma-separated list of the stages to run. Order doesn't matter.
    parser.add_argument('--stages', default='multiblob,train,trt,eval')

    # Specifies the format for directory creation. Should be a string, values in
    # {} will be replaced by their corresponding flag. An empty string will be
    # assigned a UUID. Make sure these are unique! This is not enforced.
    parser.add_argument('--directory_format', type=str, default='')

    parser.add_argument('--subsample_pct', type=float)

    # Feature spec args: allows the data matrix to output a subset of the
    # features.
    parser.add_argument('--feature_spec_fn')

    # Hacked in new parameter is here.
    parser.add_argument('--num_junction_pose_s_buckets', type=int, default=10)

    parser.add_argument('--resume_experiment_id')

    parser.add_argument('--blob_creator', type=str, default=False)

    # Flags for batch balancing. Modify here to change the way batch balancing
    # works. Flag name must be bb_{lower_case_class_name}.
    parser.add_argument('--batch_size', type=float, default=256)

    # Whether to also batch balance test data.
    parser.add_argument('--batch_balance_test', type=bool, default=False)

    known, unknown = parser.parse_known_args()
    if unknown:
        print('Encountered unknown arguments. If you see batch balance args ' +
              'here, that\'s ok: they will be parsed separately. Unknown ' +
              'args: ', unknown)
    return known


def make_config(args):
    # Copy the config from the args, but don't copy the batch balance ones.
    config = vars(args)

    # Set the experiment ID.
    if config['resume_experiment_id']:
        config['experiment_id'] = config['resume_experiment_id']
    else:
        if len(config['directory_format']) > 0:
            config['experiment_id'] = config['directory_format'].format(
                **config)
        else:
            config['experiment_id'] = str(uuid.uuid1()).replace('-', '')
    return config


def add_batch_balance_to_config(batch_balance_args, config):
    # Add the batch balance to the config.
    bb_dict = vars(batch_balance_args)
    default_subsample = False
    for k, v in bb_dict.iteritems():
        if v < 0:
            print('Using default sample because found key with negative count: ', k)
            default_subsample = True
            break
    if default_subsample:
        for k, v in bb_dict.iteritems():
            assert v < 0, 'Batch balancing must be specified for all batches or none. Found key: ' + k
        print('Using default batch balance')
    if not default_subsample:
        total = sum([v for v in bb_dict.itervalues()])
        # Make a normalized dict and remove the bb. Value needs to be at least 1
        bb_dict = {k[3:]: max(1, int(v / total * config['batch_size']))
                   for k, v in bb_dict.iteritems()}
        config['batch_balance'] = bb_dict
        print('Batch balance parms:', bb_dict)
    return config


def main():
    args = parse_args()
    config = make_config(args)

    batch_balance_args = parse_batch_balance_args(config['train_fn'])
    add_batch_balance_to_config(batch_balance_args, config)

    print('\n\nRunning with config:')
    pprint.pprint(config)
    print('\n\n')
    run(config)


if __name__ == '__main__':
    main()
