from collections import namedtuple
import numpy as np

# import ipdb
from vision.classification.mono_depth.vb.vb_dataset_utils import (
    get_frustum_from_camera_transform,
)
import vision.classification.mono_depth.training.md.vis_utils as md_vis_utils
import cv2
import vision.classification.mono_depth.training.md_vb2.summary_utils as summary_utils
import vision.classification.mono_depth.training.md_vb2.common as common
import vision.classification.mono_depth.training.md_vb2.viz.utils as md_vb2_viz_utils
from vision.tracking.pose_estimation.proto.object_3d_state_6dof_pb2 import (
    Object3DState6DofProto,
)
from vision.tracking.pose_estimation.object_3d_state_6dof_py import (
    Object3DState6Dof,
)
import base.geometry.transformations as transformations
import vision.classification.mono_depth.vb2.ori_utils as ori_utils

Transforms = namedtuple("Transforms", "base_from_camera smooth_from_base")

Boxes = namedtuple("Boxes", "frustum_box camera_box base_box smooth_box")


def make_int_tuple(t):
    return tuple([int(x) for x in t])


def display_corners(safe_camera, corners_in_cam_coords, display_im):
    corners_in_im = [
        safe_camera.bound_checking_project(p) for p in corners_in_cam_coords
    ]

    assert len(corners_in_im) == 8

    pairs = (
        [(i, (i + 1) % 4) for i in range(4)]
        + [(4 + i, 4 + ((i + 1) % 4)) for i in range(4)]
        + [(i, i + 4) for i in range(4)]
    )

    for i, j in pairs:

        c0 = corners_in_im[i]
        c1 = corners_in_im[j]
        if c0 is None or c1 is None:
            continue

        cv2.line(
            display_im, make_int_tuple(c0), make_int_tuple(c1), (0, 0, 255), 1
        )

    BOTTOM_CORNER_COLORS = [
        (255, 0, 0),  # blue,
        (0, 255, 0),  # green,
        (0, 0, 255),  # red
        (255, 255, 255),
    ]  # white

    for i in range(4):
        pt = corners_in_im[i]
        if pt is None:
            continue
        cv2.circle(
            display_im, make_int_tuple(pt), 5, BOTTOM_CORNER_COLORS[i], 3
        )


def get_rotation_about_y_axis(ori):
    c = np.cos(ori)
    s = np.sin(ori)

    return np.array([[c, 0, -s], [0, 1, 0], [s, 0, c]])


def get_rotation_about_z_axis(ori):
    c = np.cos(ori)
    s = np.sin(ori)

    return np.array([[c, -s, 0], [-s, c, 0], [0, 0, 1]])


def get_box_orientation(box):
    corners = box.corners()
    diff = (corners[1] - corners[0])[[0, 2]]
    diff = diff / np.sqrt(np.sum(diff * diff))
    return np.arctan2(diff[1], diff[0])


class VB2BoxResult2(object):
    def __init__(
        self,
        safe_camera,
        roi_coords_xy,
        ctr_depth,
        ctr_coords_rel,
        extents,
        depth_logits,
        depth_residuals,
        ctr_visible_label,
        frustum_ctr_ori_logits,
        frustum_ctr_ori_residuals,
        display_im=None,
        im=None,
    ):

        self.all_frustum_boxes = []
        self.all_base_link_boxes = []
        self.all_camera_boxes = []
        self.extents = extents
        self.roi_coords_xy = roi_coords_xy
        self.roi_dims_xy = np.array(
            (
                self.roi_coords_xy[2] - self.roi_coords_xy[0],
                self.roi_coords_xy[3] - self.roi_coords_xy[1],
            )
        )

        self.ctr_depth = ctr_depth
        # Step one: get frustum box
        self.roi_ctr_xy = np.array(
            [
                (self.roi_coords_xy[0] + self.roi_coords_xy[2]) / 2.0,
                (self.roi_coords_xy[1] + self.roi_coords_xy[3]) / 2.0,
            ]
        ).astype(np.float64)

        frustum_ori = ori_utils.logits_and_residuals_to_ori(
            frustum_ctr_ori_logits, frustum_ctr_ori_residuals
        )

        self.frustum_ori = frustum_ori
        self.ctr_coords_rel = ctr_coords_rel
        ctr_im_coords = self.rel_ctr_coords_to_im_coords(ctr_coords_rel)
        self.ctr_im_coords = ctr_im_coords

        frustum_from_camera = get_frustum_from_camera_transform(
            ctr_im_coords, safe_camera
        )

        frustum_from_camera[3, :] = 0
        frustum_from_camera[:, 3] = 0
        frustum_from_camera[3, 3] = 1

        camera_from_frustum = np.linalg.inv(frustum_from_camera)
        ray = safe_camera.unsafe_unproject(ctr_im_coords)
        ray = ray / np.sqrt(np.sum(ray * ray))

        if True:
            frustum_box_proto = Object3DState6DofProto()
            frustum_box_proto.extents.x = self.extents[0]
            frustum_box_proto.extents.y = self.extents[1]
            frustum_box_proto.extents.z = self.extents[2]

            ctr_pt_cam_coords = ray * ctr_depth
            ctr_pt_frustum_coords = np.dot(
                frustum_from_camera[:3, :3], ctr_pt_cam_coords
            )

            rpy = [np.pi / 2.0, -1 * frustum_ori, 0]

            # transformations.quaternion_matrix(quat)[:3, :3]
            R = transformations.euler_matrix(*rpy)[:3, :3]

            translation = ctr_pt_frustum_coords

            frustum_box_proto.pose.euler_rotation.roll_radians = rpy[0]
            frustum_box_proto.pose.euler_rotation.pitch_radians = rpy[1]
            frustum_box_proto.pose.euler_rotation.yaw_radians = rpy[2]

            frustum_box_no_trans = Object3DState6Dof(frustum_box_proto)

            frustum_box_proto.pose.translation.x = translation[0]
            frustum_box_proto.pose.translation.y = translation[1]
            frustum_box_proto.pose.translation.z = translation[2]

            frustum_box = Object3DState6Dof(frustum_box_proto)

            self.all_frustum_boxes.append(frustum_box)
            camera_box = frustum_box.transformed(camera_from_frustum)
            self.all_camera_boxes.append(camera_box)

        self.base_from_camera = safe_camera.base_from_camera()
        self.all_base_link_boxes = [
            b.transformed(self.base_from_camera) for b in self.all_camera_boxes
        ]

    def rel_ctr_coords_to_im_coords(self, rel_ctr_coords):
        return rel_ctr_coords * self.roi_dims_xy + self.roi_ctr_xy


class VB2BoxResult(object):
    def __init__(
        self,
        safe_camera,
        roi_coords_xy,
        depth,
        ctr_coords_rel,
        extents,
        depth_logits,
        depth_residuals,
        ori_logits,
        ori_residuals,
        surface_pt_visible_labels,
        frustum_ctr_ori_logits,
        frustum_ctr_ori_residuals,
        display_im=None,
        im=None,
    ):
        """
        depth: depth of object ctr in meters
        roi_coords_xy: [x0,y0,x1,y1] in full-frame pixels
        ctr_coords: [x0,y0]; fractional coords relative to roi ctr

        After init, self.boxes will be populated w/ various boxes.
        """

        self.depth = depth
        self.roi_coords_xy = roi_coords_xy
        self.ctr_coords_rel = ctr_coords_rel
        self.extents = extents
        self.depth_logits = depth_logits
        self.depth_residuals = depth_residuals
        self.im = im
        # self.transforms = transforms

        self.roi_dims_xy = np.array(
            (
                self.roi_coords_xy[2] - self.roi_coords_xy[0],
                self.roi_coords_xy[3] - self.roi_coords_xy[1],
            )
        )

        # Step one: get frustum box
        self.roi_ctr_xy = np.array(
            [
                (self.roi_coords_xy[0] + self.roi_coords_xy[2]) / 2.0,
                (self.roi_coords_xy[1] + self.roi_coords_xy[3]) / 2.0,
            ]
        ).astype(np.float64)

        num_surface_pts = self.ctr_coords_rel.shape[0]

        self.frustum_discr_oris = (
            (
                (np.arange(common.NUM_ORI_BINS) + 0.5)
                / float(common.NUM_ORI_BINS)
            )
            * 2
            * np.pi
        )

        # import ipdb
        # ipdb.set_trace()
        # ori_logits = np.tile(frustum_ctr_ori_logits, [ori_logits.shape[0], 1])
        # ori_residuals = np.tile(frustum_ctr_ori_residuals, [
        #                        ori_logits.shape[0], 1])

        all_best_ori_idxs = np.argmax(ori_logits, 1)

        self.all_frustum_oris = (
            self.frustum_discr_oris[all_best_ori_idxs]
            + ori_residuals[
                np.arange(len(all_best_ori_idxs)), all_best_ori_idxs
            ]
        )

        all_surface_pt_im_coords = [
            self.rel_ctr_coords_to_im_coords(x) for x in self.ctr_coords_rel
        ]

        all_frustum_from_camera = []
        all_camera_from_frustum = []
        self.all_frustum_boxes = []
        self.all_camera_boxes = []
        for surface_pt_idx, surface_pt_im_coords in enumerate(
            all_surface_pt_im_coords
        ):
            if not surface_pt_visible_labels[surface_pt_idx]:
                continue
            frustum_from_camera = get_frustum_from_camera_transform(
                surface_pt_im_coords, safe_camera
            )

            frustum_from_camera[3, :] = 0
            frustum_from_camera[:, 3] = 0
            frustum_from_camera[3, 3] = 1

            camera_from_frustum = np.linalg.inv(frustum_from_camera)

            all_frustum_from_camera.append(frustum_from_camera)
            all_camera_from_frustum.append(camera_from_frustum)

            ray = safe_camera.unsafe_unproject(surface_pt_im_coords)
            ray = ray / np.sqrt(np.sum(ray * ray))

            surface_pt_cam_coords = ray * self.depth[surface_pt_idx]
            surface_pt_frustum_coords = np.dot(
                frustum_from_camera[:3, :3], surface_pt_cam_coords
            )

            frustum_ori = self.all_frustum_oris[surface_pt_idx]

            # frustum_ori = np.pi/2.0
            rpy = [np.pi / 2.0, -1 * frustum_ori, 0]
            # quat = [np.cos(-1*frustum_ori/2.0), 0,
            #        np.sin(-1*frustum_ori/2.0), 0]

            # transformations.quaternion_matrix(quat)[:3, :3]
            R = transformations.euler_matrix(*rpy)[:3, :3]

            # Translation is translation from surface pt in rotated box coords.
            surface_pt_box_coords = np.dot(
                R,
                (common.SURFACE_PT_DATA.pts[surface_pt_idx, :] * self.extents),
            )

            frustum_box_proto = Object3DState6DofProto()
            frustum_box_proto.extents.x = self.extents[0]
            frustum_box_proto.extents.y = self.extents[1]
            frustum_box_proto.extents.z = self.extents[2]

            translation = surface_pt_frustum_coords - surface_pt_box_coords

            frustum_box_proto.pose.euler_rotation.roll_radians = rpy[0]
            frustum_box_proto.pose.euler_rotation.pitch_radians = rpy[1]
            frustum_box_proto.pose.euler_rotation.yaw_radians = rpy[2]

            frustum_box_no_trans = Object3DState6Dof(frustum_box_proto)

            frustum_box_proto.pose.translation.x = translation[0]
            frustum_box_proto.pose.translation.y = translation[1]
            frustum_box_proto.pose.translation.z = translation[2]

            frustum_box = Object3DState6Dof(frustum_box_proto)

            # vec = frustum_box.corners()[1] - frustum_box.corners()[0]
            # vec = vec / np.sqrt(np.sum(vec*vec))
            # np.arctan2(vec[2], vec[0])

            self.all_frustum_boxes.append(frustum_box)
            camera_box = frustum_box.transformed(camera_from_frustum)

            self.all_camera_boxes.append(camera_box)

        # self.frustum_box = Object3DState6Dof(frustum_box_proto)

        self.base_from_camera = safe_camera.base_from_camera()
        self.all_base_link_boxes = [
            b.transformed(self.base_from_camera) for b in self.all_camera_boxes
        ]

    def rel_ctr_coords_to_im_coords(self, rel_ctr_coords):
        return rel_ctr_coords * self.roi_dims_xy + self.roi_ctr_xy

    def get_debug_html(self, im):
        """
        Returns a cropped image suitable for debugging. im should be full-size image frame.
        """

        htmls = []
        im = im.copy()

        # Draw roi boundaries (ends up being image boundaries).
        rect_uc = tuple([int(z) for z in self.roi_coords_xy[0:2]])
        rect_lr = tuple([int(z) for z in self.roi_coords_xy[2:4]])

        cv2.rectangle(im, rect_uc, rect_lr, (0, 0, 255))

        # Draw roi ctr xy:
        roi_ctr = (int(self.roi_ctr_xy[0]), int(self.roi_ctr_xy[1]))
        cv2.circle(im, roi_ctr, 3, (0, 255, 0))

        # Draw ctr coords xy
        circ_ctr = (int(self.ctr_coords_xy[0]), int(self.ctr_coords_xy[1]))
        cv2.circle(im, circ_ctr, 5, (255, 0, 0))

        crop = md_vis_utils.crop_roi(
            im,
            self.roi_coords_xy[0],
            self.roi_coords_xy[1],
            self.roi_coords_xy[2],
            self.roi_coords_xy[3],
        )

        crop = crop / 255.0

        htmls.append(md_vis_utils.encode_im_as_html((crop)))

        # Draw frustum corner offsets
        frustum_corner_offsets_im = np.zeros(
            (summary_utils.NR, summary_utils.NC, 3), dtype=np.int32
        )

        summary_utils.draw_corners(
            summary_utils.PIXEL_DIM,
            summary_utils.NR,
            summary_utils.NC,
            frustum_corner_offsets_im,
            self.corner_offsets_frustum,
            [255, 255, 255],
        )

        # print(str(frustum_corner_offsets_im))

        htmls.append(
            md_vis_utils.encode_im_as_html(frustum_corner_offsets_im / 255.0)
        )

        if self.depth_logits is not None:
            exp_logits = np.exp(self.depth_logits)
            depth_probs = exp_logits / np.sum(exp_logits)

            depth_prob_table = ["<table>"]
            for i, depth_prob in enumerate(depth_probs):
                depth_prob_table.append(
                    "<tr><td>%d</td><td>%0.2f</td></tr>" % (i, depth_prob)
                )

            depth_prob_table.append("</table>")

            htmls.append("".join(depth_prob_table))

        return "<BR/>".join(htmls)

    def get_transformed_ori_vec(self, x_from_base=None):
        """
        Returns orientation vector in coordinate system 'x' given
        x_from_base transform.
        """
        max_prob_idx = np.argmax(self.discr_ori_probs)

        self.max_prob_idx = max_prob_idx
        max_frustum_ori = (
            self.frustum_discr_oris[max_prob_idx]
            + (0.5 / float(common.NUM_ORI_BINS)) * 2 * np.pi
            - np.pi / 2
        )
        self.max_frustum_ori = max_frustum_ori

        frustum_ori_vec = np.array(
            [[np.cos(max_frustum_ori), 0, np.sin(max_frustum_ori)]]
        )

        self.frustum_ori_vec = frustum_ori_vec

        if x_from_base is not None:
            x_from_frustum_tform = np.dot(x_from_base, self.base_from_frustum)
        else:
            x_from_frustum_tform = self.base_from_frustum

        self.x_from_frustum_tform = x_from_frustum_tform
        x_ori_vec = md_vb2_viz_utils.transform_pts(
            x_from_frustum_tform, frustum_ori_vec
        )
        self.x_ori_vec = x_ori_vec

        x_ori_orig = md_vb2_viz_utils.transform_pts(
            x_from_frustum_tform, np.zeros((1, 3))
        )
        self.x_ori_orig = x_ori_orig

        x_ori_vec = x_ori_vec - x_ori_orig

        self.x_ori_vec_len = np.sqrt(np.sum(x_ori_vec * x_ori_vec))

        x_ori_vec_len = self.x_ori_vec_len
        x_ori_vec = x_ori_vec / x_ori_vec_len
        # self.x_ori_vec = x_ori_vec

        return np.squeeze(x_ori_vec)
