# Junction Exit Prediction Metrics

In [None]:
from typing import Optional

from IPython.display import display, HTML
from mined_metric.jupyter.utils.prediction_util import (
    collect_samples,
    create_deep_dive_classification_ui,
    group_samples,
    filter_exp_data,
    get_validation_data,
    Sample,
    SampleClassifier,
    SampleUtils,
    NO_SCROLL_STYLE,
)

In [None]:
# Data loading
VALIDATION_ID = "VALIDATION_ID_PLACEHOLDER"

exp_datas = {}

# Provide either a validation ID or a list of experiment IDs.
for data in get_validation_data(VALIDATION_ID):
    key = data["experiment"]["additionalMsg"] or data["experiment"]["id"]
    exp_datas[key] = data

filter_exp_data(exp_datas)
filtered_samples = collect_samples(exp_datas)
groups_by_id = group_samples(exp_datas)

In [None]:
class TdjeSampleClassifier(SampleClassifier):
    def __init__(self, fifo_threshold: float, threshold: float):
        self.fifo_threshold = fifo_threshold
        self.threshold = threshold

    def classify(self, sample: Sample) -> Optional[bool]:
        threshold = self.fifo_threshold if SampleUtils.has_tags(sample, ["FIFO"]) else self.threshold
        score = SampleUtils.get_score(sample)
        if score is None:
            return score
        return score >= threshold

In [None]:
# Display graphs
tags = [
    'FIFO',
    'Priority',
    'TrafficLight',
    'NearestJunction',
    'NotNearestJunction',
    '3Ways',
    '4Ways',
    '5Ways',
    '6Ways',
    '7Ways',
    '8Ways',
    '9Ways',
    '10Ways',
    'TurnLeft',
    'TurnRight',
    'TurnStraight',
    'LocationSF',
    'LocationSLAC',
    'LocationVegas',
]

default_tags_list = [["FIFO"], ["Priority"], ["TrafficLight"]]

# Note: customizing the threshold per experiment ID is supported! See documentation
# of create_deep_dive_classification_ui for details.
ui, out = create_deep_dive_classification_ui(filtered_samples, groups_by_id, tags, default_tags_list)
display(HTML(NO_SCROLL_STYLE))
display(ui, out)