#!/usr/bin/python

import matplotlib
matplotlib.use('Agg')   # Force non-interactive matplotlib backend

import argparse
import json
import matplotlib.pyplot as plt
import pandas as pd
import psycopg2

from data.chum import chumpy

# Duration is last segment of the test that is of interest
def process_data(zimbeasts, duration):
    ref_duration = pd.Timedelta(duration)
    for logrun in zimbeasts:
        df = logrun.dropna(subset=['system_latency_p95'])
        run_duration = df.timestamp.max() - df.timestamp.min()
        print("Processing log: min time={}, max time={}, delta={}".format(df.timestamp.min(), df.timestamp.max(), run_duration))
        # Each entry is 1s of data, so the total duration is the range of timestamps + 1s
        if (run_duration + pd.Timedelta(1, unit='s')) < ref_duration:
            print("Fatal: Run duration is less than reference run: did you forget to specify desired duration of reference data?")
            exit(99)
        # remove all but the last <duration> seconds
        z = df[df.timestamp > (df.timestamp.max() - ref_duration)].copy()
        # line up by time since start
        z['timedelta'] = (z['timestamp'] - z['timestamp'].min()).apply(lambda x: x.total_seconds())
        yield z[['timedelta', 'system_latency_p95', 'run_meta_id']].dropna(subset=['system_latency_p95'])


def get_all_run_data(ref_run_id, logbot_runs, logbot_data_files):
    data = [get_redshift_data(run) for run in [ref_run_id] + logbot_runs]
    data.extend([get_local_data(run) for run in logbot_data_files])
    return data

def get_local_data(run_data_file):
    print("Using local data: " + run_data_file)
    # Read file to string
    with open(run_data_file, "r") as jfile:
        raw_data = '[' + jfile.read() + ']'
    # fix bad JSON
    json_data = json.loads(raw_data.replace("}\n{", "},\n{"))
    # Re-format back to json for pandas to read
    pd_json = []
    for entry in json_data:
        if 'system' in entry:
            pd_json.append({'run_meta_id': run_data_file, 'timestamp': entry['timestamp'], 'system_latency_p95': entry['system']['latency_P95']})
    data = pd.DataFrame(pd_json)
    data['timestamp'] = data['timestamp'].apply(lambda x: pd.Timestamp(x, unit='s', tz='UTC'))
    return data

def get_redshift_data(run_meta_id):
    print("Retrieving redshift data for run: " + run_meta_id)
    conn = psycopg2.connect(
        host='redshift-db01.c1iujdhow47v.us-west-1.redshift.amazonaws.com',
        user='srv_planner',
        port=5439,
        password='qAlgK2#XdRS8',
        dbname='temp_db'
    )
    # an improvement would be to batch queries for run data
    query = """
    SELECT
        performance_metrics.ts as "timestamp",
        performance_metrics.system_latency_p95,
        performance_metrics.run_meta_id
    FROM performance_metrics

    WHERE (run_meta_id = %s)
    """

    data = pd.read_sql(query, conn, params=(run_meta_id,))
    data['timestamp'] = data['timestamp'].apply(lambda x: pd.Timestamp(x, unit='s', tz='UTC'))

    return data

def bootstrap_series(table, N=1000, level=0.99, func='mean'):
    mean = table.aggregate(func, axis=1)
    deltas = table.subtract(mean, axis=0)
    mu_delta_dist = pd.concat((deltas.sample(frac=1, replace=True, axis=1).aggregate(func, axis=1) for _ in range(N)), axis=1)

    # now get the confidence level (two-sided)
    side = (1 - level)/2.0

    lb = mean - mu_delta_dist.quantile(1.0 - side, axis=1)
    ub = mean - mu_delta_dist.quantile(side, axis=1)
    return lb, ub

def main(reference_run, logbot_runs, logbot_data_files, output_png, trim_start_sec=0, group_size=1, plot_all_runs=False):
    assert group_size < float(len(logbot_runs)+len(logbot_data_files))/2, "Number of runs must be more than 2x the group size"
    # Parse Actual run URI
    ref_run_store, ref_run_range = chumpy.parseChumUri(reference_run, False)
    ref_run_id = chumpy.getMetaIdFromChumUri(reference_run)
    assert ref_run_range is not None
    ref_start_epoch = float(ref_run_range.start_time) / 1e9 # Units are ns
    ref_end_epoch = float(ref_run_range.end_time) / 1e9 # Units are ns
    ref_start_epoch += trim_start_sec
    ref_duration = ref_end_epoch-ref_start_epoch
    print("Reference run: {}, starting: {}, effective duration: {}".format(ref_run_id, ref_start_epoch, ref_duration))

    data = get_all_run_data(ref_run_id, logbot_runs, logbot_data_files)

    zimbeasts = pd.concat(process_data(data[1:], duration='{}s'.format(ref_duration)))

    # manually handle ground truth data
    start = pd.Timestamp(ref_start_epoch, tz='UTC', unit='s')
    end = start + pd.Timedelta('{}s'.format(ref_duration))
    gt = data[0][(data[0].timestamp >= start) & (data[0].timestamp < end)].copy()
    gt['timedelta'] = (gt.timestamp - gt.timestamp.min()).apply(lambda x:
            x.total_seconds())

    zimbot_series = pd.pivot_table(zimbeasts, index='run_meta_id', columns='timedelta',
            values='system_latency_p95', aggfunc='mean').T # have to do it this way, error using timedelta as index..

    if group_size > 1:
        def make_samples():
            for _ in range(len(logbot_runs)):
                yield zimbot_series.sample(group_size, replace=False, axis=1).aggregate('mean', axis=1)

        groups = pd.concat(make_samples(), axis=1)
    else:
        groups = zimbot_series

    # find the confidence interval for the 99.5% runtime of the series
    p995_lb_series, p995_ub_series = bootstrap_series(groups, func=lambda x:
            x.quantile(0.995))

    # find the confidence interval for the 0.5% runtime of the series
    p005_lb_series, p005_ub_series = bootstrap_series(groups, func=lambda x:
            x.quantile(0.005))

    # use the average of the 99.5 percentile and 0.5 percentile confidence
    # intervals to form an estimate of the upper and lower bounds (respecitvely)
    # of logbot latency; this forms an (estimated) 99% confidence interval of
    # where logbot latency would lie.
    #
    # of intervals constructed this way, 99% of them will contain the "true" run
    # time of logbot. if this interval *also* contains data from the real run,
    # that would be a nice bonus (although it doesn't need to contain real run
    # latency in theory).
    lb_series, ub_series = (p005_lb_series + p005_ub_series)/2, (p995_lb_series + p995_ub_series)/2

    # make a plot
    ax = gt.groupby('timedelta').mean().system_latency_p95.dropna().sort_index().plot(linewidth=2, marker='o', color='#EA2E00', label=ref_run_id)
    ax.fill_between(lb_series.index, lb_series, ub_series, facecolor='#51EBAF',
                    edgecolor='k', linewidth=2, alpha=0.5, label='99% Zimbeast CI')

    if plot_all_runs:
        for index, logrun in enumerate(data[1:]):
            ax.plot(pd.concat(process_data([logrun], duration='{}s'.format(ref_duration))).groupby('timedelta').mean().system_latency_p95.dropna().sort_index(), label=(logbot_runs+logbot_data_files)[index])
        # For now don't print labels when we have all the runs displayed
        # ax.legend(fontsize='xx-small')
    else:
        ax.legend()

    ax.figure.savefig(output_png, dpi=300)

    p95_upper_bound_mean = ub_series.mean()
    p95_estimate_bandwidth = (ub_series - lb_series).mean()
    print("Logbot Result Analysis: saved graph to {}".format(output_png))
    print("Logbot Result Analysis: Average P95 upper bound (99.5% confidence): {} ms".format(p95_upper_bound_mean))
    print("Logbot Result Analysis: Average bandwidth (99% confidence interval): {} ms".format(p95_estimate_bandwidth))

    return (p95_upper_bound_mean, p95_estimate_bandwidth)


if __name__ == "__main__":
    parser = argparse.ArgumentParser('Analyze Logbot runs')
    parser.add_argument('--reference_run', '-r',
            required=True,
            help="CHUM URI with offset and duration indications, like '20190320T214646-pri_010@1553118557.05+30.0'")
    parser.add_argument('--logbot_runs', '-l',
            required=False,
            nargs='+',
            action='append',
            help='Space separated list of logbot runs. Ex. \'-l <uri1> <uri2>\' (May also be specified multiple times: Ex. \'-l <uri1> -l <uri2> <uri3>\')')
    parser.add_argument('--logbot_data', '-d',
            required=False,
            nargs='+',
            action='append',
            help='Space separated list of logbot system latency report files. Ex. \'-l <file1> <file2>\'')
    parser.add_argument('--output-png', '-o',
            required=True,
            help='Path to write output figure to')
    parser.add_argument('--groups', '-n',
            type=int,
            default=1,
            help='Number of logbot runs to "group" together')
    parser.add_argument('--plot-all-runs', '-a',
            action='store_true',
            default=False,
            help='Plot all individual runs')
    parser.add_argument('--trim-start',
            type=int,
            default=0,
            help='Seconds to trim off the start of the reference run')
    parser.add_argument('--latency-limit',
            type=float,
            help='Check estimated P95 latency upper bound is less than this number (in ms). Exit with code 3 if greater.')
    args = parser.parse_args()

    plt.style.use('ggplot')
    # flatten list
    logbot_runs_flat = [item for sublist in args.logbot_runs for item in sublist] if args.logbot_runs is not None else []
    logbot_data_files_flat = [item for sublist in args.logbot_data for item in sublist] if args.logbot_data is not None else []

    (p95_upper_bound_mean, p95_estimate_bandwidth) = main(args.reference_run, logbot_runs_flat, logbot_data_files_flat, trim_start_sec=args.trim_start, output_png=args.output_png, group_size=args.groups, plot_all_runs=args.plot_all_runs)

    # Check results
    if args.latency_limit:
        if float(p95_upper_bound_mean) > float(args.latency_limit):
            print("FAIL: System latency {} above limit of {}".format(p95_upper_bound_mean, args.latency_limit))
            exit(3)
        else:
            print("PASS: System latency {} below limit of {}".format(p95_upper_bound_mean, args.latency_limit))
