from vision.geometry.camera.safe_camera import SafeCamera
import argparse
import tensorflow as tf
from vision.classification.mono_depth.training.vb2.proto.vision_box2_dataset_pb2 import (
    VisionBox2Dataset,
    VisionBox2Datasets,
)
import os
from collections import namedtuple
import vision.classification.mono_depth.training.md.resnet_utils as resnet_utils
import numpy as np
import vision.classification.mono_depth.training.tf.net_utils as net_utils
from vision.classification.mono_depth.training.tf.net_utils import TensorInfo

import vision.classification.mono_depth.training.tf.nest as tf_nest
import vision.classification.mono_depth.training.md_vb2.common as common
from vision.tracking.pose_estimation.object_3d_state_6dof_py import (
    Object3DState6Dof,
)
import vision.classification.mono_depth.training.md_vb2.viz.utils as md_vb2_viz_utils
import vision.classification.mono_depth.training.vb2.surface_pt_utils as surface_pt_utils
import base.geometry.transformations as transformations
from vision.inference.memory.blob_proto_utils import blob_proto_to_mat
from vision.data.proto.annotated_image_pb2 import BlobRegion
import time
import collections

"""
Function to create dataset from directory path
  Should return tuple (dictionary?) with various tensors

Two categories of tensors
  Tensors for input to inference
  Tensors for GT eval

Function to do inference with given tensors

Should return dictionary of output tensors

Above could be one function for backbone, another function to create outputs off of backbone

Function to create losses - should take in output of inference, GT tensors and input to inference (sometimes useful data is needed there)
"""


def _get_per_class_input_tensor_schema():
    return net_utils.create_nt_with_init_values(
        "VB2PerClassInputTensorSchema",
        "roi_coords",
        TensorInfo(tf.float32, (None, 4)),
        "box_to_im_inds",
        TensorInfo(tf.int32, (None,)),
        "proto_str",
        TensorInfo(tf.string, (None,)),
    )


def get_input_tensor_schema():
    vs = []
    for c in common.CLASS_IDS:
        vs.append(c)
        vs.append(_get_per_class_input_tensor_schema())
    return net_utils.create_nt_with_init_values("VB2InputTensorSchema", *vs)


def _get_per_class_gt_tensor_schema():
    return net_utils.create_nt_with_init_values(
        "VB2PerClassGTTensorSchema",
        "ctr_coords",
        TensorInfo(tf.float32, (None, common.NUM_SURFACE_PTS, 2)),
        "ctr_depths",  # index for logit
        TensorInfo(tf.int32, (None, common.NUM_SURFACE_PTS)),
        "ctr_depth_residuals",
        TensorInfo(tf.float32, (None, common.NUM_SURFACE_PTS)),
        "ori_labels",
        TensorInfo(
            tf.float32, (None, common.NUM_SURFACE_PTS, common.NUM_ORI_BINS)
        ),
        "ori_residuals",
        TensorInfo(
            tf.float32, (None, common.NUM_SURFACE_PTS, common.NUM_ORI_BINS)
        ),
        "ori_residuals_valid",
        TensorInfo(
            tf.int32, (None, common.NUM_SURFACE_PTS, common.NUM_ORI_BINS)
        ),
        "extents",
        TensorInfo(tf.float32, (None, 3)),
        "car_length_labels",
        TensorInfo(tf.int32, (None,)),
        "car_width_labels",
        TensorInfo(tf.int32, (None,)),
        "car_height_labels",
        TensorInfo(tf.int32, (None,)),
        "surface_pt_in_image",
        # true if a surface point is visible in the image;
        # visible means that its projection lies in image bounds
        TensorInfo(tf.bool, (None, common.NUM_SURFACE_PTS)),
        "frustum_ctr_ori_logits",
        TensorInfo(tf.float32, (None, common.NUM_ORI_BINS)),
        "frustum_ctr_ori_residuals",
        TensorInfo(tf.float32, (None, common.NUM_ORI_BINS)),
        "frustum_ctr_ori_residuals_valid",
        TensorInfo(tf.int32, (None, common.NUM_ORI_BINS)),
    )


def get_gt_tensor_schema():
    vs = []
    for c in common.CLASS_IDS:
        vs.append(c)
        vs.append(_get_per_class_gt_tensor_schema())
    return net_utils.create_nt_with_init_values("VB2GTTensorSchema", *vs)


def get_per_class_output_tensor_schema():
    return net_utils.create_nt_with_init_values(
        "VB2OutputTensorData",
        "depth_logits",
        TensorInfo(
            tf.float32, (None, common.NUM_SURFACE_PTS, common.NUM_DEPTH_LOGITS)
        ),
        "depth_residuals",
        TensorInfo(
            tf.float32, (None, common.NUM_SURFACE_PTS, common.NUM_DEPTH_LOGITS)
        ),
        "ctr_coords",
        TensorInfo(tf.float32, (None, common.NUM_SURFACE_PTS, 2)),
        "ori_logits",
        TensorInfo(
            tf.float32, (None, common.NUM_SURFACE_PTS, common.NUM_ORI_BINS)
        ),
        "ori_residuals",
        TensorInfo(tf.float32, (None, common.NUM_SURFACE_PTS)),
        "extents",
        TensorInfo(tf.float32, (None, 3)),
        "extents_uncertainty",
        TensorInfo(tf.float32, (None, 3)),
        "car_length_logits",
        TensorInfo(tf.float32, (None, common.NUM_CAR_LENGTH_BINS)),
        "car_width_logits",
        TensorInfo(tf.float32, (None, common.NUM_CAR_WIDTH_BINS)),
        "car_height_logits",
        TensorInfo(tf.float32, (None, common.NUM_CAR_HEIGHT_BINS)),
        "visibility_logits",
        TensorInfo(tf.float32, (None, common.NUM_SURFACE_PTS, 2)),
        "frustum_ctr_ori_logits",
        TensorInfo(tf.float32, (None, common.NUM_ORI_BINS)),
        "frustum_ctr_ori_residuals",
        TensorInfo(tf.float32, (None, common.NUM_ORI_BINS)),
    )


def get_output_tensor_schema():
    vs = []
    for c in common.CLASS_IDS:
        vs.append(c)
        vs.append(get_per_class_output_tensor_schema())
    return net_utils.create_nt_with_init_values("VB2OutputTensorSchema", *vs)


LossesTuple = namedtuple(
    "VB2LossesTuple",
    "loss depth_loss residual_loss surface_pt_coords_loss ori_loss extents_loss car_length_loss car_width_loss car_height_loss visibility_loss ori_residuals_loss frustum_ctr_ori_loss",
)


def get_argparser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--train_dirs",
        type=str,
        help="Comma-separated directories containing tfrecords for training.",
    )
    parser.add_argument(
        "--nr", type=int, help="Input image height", default=300
    )
    parser.add_argument(
        "--nc", type=int, help="Input image height", default=480
    )
    parser.add_argument("--batch_size", type=int, default=1)
    return parser


def parse_proto(args, im_idx, inp, input_schema, gt_schema):
    """
    inp should be a serialized VisionBox2Dataset proto
    """

    vb_ds = VisionBox2Dataset()
    vb_ds.ParseFromString(inp)

    safe_camera = SafeCamera(vb_ds.camera, 0.1)

    all_input_data = tf_nest.map_structure(lambda x: [], input_schema)
    all_gt_data = tf_nest.map_structure(lambda x: [], gt_schema)

    # Limit the number of items to at most 25. Choose closest objects ?
    boxes_and_depths = []
    for vb in vb_ds.vbox:
        t = vb.cam_box.pose.translation
        v = np.array([t.x, t.y, t.z])
        dist_to_ctr = np.sqrt(np.sum(v * v))
        boxes_and_depths.append((dist_to_ctr, vb))

    boxes_and_depths = sorted(boxes_and_depths, key=lambda x: x[0])

    MAX_NUM_BOXES = 100

    if len(boxes_and_depths) > MAX_NUM_BOXES:
        print("Too many boxes: %d" % len(boxes_and_depths))

    boxes_and_depths = boxes_and_depths[0:MAX_NUM_BOXES]

    # vbs = [p[1] for p in boxes_and_depths]

    for box_dist, vb in boxes_and_depths:
        if box_dist < common.DEPTH_BINS[0]:
            continue

        if vb.class_id <= 0:
            continue

        if vb.class_id > len(common.CLASS_IDS):
            print("Did not understand class id: " + str(vb.class_id))
            continue

        class_idx = common.CLASS_IDS[vb.class_id - 1]

        field_name = class_idx
        input_data = all_input_data._asdict()[field_name]
        gt_data = all_gt_data._asdict()[field_name]

        camera_box = Object3DState6Dof(vb.cam_box)

        extents = vb.cam_box.extents
        if class_idx == "car":
            if extents.x < 1.6:
                continue

        ori_offset = 0
        # Compute labels for extents.
        if class_idx == "car":
            x = extents.x
            y = extents.y
            if y > x:
                # If width is longer than length, canonicalize
                # along longer dimension.
                ori_offset = np.pi / 2
                tmp = extents.x
                extents.x = extents.y
                extents.y = tmp

            length_idx = np.digitize(x, common.CAR_LENGTH_BINS)
            width_idx = np.digitize(y, common.CAR_WIDTH_BINS)
            height_idx = np.digitize(extents.z, common.CAR_HEIGHT_BINS)

            assert length_idx >= 0
            assert width_idx >= 0
            assert height_idx >= 0
            assert length_idx < common.NUM_CAR_LENGTH_BINS
            assert width_idx < common.NUM_CAR_WIDTH_BINS
            assert height_idx < common.NUM_CAR_HEIGHT_BINS

        else:
            length_idx = 0
            width_idx = 0
            height_idx = 0

        gt_data.extents.append(
            np.array([extents.x, extents.y, extents.z], dtype=np.float32)
        )

        gt_data.car_length_labels.append(np.int32(length_idx))
        gt_data.car_width_labels.append(np.int32(width_idx))
        gt_data.car_height_labels.append(np.int32(height_idx))

        # row: point index, col: coord index
        # scale factor: scaling from full image res to input image resolution.
        roi_coords = np.reshape(
            np.array([vb.x1, vb.y1, vb.x2, vb.y2]), (2, 2)
        ).astype(np.float32)

        # Swap x and y columns and scale by image dimension.
        # Determine input roi coords - this is in normalized coords [0,1]
        roi_coords_for_crop_and_resize = roi_coords[:, [1, 0]] / np.array(
            [1200, 1920], dtype=np.float32
        )
        # roi_coords_for_crop_and_resize = roi_coords[:, [1, 0]] / np.array(
        #    [args.nr, args.nc], dtype=np.float32)

        input_data.roi_coords.append(
            np.array(roi_coords_for_crop_and_resize.flatten()).astype(
                np.float32
            )
        )
        # Done with normalized coords

        ori = get_frustum_ori(
            safe_camera,
            (np.mean([vb.x1, vb.x2]), np.mean([vb.y1, vb.y2])),
            camera_box,
        )
        frustum_ctr_ori_label, frustum_ctr_ori_residuals, frustum_ctr_ori_residuals_valid = ori_to_bin_and_label(
            ori
        )

        gt_data.frustum_ctr_ori_logits.append(frustum_ctr_ori_label)
        gt_data.frustum_ctr_ori_residuals.append(frustum_ctr_ori_residuals)
        gt_data.frustum_ctr_ori_residuals_valid.append(
            frustum_ctr_ori_residuals_valid
        )

        # Surface pts
        # Given as input safe camera, camera box,
        # return:
        # ctr_coords, ori, ctr_depth, depth_residual, ctr_is_visible
        # Determine surface pts in camera coords
        all_ctr_proj_coords_in_roi = []
        all_ctr_depth_bin_idxs = []
        all_ctr_depth_residuals = []
        all_ori_labels = []
        all_ori_residuals = []
        all_ori_residuals_valid = []
        all_is_visible = []
        # gt_data.ctr_depths.append(ctr_depth_bin_idx)
        # gt_data.ctr_depth_residuals.append(residual)

        for surface_pt_idx, surface_pt in enumerate(common.SURFACE_PT_DATA.pts):
            adj_pt_idxs = common.SURFACE_PT_DATA.adjacent_pt_idxs[
                surface_pt_idx
            ]

            pt_info = get_surface_pt_info(
                safe_camera,
                camera_box,
                surface_pt,
                common.SURFACE_PT_DATA.pts[adj_pt_idxs, :],
            )

            # Determine projection of ctr pt in image
            roi_dims = roi_coords[1, :] - roi_coords[0, :]
            ctr_proj_coords_in_image = np.array(pt_info.ctr_coords).astype(
                np.float32
            )

            ctr_proj_coords_in_roi = (
                ctr_proj_coords_in_image - roi_coords[0, :]
            ) / roi_dims - 0.5
            all_ctr_proj_coords_in_roi.append(ctr_proj_coords_in_roi)

            all_ctr_depth_bin_idxs.append(pt_info.depth_bin_idx)
            all_ctr_depth_residuals.append(pt_info.depth_residual)
            all_ori_labels.append(pt_info.ori)
            all_ori_residuals.append(pt_info.ori_residuals)
            all_ori_residuals_valid.append(pt_info.ori_residuals_valid)
            all_is_visible.append(pt_info.is_visible)

        gt_data.ctr_coords.append(np.array(all_ctr_proj_coords_in_roi))
        gt_data.ctr_depths.append(np.array(all_ctr_depth_bin_idxs))
        gt_data.ctr_depth_residuals.append(np.array(all_ctr_depth_residuals))
        gt_data.ori_labels.append(np.array(all_ori_labels))
        gt_data.ori_residuals.append(np.array(all_ori_residuals))
        gt_data.ori_residuals_valid.append(np.array(all_ori_residuals_valid))
        gt_data.surface_pt_in_image.append(np.array(all_is_visible))

        input_data.box_to_im_inds.append(np.array(im_idx, dtype=np.int32))

    return all_input_data, all_gt_data, vb_ds


def parse_protos(args, inps, input_schema, gt_schema):

    all_input_data = []
    all_gt_data = []
    all_protos = []
    for i, inp in enumerate(inps):

        input_data, gt_data, parsed_proto = parse_proto(
            args, i, inp, input_schema, gt_schema
        )

        all_input_data.append(input_data)
        all_gt_data.append(gt_data)
        all_protos.append(parsed_proto)

    # Compress input data into single structure

    stacked_inputs = net_utils.stack_nested_data(input_schema, all_input_data)
    stacked_gts = net_utils.stack_nested_data(gt_schema, all_gt_data)
    ds_wrapper = VisionBox2Datasets()
    for proto in all_protos:
        ds = ds_wrapper.ds.add()
        ds.CopyFrom(proto)

    serialized_ds_wrapper = ds_wrapper.SerializeToString()

    return (stacked_inputs, stacked_gts, serialized_ds_wrapper)


def get_2d_rotation_matrix(t):
    c = np.cos(t)
    s = np.sin(t)

    return np.array([[c, -s], [s, c]])


def get_box_orientation(box):
    corners = box.corners()
    diff = (corners[1] - corners[0])[[0, 2]]
    diff = diff / np.sqrt(np.sum(diff * diff))
    return np.arctan2(diff[1], diff[0])


def get_frustum_ori(safe_camera, ctr_coords, camera_box):
    """
    Given a safe camera, ctr coords (full image coordinates) and a box
    in camera coordinates, returns the orientation (yaw) of the box
    in the frustum coordinate system defined w/ the frustum ctr as ctr coords.
    """
    camera_box_corners = np.array(camera_box.corners())
    camera_ori_pts = camera_box_corners[[0, 1], :]

    frustum_from_camera_tform = safe_camera.get().frustum_from_camera(
        ctr_coords
    )

    frustum_from_camera_tform4 = np.eye(4)
    frustum_from_camera_tform4[:3, :3] = frustum_from_camera_tform

    frustum_ori_pts = np.dot(frustum_from_camera_tform, camera_ori_pts.T).T
    frustum_ori_vec = frustum_ori_pts[1, :] - frustum_ori_pts[0, :]

    frustum_ori_vec = frustum_ori_vec[[0, 2]]
    frustum_ori_vec = frustum_ori_vec / (
        np.sqrt(np.sum(frustum_ori_vec * frustum_ori_vec))
    )
    ori = np.arctan2(frustum_ori_vec[1], frustum_ori_vec[0]) % (2 * np.pi)

    return ori


def ori_to_bin_and_label(ori):
    ori_label = np.zeros((common.NUM_ORI_BINS,)).astype(np.float32)
    ori_residuals = np.zeros((common.NUM_ORI_BINS,), dtype=np.float32)
    ori_residuals_valid = np.zeros((common.NUM_ORI_BINS), dtype=np.int32)

    ori_bin = (
        int(np.floor((ori / (2 * np.pi)) * common.NUM_ORI_BINS))
        % common.NUM_ORI_BINS
    )

    for offset_idx in [-1, 0, 1]:
        offset_ori_bin = (ori_bin + offset_idx) % common.NUM_ORI_BINS
        offset_ori_bin_ctr = (
            (offset_ori_bin + 0.5) * (2 * np.pi) / (common.NUM_ORI_BINS)
        )
        ori_residuals[offset_ori_bin] = np.float32(ori - offset_ori_bin_ctr)
        ori_residuals_valid[offset_ori_bin] = 1

    ori_label = np.zeros((common.NUM_ORI_BINS,)).astype(np.float32)
    ori_label[ori_bin] = 1

    return ori_label, ori_residuals, ori_residuals_valid


def get_surface_pt_info(safe_camera, camera_box, surface_pt, adj_surface_pts):
    """
    surface_pt: in box coordinates
    """
    is_visible = False
    ctr_coords = np.zeros((2,), np.float32)
    ori_label = np.zeros((common.NUM_ORI_BINS,)).astype(np.float32)
    ori_label[0] = 1
    depth_bin_idx = np.int32(0)
    depth_residual = np.float32(0)

    ori_residuals = np.zeros((common.NUM_ORI_BINS,), dtype=np.float32)
    ori_residuals_valid = np.zeros((common.NUM_ORI_BINS), dtype=np.int32)

    # Determine surface point in camera coords; call it camera_pt.

    extents = camera_box.extents()
    pt_in_box_coords = np.array(
        [
            surface_pt[0] * extents[0],
            surface_pt[1] * extents[1],
            surface_pt[2] * extents[2],
        ]
    )

    camera_pt = md_vb2_viz_utils.transform_pts(
        camera_box.pose(), np.array([pt_in_box_coords])
    )[0, :]

    # print(str(camera_pt))
    normal_check_passed = True
    if len(adj_surface_pts) > 0:
        adj_pts_in_box_coords = np.array(adj_surface_pts) * np.array(extents)
        adj_pts_in_camera_coords = md_vb2_viz_utils.transform_pts(
            camera_box.pose(), adj_pts_in_box_coords
        )

        adj_surface_normals = adj_pts_in_camera_coords - camera_pt

        # Only use x-z plane
        adj_surface_normals = adj_surface_normals[:, [0, 2]]
        adj_surface_normals = adj_surface_normals / np.expand_dims(
            np.sqrt(np.sum(adj_surface_normals * adj_surface_normals, 1)), 1
        )

        # Need to rotate by 90 degrees.
        R = get_2d_rotation_matrix(np.pi / 2.0)
        adj_surface_normals = np.dot(R, adj_surface_normals.T).T

        # Need to flip normals so they are outward facing.
        box_ctr_in_camera_coords = camera_box.pose()[0:3, -1]
        outward_vec = (camera_pt - box_ctr_in_camera_coords)[[0, 2]]
        outward_vec = outward_vec / np.sqrt(np.sum(outward_vec * outward_vec))

        dps = np.expand_dims(np.dot(adj_surface_normals, outward_vec), 1)

        adj_surface_normals *= 2 * (dps > 0) - 1

        # At least one of these normals should be in the "same" direction as
        # the vector from the surface pt to the camera.

        viewing_vec = -1 * camera_pt[[0, 2]]
        viewing_vec /= np.linalg.norm(viewing_vec)

        dps = np.dot(adj_surface_normals, viewing_vec)

        if not np.any(dps > 0):
            normal_check_passed = False

    # Try projecting camera pt into image.
    camera_pt_proj = safe_camera.bound_checking_project(camera_pt)

    # Also check that surface normal(s) of pt are compatible with viewing
    # direction.

    if camera_pt_proj is not None and normal_check_passed:
        is_visible = True
        ctr_coords = np.array(camera_pt_proj, dtype=np.float32)
        # depth = np.sqrt(np.sum(camera_pt * camera_pt))
        depth = np.linalg.norm(camera_pt)
        depth_bin_idx = np.digitize(depth, common.DEPTH_BINS).astype(np.int32)

        log_depth = np.log(depth)
        ctr_delta = log_depth - common.LOG_DEPTH_BIN_CTRS[depth_bin_idx]

        depth_residual = (
            ctr_delta / common.LOG_DEPTH_BIN_WIDTHS[depth_bin_idx]
        ).astype(np.float32)

        # Determine orientation.
        ori = get_frustum_ori(safe_camera, ctr_coords, camera_box)

        ori_label, ori_residuals, ori_residuals_valid = ori_to_bin_and_label(
            ori
        )

    return surface_pt_utils.SurfacePtInfo(
        is_visible=is_visible,
        ctr_coords=ctr_coords,
        ori=ori_label,
        depth_bin_idx=depth_bin_idx,
        depth_residual=depth_residual,
        camera_pt=camera_pt,
        ori_residuals=ori_residuals,
        ori_residuals_valid=ori_residuals_valid,
    )


def _get_region_from_vision_box2(vbox):
    assert vbox.HasField("vtrk")
    assert vbox.vtrk.HasField("track")
    assert vbox.vtrk.track.HasField("region")
    return vbox.vtrk.track.region


def parse_proto_v3(vb_ds, input_schema, gt_schema):
    """
    Parses VisionBox2Dataset proto containing a single example.
    Expects input to consist of the image and mono depth logits.
    Does not extract/expect an intermediate feature map to be in the input.
    vb_ds is a VisionBox2Dataset proto.
    """
    if len(vb_ds.vbox) > 1:
        print(
            "Your vbox container consists of more than one example. This "
            "is unexpected -- we will only extract the first example from "
            "the proto, abandoning the rest. Fix your dataset immediately!"
        )

    assert len(vb_ds.vbox) >= 1
    vbox = vb_ds.vbox[0]

    # Populate input data.
    region = _get_region_from_vision_box2(vbox)
    ext = region.Extensions[BlobRegion.ext]
    print("len(ext.blobs) = %s" % len(ext.blobs))
    assert len(ext.blobs) == 2
    assert ext.blobs[0].name == "image"
    assert ext.blobs[1].name == "mono_depth/conv2d_5/BiasAdd"

    # This is the most expensive part.
    all_input_data = tf_nest.map_structure(lambda x: [], input_schema)
    all_input_data.roi_image_patches.append(
        np.squeeze(blob_proto_to_mat(ext.blobs[0]))
    )
    all_input_data.roi_md_depth_logits_patch.append(
        np.squeeze(blob_proto_to_mat(ext.blobs[1]))
    )

    # Parse groundtruth data from proto.
    gt_data = _parse_groundtruth_from_proto(vb_ds, gt_schema)
    return all_input_data, gt_data


SimpleGTTensorSchema = collections.namedtuple(
    "SimpleGTTensorSchema",
    [
        "ctr_coords",
        "ctr_depths",
        "ctr_depth_residuals",
        "ori_labels",
        "ori_residuals",
        "ori_residuals_valid",
        "extents",
        "surface_pt_in_image",
        "frustum_ctr_ori_logits",
        "frustum_ctr_ori_residuals",
        "frustum_ctr_ori_residuals_valid",
    ],
)

BoxingInputsTuple = collections.namedtuple(
    "BoxingInputsTuple",
    [
        "roi_crops",
        "roi_image_patches",
        "roi_md_depth_logits_patch",
        "roi_md_depth_residuals_patch",
    ],
)


def _parse_groundtruth_from_proto_fast(vb_ds):
    assert vb_ds.HasField("camera")
    safe_camera = SafeCamera(vb_ds.camera, 0.1)

    if len(vb_ds.vbox) > 1:
        print(
            "Your vbox container consists of more than one example. This "
            "is unexpected -- we will only extract the first example from "
            "the proto, abandoning the rest. Fix your dataset immediately!"
        )

    assert len(vb_ds.vbox) >= 1
    vbox = vb_ds.vbox[0]

    region = _get_region_from_vision_box2(vbox)

    # Populate groundtruth.
    assert vbox.HasField("cam_box")
    camera_box = Object3DState6Dof(vbox.cam_box)
    extents = vbox.cam_box.extents

    # center_in_image = (np.mean([region.x1, region.x2]), np.mean([region.y1, region.y2]))
    center_in_image = (
        (region.x1 + region.x2) / 2.0,
        (region.y1 + region.y2) / 2.0,
    )

    ori = get_frustum_ori(safe_camera, center_in_image, camera_box)
    frustum_ctr_ori_label, frustum_ctr_ori_residuals, frustum_ctr_ori_residuals_valid = ori_to_bin_and_label(
        ori
    )

    # Surface pts
    # Given as input safe camera, camera box,
    # return:
    # ctr_coords, ori, ctr_depth, depth_residual, ctr_is_visible
    # Determine surface pts in camera coords
    all_ctr_proj_coords_in_roi = []
    all_ctr_depth_bin_idxs = []
    all_ctr_depth_residuals = []
    all_ori_labels = []
    all_ori_residuals = []
    all_ori_residuals_valid = []
    all_is_visible = []

    # Coordinates of the 2D box in the image.
    roi_coords = np.reshape(
        np.array([region.x1, region.y1, region.x2, region.y2]), (2, 2)
    ).astype(np.float32)

    roi_dims = roi_coords[1, :] - roi_coords[0, :]
    region_x1 = roi_coords[0, 0]
    region_y1 = roi_coords[0, 1]
    for surface_pt_idx, surface_pt in enumerate(common.SURFACE_PT_DATA.pts):
        adj_pt_idxs = common.SURFACE_PT_DATA.adjacent_pt_idxs[surface_pt_idx]

        pt_info = get_surface_pt_info(
            safe_camera,
            camera_box,
            surface_pt,
            common.SURFACE_PT_DATA.pts[adj_pt_idxs, :],
        )

        # Determine projection of ctr pt in image

        ctr_proj_coords_in_image = np.array(pt_info.ctr_coords).astype(
            np.float32
        )

        ctr_proj_coords_in_roi = (
            ctr_proj_coords_in_image - roi_coords[0, :]
        ) / roi_dims - 0.5

        all_ctr_proj_coords_in_roi.append(ctr_proj_coords_in_roi)

        all_ctr_depth_bin_idxs.append(pt_info.depth_bin_idx)
        all_ctr_depth_residuals.append(pt_info.depth_residual)
        all_ori_labels.append(pt_info.ori)
        all_ori_residuals.append(pt_info.ori_residuals)
        all_ori_residuals_valid.append(pt_info.ori_residuals_valid)
        all_is_visible.append(pt_info.is_visible)

    return SimpleGTTensorSchema(
        ctr_coords=np.array(all_ctr_proj_coords_in_roi),
        ctr_depths=np.array(all_ctr_depth_bin_idxs),
        ctr_depth_residuals=np.array(all_ctr_depth_residuals),
        ori_labels=np.array(all_ori_labels),
        ori_residuals=np.array(all_ori_residuals),
        ori_residuals_valid=np.array(all_ori_residuals_valid),
        extents=np.array([extents.x, extents.y, extents.z], dtype=np.float32),
        surface_pt_in_image=np.array(all_is_visible),
        frustum_ctr_ori_logits=frustum_ctr_ori_label,
        frustum_ctr_ori_residuals=frustum_ctr_ori_residuals,
        frustum_ctr_ori_residuals_valid=frustum_ctr_ori_residuals_valid,
    )


def parse_proto_v3_fast(vb_ds):
    """
    Parses VisionBox2Dataset proto containing a single example.
    Expects input to consist of the image and mono depth logits.
    Does not extract/expect an intermediate feature map to be in the input.
    vb_ds is a VisionBox2Dataset proto.
    """
    if len(vb_ds.vbox) > 1:
        print(
            "Your vbox container consists of more than one example. This "
            "is unexpected -- we will only extract the first example from "
            "the proto, abandoning the rest. Fix your dataset immediately!"
        )

    assert len(vb_ds.vbox) >= 1
    vbox = vb_ds.vbox[0]

    # Populate input data.
    region = _get_region_from_vision_box2(vbox)
    ext = region.Extensions[BlobRegion.ext]
    assert len(ext.blobs) == 2
    assert ext.blobs[0].name == "image"
    assert ext.blobs[1].name == "mono_depth/conv2d_5/BiasAdd"

    # Populate inputs structure.
    dummy_float = np.zeros(1, dtype=np.float32)
    inputs = BoxingInputsTuple(
        roi_crops=dummy_float,
        roi_image_patches=np.squeeze(blob_proto_to_mat(ext.blobs[0])),
        roi_md_depth_logits_patch=np.squeeze(blob_proto_to_mat(ext.blobs[1])),
        roi_md_depth_residuals_patch=dummy_float,
    )

    # Parse groundtruth data from proto. Returns a SimpleGT namedtuple.
    gt_data = _parse_groundtruth_from_proto_fast(vb_ds)
    return inputs, gt_data


def decode_depths(logits_argmax, depth_residuals):  # all_depth_residuals):

    log_depth_bin_widths = common.LOG_DEPTH_BIN_WIDTHS
    log_depth_bin_widths[log_depth_bin_widths == np.inf] = 0

    log_bin_ctrs = common.LOG_DEPTH_BIN_CTRS[logits_argmax]

    log_depths = (
        log_bin_ctrs + depth_residuals * log_depth_bin_widths[logits_argmax]
    )

    return np.exp(log_depths)


def decode_depths2(depth_logits, depth_residuals):
    depths = common.decode_depths4(depth_logits, depth_residuals)

    return depths
