#!/usr/bin/env python3
"""Unit tests for comparator.py."""
import json
import math
import os
import sys

import pandas as pd
import pytest

from mined_metric.builder.metrics_impl.pisces.utils.comparator import (
    Comparator,
    SteeringComparator,
)


def test_max_abs_diff_abs_value():
    """Test that DataFrame.abs works."""
    data = pd.DataFrame(
        {
            "Timestamp": [
                1578503121.1,
                1578503121.2,
                1578503121.3,
                1578503121.4,
                1578503121.5,
            ],
            "decision_vel": [10.4, 10.5, 10.6, 10.7, 10.8],
        }
    )
    data = data.set_index(["Timestamp"])
    data.index = pd.to_datetime(data.index, unit="s")

    df_a = data.copy()
    df_b = data.copy()
    df_b["decision_vel"] = [-10.4, 10.5, 10.6, 10.7, 10.8]
    comparator = Comparator(name="scenario", dfA=df_a, dfB=df_b)

    assert comparator.max_abs_diff("decision_vel") == pytest.approx(20.8)


def test_max_abs_diff_normal():
    """Test max_abs_diff correctly reports no difference."""
    df_a = pd.DataFrame(
        {
            "Timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
            "decision_vel": [10.4, 10.5, 10.6, 10.7, 10.8],
            "decision_ave_vel": [0.779, 0.101, 0.316, 0.150, 1.8],
        }
    )
    df_a = df_a.set_index(["Timestamp"])
    df_a.index = pd.to_datetime(df_a.index, unit="s")

    df_b = pd.DataFrame(
        {
            "Timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
            "decision_vel": [10.4, 10.5, 10.6, 10.7, 10.8, 10.9],
            "decision_ave_vel": [0.779, 0.101, 0.316, 0.150, 1.8, 1.9],
        }
    )
    df_b = df_b.set_index(["Timestamp"])
    df_b.index = pd.to_datetime(df_b.index, unit="s")

    comparator = Comparator(name="scenario", dfA=df_a, dfB=df_b)

    assert comparator.max_abs_diff("decision_vel") == pytest.approx(0)


def test_max_abs_diff_df_is_none():
    """Test when value is empty after tracking overlapping time indexes."""
    df_a = pd.DataFrame({"Timestamp": [0.1], "decision_vel": [None]})
    df_a = df_a.set_index(["Timestamp"])
    df_a.index = pd.to_datetime(df_a.index, unit="s")

    df_b = pd.DataFrame({"Timestamp": [0.1], "decision_vel": [1]})
    df_b = df_b.set_index(["Timestamp"])
    df_b.index = pd.to_datetime(df_b.index, unit="s")

    comparator = Comparator(name="scenario", dfA=df_a, dfB=df_b)

    assert math.isnan(comparator.max_abs_diff("decision_vel"))


def test_max_abs_diff_fillna():
    """Test that fillna() works as expected."""
    df_a = pd.DataFrame(
        {
            "Timestamp": [0.1, 0.2],
            "decision_vel": [None, 10.5],
            "decision_ave_vel": [1, 2],
        }
    )
    df_a = df_a.set_index(["Timestamp"])
    df_a.index = pd.to_datetime(df_a.index, unit="s")

    df_b = pd.DataFrame(
        {
            "Timestamp": [0.1, 0.2],
            "decision_vel": [1, 10.6],
            "decision_ave_vel": [1, 2],
        }
    )
    df_b = df_b.set_index(["Timestamp"])
    df_b.index = pd.to_datetime(df_b.index, unit="s")

    comparator = Comparator(name="scenario", dfA=df_a, dfB=df_b)

    assert comparator.max_abs_diff("decision_vel") == pytest.approx(0.1)


def test_steering_comparator():
    """Test that diff of steering comparator presents data as expected."""

    df_a = pd.DataFrame(
        {
            "hash_id": ["DEADBEEF"],
            "control_chum_uri": ["chum://hero@1234.0"],
            "candidate_chum_uri": ["chum://hero@5678.0"],
            "sdl_target": ["//some/sdl/target"],
            "results": [json.dumps({"decision": {}, "tracker": {}})],
        }
    )

    df_b = pd.DataFrame(
        {
            "hash_id": ["DEADBEEF"],
            "control_chum_uri": ["chum://hero@1234.5"],
            "candidate_chum_uri": ["chum://hero@6789.0"],
            "sdl_target": ["//some/sdl/target"],
            "results": [json.dumps({"decision": {}, "tracker": {}})],
        }
    )

    comparator = SteeringComparator(name="tester", dfA=df_a, dfB=df_b)

    diff = comparator.diff()

    assert "candidate" in diff
    assert "control" in diff

    assert diff["candidate"]["sdl_target"] == diff["control"]["sdl_target"]
    assert diff["candidate"]["ctrl_uri"] == "chum://hero@1234.0"
    assert diff["candidate"]["cand_uri"] == "chum://hero@5678.0"

    assert diff["control"]["ctrl_uri"] == "chum://hero@1234.5"
    assert diff["control"]["cand_uri"] == "chum://hero@6789.0"


if __name__ == "__main__":
    args = [os.path.dirname(__file__), "--color=yes", "--verbose", "--verbose"]
    code = pytest.main(args)
    sys.exit(code)
