In [1]:
"""
Please save notebooks with important results to your own folder in JupyterHub.
See README for usage instructions:
https://git.zooxlabs.com/zooxco/driving/tree/master/mined_metric/jupyter/README.md
"""
from mined_metric.jupyter.utils.data_access_util import get_experiment_data, chum_summary_stats

EXPERIMENT_ID = "EXPERIMENT_ID_PLACEHOLDER"

In [2]:
# Retrieve all experiment data and metadata
exp_data = get_experiment_data(EXPERIMENT_ID)
if "exception" in exp_data:
    print("Invalid data is returned: %s" % exp_data["exception"])

In [3]:
# Calculate basic Chum URI summary statistics
chum_uris = [d["chum_uri"] for d in exp_data["meta"]]
stats = chum_summary_stats(chum_uris)
print("Average duration of Chum URIs: {:.2f}s".format(stats["avg_duration_s"]))
print("Vehicles: {}".format(stats["vehicles_included"]))

Average duration of Chum URIs: 5.00s
Vehicles: ['kitt_01', 'kitt_02', 'kitt_03', 'kitt_06', 'kitt_07', 'kitt_09', 'kitt_10']


In [4]:
class Sample:
    def __init__(self, score, label, classification):
        self.score = score
        self.label = label
        self.classification = classification
    
    def to_str(self):
        return 'score: {}, label: {}, classification: {}'.format(str(self.score), str(self.label), str(self.classification))

samples = []
for d in exp_data["meta"]:
    sample = Sample(d['score'], d['label'], d['classification'])
    samples.append(sample)

print('num of samples: {}'.format(len(samples)))


num of samples: 4020


In [5]:
!pip install plotly

[33mDEPRECATION: Python 2.7 reached the end of its life on January 1st, 2020. Please upgrade your Python as Python 2.7 is no longer maintained. A future version of pip will drop support for Python 2.7. More details about Python 2 support in pip, can be found at https://pip.pypa.io/en/latest/development/release-process/#python-2-support[0m
Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.zooxlabs.com/simple/


In [10]:
import numpy as np
import matplotlib.pyplot as plt
import plotly as plotly
import plotly.graph_objects as go

class Sample:
    def __init__(self, score, label, classification):
        self.score = score
        self.label = label
        self.classification = classification

    def __lt__(self, other):
        return self.score > other.score
    
    def to_str(self):
        return 'score: {}, label: {}, classification: {}'.format(str(self.score), str(self.label), str(self.classification))

samples = []
for d in exp_data["meta"]:
    sample = Sample(d['score'], d['label'], d['classification'])
    samples.append(sample)

print('num of samples: {}'.format(len(samples)))

class PRPoint:

    def __init__(self, precision, recall, score, num_results):
        self.precision = precision
        self.recall = recall
        self.score = score
        self.num_results = num_results

pr_points = []
npos = 0

samples.sort()
for sample in samples:
    if sample.label:
        npos += 1

print('npos ' + str(npos))

tp = 0
fp = 0

for sample in samples:
    if sample.label:
        tp += 1
    else:
        fp += 1
    
    if len(pr_points) > 0 and pr_points[-1].score == sample.score:
        pr_points.pop()
        
    pr_point = PRPoint(tp / float(tp + fp), tp / float(npos), sample.score, tp + fp)
    pr_points.append(pr_point)

len_points = len(pr_points)
print('num of pr_points: {}'.format(len(pr_points)))

dtype = [('precision', float), ('recall', float), ('score', float), ('num_results', int)]
points = np.empty(len_points, dtype=dtype)
for i in range(len_points):
    pr_point = pr_points[i]
    points[i] = (pr_point.precision, pr_point.recall, pr_point.score, pr_point.num_results)

if len(points) > 0:      
    fig = go.Figure()
    fig = fig.add_trace(go.Scatter(
        x=points['recall'], 
        y=points['precision']
    ))
    fig.update_layout(
        title='PR Curve',
        xaxis_title='Recall',
        yaxis_title='Precision'
    )
    fig.update_layout(
        autosize=False,
        width=400,
        height=400,
    )
    fig.show()
    
    





num of samples: 4020
npos 1134
num of pr_points: 3053
