import tensorflow as tf
from collections import namedtuple

LossAndErrorTuple = namedtuple(
    "LossAndErrorTuple", "loss weighted_per_label_loss err"
)


def assert_tensor_shapes_equal(t0, t1):
    s0 = t0.shape.as_list()
    s1 = t1.shape.as_list()
    assert s0 == s1, "Shape: %s" % str(s0) + " did not match: %s" % str(s1)


def get_weighted_loss_and_error(gt_labels, pred_logits, weights=None):
    """
    Returns LossAndErrorTuple
    """
    assert len(gt_labels.shape) == (len(pred_logits.shape) - 1)
    assert gt_labels.dtype in [tf.bool, tf.int32, tf.int64]
    gt_labels = tf.cast(gt_labels, dtype=tf.int64)
    assert pred_logits.dtype in [tf.float32, tf.float64]

    if weights is None:
        weights = tf.ones(tf.shape(gt_labels), tf.float32)

    assert weights.dtype in [tf.float32, tf.float64]
    assert_tensor_shapes_equal(gt_labels, weights)

    weighted_per_label_loss = (
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=gt_labels, logits=pred_logits
        )
        * weights
    )
    loss = tf.reduce_sum(weighted_per_label_loss)

    pred_labels = tf.argmax(pred_logits, -1)
    err = tf.reduce_sum(
        tf.cast(tf.not_equal(gt_labels, pred_labels), tf.float32) * weights
    ) / (tf.reduce_sum(weights) + 1e-6)

    return LossAndErrorTuple(
        loss=loss, weighted_per_label_loss=weighted_per_label_loss, err=err
    )
