import ori_utils
import numpy as np


def test_bin_width():
    NUM_ORI_BINS = 60
    BIN_WIDTH = 2 * np.pi / NUM_ORI_BINS
    assert np.abs((ori_utils.bin_width(NUM_ORI_BINS) - BIN_WIDTH) < 1e-6)


def test_ori_to_bin_and_label_works():
    NUM_ORI_BINS = 60
    BIN_WIDTH = ori_utils.bin_width(NUM_ORI_BINS)

    ori = 10.5 * BIN_WIDTH
    ori_label, ori_residuals, ori_residuals_valid = ori_utils.ori_to_bin_and_label(
        ori, NUM_ORI_BINS
    )

    ORI_LABEL_EXPECTED = np.zeros((NUM_ORI_BINS,))
    ORI_LABEL_EXPECTED[10] = 1

    assert np.all(ori_label == ORI_LABEL_EXPECTED), str(ori_label)

    ORI_RESIDUALS_EXPECTED = np.zeros((NUM_ORI_BINS,))
    ORI_RESIDUALS_EXPECTED[9] = BIN_WIDTH
    ORI_RESIDUALS_EXPECTED[11] = -1 * BIN_WIDTH

    assert np.all(np.abs(ori_residuals - ORI_RESIDUALS_EXPECTED) < 1e-6), str(
        ori_residuals
    )

    ORI_RESIDUALS_VALID_EXPECTED = np.zeros((NUM_ORI_BINS,))
    ORI_RESIDUALS_VALID_EXPECTED[9:12] = 1
    assert np.all(
        np.abs(ori_residuals_valid - ORI_RESIDUALS_VALID_EXPECTED) < 1e-6
    ), str(ori_residuals)


def test_bin_and_label_to_ori_works():
    NUM_ORI_BINS = 60

    BIN_WIDTH = ori_utils.bin_width(NUM_ORI_BINS)

    ori = ori_utils.bin_and_residual_to_ori(10, BIN_WIDTH * 0.5, NUM_ORI_BINS)

    EXPECTED_ORI = 11 * BIN_WIDTH
    assert np.abs(ori - EXPECTED_ORI) < 1e-5, (
        str(ori) + " vs. " + str(EXPECTED_ORI)
    )


def test_is_invertible():
    NUM_ORI_BINS = 60
    ORI = np.pi
    ori_label, ori_residuals, _ = ori_utils.ori_to_bin_and_label(
        ORI, NUM_ORI_BINS
    )

    bin_idx = np.argmax(ori_label)
    residual = ori_residuals[bin_idx]

    ori = ori_utils.bin_and_residual_to_ori(bin_idx, residual, NUM_ORI_BINS)

    assert np.abs(ORI - ori) < 1e-6


def test_get_angle_dist():
    cases = [
        (0, np.pi, np.pi),
        (0, 1.75 * np.pi, 0.25 * np.pi),
        (-0.25 * np.pi, 1.75 * np.pi, 0),
    ]

    augmented_cases = []
    for case in cases:
        augmented_cases.append([case[0], case[1] + 2 * np.pi, case[2]])
        augmented_cases.append([case[0] + 2 * np.pi, case[1], case[2]])

    cases.extend(augmented_cases)

    for case in cases:
        result = ori_utils.get_angle_dist(case[0], case[1])
        assert np.isclose(result, case[2]), str(case)


def test_get_angle_dist_180():
    cases = [(0, np.pi, 0), (0, 3 * np.pi / 2, np.pi / 2)]

    for case in cases:
        result = ori_utils.get_angle_dist_180(case[0], case[1])
        assert np.isclose(result, case[2]), str(case)


def test_get_angle_dist_90():
    cases = [(0, np.pi, 0), (0, 3 * np.pi / 2, 0)]

    for case in cases:
        result = ori_utils.get_angle_dist_90(case[0], case[1])
        assert np.isclose(result, case[2]), str(case)


if __name__ == "__main__":
    test_bin_width()
    test_ori_to_bin_and_label_works()
    test_bin_and_label_to_ori_works()
    test_is_invertible()
    test_get_angle_dist()
    test_get_angle_dist_180()
    test_get_angle_dist_90()
    print("success!")
