"""A TFLight model using high-level layer based on the CIFAR-10 network.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
import re

import tensorflow as tf

from tflight.examples.cifar10 import cifar10_input
from tflight.framework.supervised_model import SupervisedModel
from tensorflow.python.training import optimizer
from tflight.framework import memory_saving_gradients as tflight_memory_saving_gradients
import meteor_memory_saving_gradients

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_integer("gradient_mode", 0,
                            """Gradient mode to use.
                            0: vanilla
                            1: recompute
                            2: recompute & offload""")

import sys; tf.flags.FLAGS(sys.argv, known_only=True)

FLAGS.output = "/tmp/piconet_tflight_{}".format(FLAGS.gradient_mode)

# Global constants describing the CIFAR-10 data set.
IMAGE_SIZE = cifar10_input.IMAGE_SIZE
NUM_CLASSES = cifar10_input.NUM_CLASSES
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL

# Constants describing the training process.
NUM_EPOCHS_PER_DECAY = 350.0  # Epochs after which learning rate decays.
LEARNING_RATE_DECAY_FACTOR = 0.1  # Learning rate decay factor.
INITIAL_LEARNING_RATE = 0.1  # Initial learning rate.

def delayed_layer(x):
    def slow_identity(x_):
        time.sleep(0.1)
        return x_

    fix_name = lambda n: re.sub("tower_\d+/", "", n) if n else ""

    with tf.get_default_graph().gradient_override_map({"PyFuncStateless": "Identity"}):
        delayed = tf.py_func(slow_identity,
                             [x],
                             x.dtype,
                             stateful=False,
                             name=fix_name(x.op.name) + "_delayed")
        delayed = tf.reshape(delayed,
                             tf.shape(x),
                             name=fix_name(delayed.op.name) + "_reshape")

    return delayed


class Piconet(SupervisedModel):
    def __init__(self):
        super(Piconet, self).__init__()

        if FLAGS.gradient_mode == 0:
            pass
        elif FLAGS.gradient_mode == 1:
            optimizer.gradients.gradients = tflight_memory_saving_gradients.gradients_speed
        elif FLAGS.gradient_mode == 2:
            optimizer.gradients.gradients = tflight_memory_saving_gradients.gradients_speed
            FLAGS.offload_checkpoints = True
        else:
            raise ValueError("Invalid argument --gradient_mode {}".format(FLAGS.gradient_mode))

        print("Using gradient_mode", FLAGS.gradient_mode)

    def _input(self, num_gpus):
        """Construct inputs for CIFAR training using the Reader ops.

        Returns:
          images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
          labels: Labels. 1D tensor of [batch_size] size.
        """
        input_path = os.path.join(FLAGS.input, 'cifar-10-batches-bin')
        images, labels = \
            cifar10_input.distorted_inputs(data_dir=input_path,
                                           batch_size=FLAGS.train_batch_size * num_gpus)
        return images, labels

    def _inference(self, sample):
        """Build the CIFAR-10 model.

            Args:
              sample: Images returned from input().

            Returns:
              Logits.
            """
        # conv1
        conv1 = tf.layers.conv2d(sample,
                                 filters=64,
                                 kernel_size=[5, 5],
                                 strides=[1, 1],
                                 padding='same',
                                 activation=tf.nn.relu,
                                 name='conv1')

        # pool1
        pool1 = tf.layers.max_pooling2d(conv1,
                                        pool_size=[3, 3],
                                        strides=[2, 2],
                                        padding='same',
                                        name='pool1')
        pool1 = delayed_layer(pool1)

        # norm1
        norm1 = tf.nn.local_response_normalization(pool1,
                                                   depth_radius=4,
                                                   bias=1.0,
                                                   alpha=0.001 / 9.0,
                                                   beta=0.75,
                                                   name='norm1')

        # conv2
        conv2 = tf.layers.conv2d(norm1,
                                 filters=64,
                                 kernel_size=[5, 5],
                                 strides=[1, 1],
                                 padding='same',
                                 activation=tf.nn.relu,
                                 name='conv2')

        # norm2
        norm2 = tf.nn.local_response_normalization(conv2,
                                                   depth_radius=4,
                                                   bias=1.0,
                                                   alpha=0.001 / 9.0,
                                                   beta=0.75,
                                                   name='norm2')

        # pool2
        pool2 = tf.layers.max_pooling2d(norm2,
                                        pool_size=[3, 3],
                                        strides=[2, 2],
                                        padding='same',
                                        name='pool2')

        # local3
        pool2_flatten = tf.layers.flatten(pool2)

        pool2_flatten = delayed_layer(pool2_flatten)

        local3 = tf.layers.dense(pool2_flatten,
                                 units=384,
                                 activation=tf.nn.relu,
                                 name='local3')
        local3 = delayed_layer(local3)

        # local4
        local4 = tf.layers.dense(local3,
                                 units=192,
                                 activation=tf.nn.relu,
                                 name='local4')
        local4 = delayed_layer(local4)

        # softmax
        softmax_linear = tf.layers.dense(local4,
                                         units=NUM_CLASSES,
                                         name='softmax_linear')

        return delayed_layer(softmax_linear)

    def _loss(self, prediction, ground_truth):
        """Add L2 loss to all the trainable variables.

        Args:
          prediction: Predictions from inference().
          ground_truth: Ground truth from input(). 1-D tensor of shape [batch_size]

        Returns:
          Loss tensor of type float.
        """
        # Calculate the average cross entropy loss across the batch.
        ground_truth = tf.reshape(ground_truth, prediction.get_shape()[:-1])
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=ground_truth, logits=prediction, name="cross_entropy_per_example")
        cross_entropy_mean = tf.reduce_mean(cross_entropy, name="cross_entropy")

        return cross_entropy_mean

    def _learning_rate(self, global_step):
        # Decay the learning rate exponentially based on the number of steps.
        num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.train_batch_size
        decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)
        return tf.train.exponential_decay(INITIAL_LEARNING_RATE,
                                          global_step,
                                          decay_steps,
                                          LEARNING_RATE_DECAY_FACTOR,
                                          staircase=True)

    def _optimizer(self, learning_rate):
        return tf.train.GradientDescentOptimizer(learning_rate)
