import numpy as np


def ori_to_bin_and_label(ori, NUM_ORI_BINS):
    ori_label = np.zeros((NUM_ORI_BINS,)).astype(np.float32)
    ori_residuals = np.zeros((NUM_ORI_BINS,), dtype=np.float32)
    ori_residuals_valid = np.zeros((NUM_ORI_BINS), dtype=np.int32)

    ori_bin = int(np.floor((ori / (2 * np.pi)) * NUM_ORI_BINS)) % NUM_ORI_BINS

    for offset_idx in [-1, 0, 1]:
        offset_ori_bin = (ori_bin + offset_idx) % NUM_ORI_BINS
        offset_ori_bin_ctr = (
            (offset_ori_bin + 0.5) * (2 * np.pi) / (NUM_ORI_BINS)
        )
        ori_residuals[offset_ori_bin] = np.float32(ori - offset_ori_bin_ctr)
        ori_residuals_valid[offset_ori_bin] = 1

    ori_label = np.zeros((NUM_ORI_BINS,)).astype(np.float32)
    ori_label[ori_bin] = 1

    return ori_label, ori_residuals, ori_residuals_valid


def bin_width(NUM_ORI_BINS):
    return 2 * np.pi / NUM_ORI_BINS


def bin_and_residual_to_ori(bin_idx, residual, NUM_ORI_BINS):
    BIN_WIDTH = 2 * np.pi / NUM_ORI_BINS
    return BIN_WIDTH * (bin_idx + 0.5) + residual


def logits_and_residuals_to_ori(ori_logits, ori_residuals, num_ori_bins=None):
    if num_ori_bins is not None:
        assert num_ori_bins == ori_logits.shape[0]
    else:
        num_ori_bins = ori_logits.shape[0]
    max_idx = np.argmax(ori_logits)
    residual = ori_residuals[max_idx]
    return bin_and_residual_to_ori(max_idx, residual, num_ori_bins)


def bin_and_label_to_ori(ori_label, ori_residuals, num_ori_bins=None):

    if num_ori_bins is not None:
        assert num_ori_bins == ori_labels.shape[0]
    else:
        num_ori_bins = ori_labels.shape[0]

    ori_idxs = np.nonzero(ori_label)

    ori_label = np.zeros((num_ori_bins,)).astype(np.float32)
    ori_residuals = np.zeros((num_ori_bins,), dtype=np.float32)
    ori_residuals_valid = np.zeros((num_ori_bins), dtype=np.int32)

    ori_bin = int(np.floor((ori / (2 * np.pi)) * num_ori_bins)) % num_ori_bins

    for offset_idx in [-1, 0, 1]:
        offset_ori_bin = (ori_bin + offset_idx) % num_ori_bins
        offset_ori_bin_ctr = (
            (offset_ori_bin + 0.5) * (2 * np.pi) / (num_ori_bins)
        )
        ori_residuals[offset_ori_bin] = np.float32(ori - offset_ori_bin_ctr)
        ori_residuals_valid[offset_ori_bin] = 1

    ori_label = np.zeros((num_ori_bins,)).astype(np.float32)
    ori_label[ori_bin] = 1

    return ori_label, ori_residuals, ori_residuals_valid


def get_angle_dist(x, y):
    """
    Returns absolute distance in radians between angles x and y (both in radians).
    """

    x = x % (2 * np.pi)
    y = y % (2 * np.pi)

    min_val = min([x, y])
    max_val = max([x, y])

    d0 = np.abs(max_val - min_val)
    d1 = np.abs(2 * np.pi + min_val - max_val)
    return min([d0, d1])


def get_angle_dist_90(x, y):
    dists = []
    for step in [0, np.pi / 2, np.pi, 3 * np.pi / 2, np.pi]:
        dists.append(get_angle_dist(x, (y + step) % (2 * np.pi)))
    return min(dists)


def get_angle_dist_180(x, y):
    dists = []
    for step in [0, np.pi]:
        dists.append(get_angle_dist(x, (y + step) % (2 * np.pi)))
    return min(dists)
