# USAGE: python plot_caffe_accuracy.py
# This is specific to the corner refiner model (to the logs output by caffe
# while it's training the corner refiner model.)

import os
import numpy as np
import matplotlib.pyplot as plt
import argparse
import re
from scipy.interpolate import UnivariateSpline

# Super hacks to cap the y-range.
MAX_LOSS = 9001
Y_RANGE = (1, 6)

def get_numbers(acc, pattern):
    out = []
    for a in acc:
        print a
        r = re.findall(pattern, a)
        if not len(r)==1:
            continue
        out.append(min(float(r[0]), MAX_LOSS))
    return out

def get_file_lines(fn):
    fid = open(fn, 'r')
    return fid.readlines()

def get_pos_distance(lines):
    relevant_lines = [l for l in lines if l.find('FLAG sum positive distance: ')>=0]
    distances = get_numbers(relevant_lines, '.+,(.+)')
    return distances

def get_neg_distance(lines):
    relevant_lines = [l for l in lines if l.find('FLAG sum negative distance: ')>=0]
    distances = get_numbers(relevant_lines, '.+,(.+)')
    return distances

def get_triplet_count(lines):
    relevant_lines = [l for l in lines if l.find('FLAG triplet count: ')>=0]
    counts = get_numbers(relevant_lines, '.+,(.+)')
    print "FLAGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG"
    print(len(counts))
    return counts

def get_accuracies(lines):
    accuracy_lines = [l for l in lines if l.find('Test net output #0:')>=0]
    accuracy = get_numbers(accuracy_lines, '.+accuracy = (.+)')
    return accuracy

def get_testing_loss(lines):
    loss_lines = [l for l in lines if l.find('Test loss:')>=0]
    loss = get_numbers(loss_lines, '.+Test loss: (.+)')
    return loss

def get_training_loss(lines):
    loss_lines = [l for l in lines if l.find('loss')>=0 and l.find('Iteration')>=0]
    loss = []
    for l in loss_lines:
        r = re.findall('Iteration .+, loss = (.+)', l)
        if not len(r) == 1:
            continue
        loss.append(min(float(r[0]), MAX_LOSS))
    if len(loss) > 50:
        loss = loss[10:]
    return loss

def draw_image(input_fn, output_fn):
    lines = get_file_lines(input_fn)
    acc = get_accuracies(lines)
    pos_dist = get_pos_distance(lines);
    neg_dist = get_neg_distance(lines);
    triplet_counts = get_triplet_count(lines);
    loss = get_testing_loss(lines)
    training_loss = get_training_loss(lines)

    colours = 'bgrk'
    # f, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1, 5, figsize=(30,4))
    f, ax1 = plt.subplots(1, 1, figsize=(8,4))

    x_loss = range(len(loss))
    spl = UnivariateSpline(x_loss, loss)
    spl.set_smoothing_factor(0.15)

    ax1.plot(training_loss, 'orange')
    ax1.plot(x_loss, spl(x_loss), color='b')
    ax1.plot(loss, color='b', alpha=0.3)
    ax1.set_title('training loss')
    ax1.set_xlabel('iteration')
    ax1.set_ylabel('training loss')
    ax1.set_ylim(Y_RANGE)
    ax1.grid(True)

    """
    ax2.plot(x_loss, spl(x_loss), color='g')
    ax2.plot(loss, color='b', alpha=0.3)
    ax2.set_xlabel('iteration')
    ax2.set_ylabel('testing loss')
    ax2.set_title('testing loss')
    ax2.set_ylim(Y_RANGE)
    ax2.grid(True)
    """

    '''
    ax = plt.subplot(133)
    ax.plot(acc, 'r')
    ax.grid(True)
    ax.set_xlabel('iteration')
    ax.set_ylabel('test accuracy')
    ax.set_title('test accuracy')
    '''

    """
    ax3.plot(pos_dist, 'r')
    ax3.grid(True)
    ax3.set_xlabel('iteration')
    ax3.set_ylabel('sum pos distance')
    ax3.set_title('Positive Example Distance')

    ax4.plot(neg_dist, 'r')
    ax4.grid(True)
    ax4.set_xlabel('iteration')
    ax4.set_ylabel('sum neg distance')
    ax4.set_title('Negative Example Distance')

    ax5.plot(triplet_counts, 'r')
    ax5.grid(True)
    ax5.set_xlabel('iteration')
    ax5.set_ylabel('count')
    ax5.set_title('Semihard triplet count')
    """


    plt.tight_layout()
    plt.savefig(output_fn, dpi=300)
    plt.close('all')
    print "Saved ", output_fn

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input', default='/tmp/caffe_tool.INFO', help='input caffe log')
    parser.add_argument('--output', default='caffe_tool.INFO.jpg', help='output image')
    args = parser.parse_args()
    draw_image(args.input, args.output)

if __name__ == "__main__":
    main()
