#!/usr/bin/env python3
"""Unit tests for comparing 2ws and 4ws.py."""
import math
import os
import sys

import pytest

from mined_metric.builder.metrics_impl.pisces.utils.compare_2ws_4ws import (
    PLANNER_MOTION_COMPARATOR_ARGS,
    compute_2ws_4ws_errors,
)
from vis.data_analysis.data_comparator.planner_motion_comparator import (
    PlannerMotionComparator,
)


def test_construct_planner_motion_comparator():
    """Test constructing a PlannerMotionComparator object."""
    baseline_dataset = "baseline_dataset"
    comparison_dataset = "comparison_dataset"
    baseline_wheel_base = 1.2
    comparison_wheel_base = 3.4

    comparator = PlannerMotionComparator(
        baseline_dataset,
        comparison_dataset,
        baseline_wheel_base,
        comparison_wheel_base,
        PLANNER_MOTION_COMPARATOR_ARGS,
    )

    assert comparator.bl_dataset == baseline_dataset
    assert comparator.comp_dataset == comparison_dataset
    assert comparator.bl_wheel_base == baseline_wheel_base
    assert comparator.comp_wheel_base == comparison_wheel_base
    assert comparator.layers == ["decision"]


def test_compute_2ws_4ws_errors():
    """Tests computing 2WS/4WS errors."""
    ctrl_uri = (
        "chum://vh6_sim@12354.0+6.56000001s?i=$vars,$empty,$vars,$empty&v="
        "logtests/PlannerModule/"
        "cf922660e343d02821cf1decaad02e45fccbf0788ca7c4195e91373239550f22_8cd37e88173511eb8d050242c0a80002,"
        "logtests/PlannerModule/"
        "cb5d5be345b8bb01da3cb9f61a0821ba6fac30a1c874c783f4e289cb9fab4c20_4179b316141b11eba2a60242c0a80002"
    )
    cand_uri = (
        "chum://vh6_sim@12354.0+6.56000001s?i=$vars,$empty,$vars,$empty&v="
        "logtests/PlannerModule/"
        "992455763f0fcdb362f3497823adbf8bdb58838364a9632c77d6fbbd086c533f_fdb01e00173411eb916b0242c0a80002,"
        "logtests/PlannerModule/"
        "cb5d5be345b8bb01da3cb9f61a0821ba6fac30a1c874c783f4e289cb9fab4c20_4179b316141b11eba2a60242c0a80002"
    )
    results = compute_2ws_4ws_errors(ctrl_uri, cand_uri)
    assert "Projected errors" in results

    projected_errors = results["Projected errors"]
    assert "decision" in projected_errors

    decision = projected_errors["decision"]
    errors = [
        "Decision s error(m)",
        "Projected Lateral error(m)",
        "Projected Heading error(deg)",
        "Projected Velocity vx error(m/s)",
        "Expected heading error diff(deg)",
        "Non-expected heading error diff(deg)",
    ]
    assert all([error in decision for error in errors])

    metrics = ["Max absolute", "RMS"]
    assert all(
        [metric in decision[error] for metric in metrics for error in errors]
    )
    assert all(
        [
            math.isfinite(decision[error][metric])
            for metric in metrics
            for error in errors
        ]
    )


if __name__ == "__main__":
    args = [os.path.dirname(__file__), "--color=yes", "--verbose", "--verbose"]
    code = pytest.main(args)
    sys.exit(code)
