import vision.classification.mono_depth.tf.net_utils as net_utils
from vision.classification.mono_depth.tf.net_utils import TensorInfo
import tensorflow as tf
import vision.classification.mono_depth.training.md_vb2.common as common


def _get_per_class_input_tensor_schema():
    return net_utils.create_nt_with_init_values(
        "VB2PerClassInputTensorSchema",
        "roi_coords",
        TensorInfo(tf.float32, (None, 4)),
        "box_to_im_inds",
        TensorInfo(tf.int32, (None,)),
        "proto_str",
        TensorInfo(tf.string, (None,)),
    )


def get_input_tensor_schema():
    vs = []
    for c in common.CLASS_IDS:
        vs.append(c)
        vs.append(_get_per_class_input_tensor_schema())
    return net_utils.create_nt_with_init_values("VB2InputTensorSchema", *vs)


def _get_per_class_gt_tensor_schema():
    return net_utils.create_nt_with_init_values(
        "VB2PerClassGTTensorSchema",
        "ctr_coords",
        TensorInfo(tf.float32, (None, common.NUM_SURFACE_PTS, 2)),
        "ctr_depths",  # index for logit
        TensorInfo(tf.int32, (None, common.NUM_SURFACE_PTS)),
        "ctr_depth_residuals",
        TensorInfo(tf.float32, (None, common.NUM_SURFACE_PTS)),
        "ori_labels",
        TensorInfo(
            tf.float32, (None, common.NUM_SURFACE_PTS, common.NUM_ORI_BINS)
        ),
        "ori_residuals",
        TensorInfo(
            tf.float32, (None, common.NUM_SURFACE_PTS, common.NUM_ORI_BINS)
        ),
        "ori_residuals_valid",
        TensorInfo(
            tf.int32, (None, common.NUM_SURFACE_PTS, common.NUM_ORI_BINS)
        ),
        "extents",
        TensorInfo(tf.float32, (None, 3)),
        "car_length_labels",
        TensorInfo(tf.int32, (None,)),
        "car_width_labels",
        TensorInfo(tf.int32, (None,)),
        "car_height_labels",
        TensorInfo(tf.int32, (None,)),
        "surface_pt_in_image",
        # true if a surface point is visible in the image;
        # visible means that its projection lies in image bounds
        TensorInfo(tf.bool, (None, common.NUM_SURFACE_PTS)),
        "frustum_ctr_ori_logits",
        TensorInfo(tf.float32, (None, common.NUM_ORI_BINS)),
        "frustum_ctr_ori_residuals",
        TensorInfo(tf.float32, (None, common.NUM_ORI_BINS)),
        "frustum_ctr_ori_residuals_valid",
        TensorInfo(tf.int32, (None, common.NUM_ORI_BINS)),
    )


def get_gt_tensor_schema():
    vs = []
    for c in common.CLASS_IDS:
        vs.append(c)
        vs.append(_get_per_class_gt_tensor_schema())
    return net_utils.create_nt_with_init_values("VB2GTTensorSchema", *vs)


def get_per_class_output_tensor_schema():
    return net_utils.create_nt_with_init_values(
        "VB2OutputTensorData",
        "depth_logits",
        TensorInfo(
            tf.float32, (None, common.NUM_SURFACE_PTS, common.NUM_DEPTH_LOGITS)
        ),
        "depth_residuals",
        TensorInfo(
            tf.float32, (None, common.NUM_SURFACE_PTS, common.NUM_DEPTH_LOGITS)
        ),
        "ctr_coords",
        TensorInfo(tf.float32, (None, common.NUM_SURFACE_PTS, 2)),
        "ori_logits",
        TensorInfo(
            tf.float32, (None, common.NUM_SURFACE_PTS, common.NUM_ORI_BINS)
        ),
        "ori_residuals",
        TensorInfo(tf.float32, (None, common.NUM_SURFACE_PTS)),
        "extents",
        TensorInfo(tf.float32, (None, 3)),
        "extents_uncertainty",
        TensorInfo(tf.float32, (None, 3)),
        "car_length_logits",
        TensorInfo(tf.float32, (None, common.NUM_CAR_LENGTH_BINS)),
        "car_width_logits",
        TensorInfo(tf.float32, (None, common.NUM_CAR_WIDTH_BINS)),
        "car_height_logits",
        TensorInfo(tf.float32, (None, common.NUM_CAR_HEIGHT_BINS)),
        "visibility_logits",
        TensorInfo(tf.float32, (None, common.NUM_SURFACE_PTS, 2)),
        "frustum_ctr_ori_logits",
        TensorInfo(tf.float32, (None, common.NUM_ORI_BINS)),
        "frustum_ctr_ori_residuals",
        TensorInfo(tf.float32, (None, common.NUM_ORI_BINS)),
    )


def get_output_tensor_schema():
    vs = []
    for c in common.CLASS_IDS:
        vs.append(c)
        vs.append(get_per_class_output_tensor_schema())
    return net_utils.create_nt_with_init_values("VB2OutputTensorSchema", *vs)
