In [None]:
import tensorflow.compat.v1 as tf
import glob
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import colors
import numpy as np
from PIL import Image

import IPython
import vehicle.perception.ll3d.learning.topdown_segmentation_net as net_def
from tflight.layers.training_mode import set_training_mode
from vehicle.perception.ll3d.learning.topdown_segmentation_data import (
    make_loss_mask,
    make_sharded_inputs,
    map_data_and_label_anchors,
    map_training_data_proto_func)
from vehicle.perception.ll3d.learning.topdown_segmentation_options import \
    read_options
from vision.detection.training.tf_sseg_det.anchor_utils import anchors

from tensorflow.python.keras.utils.conv_utils import normalize_data_format
from tensorflow.python import debug as tf_debug

%matplotlib inline
%config InlineBackend.figure_format='retina'

tf.VERSION

In [None]:
DATASET_TRAIN_SHARDS = "/mnt/sun-pcp01/tds/stages/dataset/64ee8622-cbbb-4f91-9b72-9f5b22b203d7/remap_by_class/train_shards"

tf.disable_v2_behavior()
FLAGS = tf.app.flags.FLAGS
FLAGS.train_batch_size = 2
FLAGS.test_batch_size = 2

with tf.device("/GPU:0"):
    tf.set_random_seed(1)
    model = net_def.TopdownSegmentationModel()

    tf.keras.backend.set_image_data_format('channels_first')

    config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False
    )

    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.5

    sess = tf.Session(config=config)

    with tf.device("/cpu:0"):
        train_iter = make_sharded_inputs(
            DATASET_TRAIN_SHARDS, model.anchor_labeler, model.opt, 1, use_dataset=False)

        train_data, train_labels = model._postprocess_input_iter(train_iter)

        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        sess.run(train_iter.initializer)

        data = sess.run(train_data['data'])
        labels = sess.run(train_labels)

In [None]:
offsets = {
    "occupancy_index": 0,
    "mean_index": 1,
    "cov_index": -1,
    "semseg_index": 4,
    "freespace_index": 15,
    "raycast_index": 16,
    "aggregated_occupancy_index": 17,
    "binary_semseg_index": 14,
    "mean_intensity_index": 18,
    "mean_static_probability_index": 19
}
offsets = {k: v*7 if v > 0 else v for k, v in offsets.items()}
offsets

In [None]:
def plot_channel(sample, channel):
    channel = offsets[channel]
    fig, ax = plt.subplots(7, 2, figsize=(20, 70), sharex=True, sharey=True, tight_layout=True)
    
    for subchannel in range(7):
        ax[subchannel, 0].imshow(sample[channel+subchannel][0])
        ax[subchannel, 1].imshow(sample[channel+subchannel][1])
    return fig

In [None]:
def plot_label(sample, name, cmap=None):
    subchannels = sample.shape[0]
    fig, ax = plt.subplots(subchannels, 2, figsize=(20, 70), squeeze=False, sharex=True, sharey=True, tight_layout=True)
    
    for subchannel in range(subchannels):
        if cmap:
            ax[subchannel, 0].imshow(sample[subchannel][0], cmap=cmap)
            ax[subchannel, 1].imshow(sample[subchannel][1], cmap=cmap)
        else:
            ax[subchannel, 0].imshow(sample[subchannel][0])
            ax[subchannel, 1].imshow(sample[subchannel][1])
    return fig

In [None]:
plot_label(labels['label'][1], 'label')
plt.show()

cmap_orient = colors.ListedColormap([
'#0000ff', # Blue
'#0091ff', # Light blue
'#00ffd9', # Teal
'#ff91ff', # Pink
'#ff0000', # Red
'#ff9100', # Orange
'#d9ff00', # Yellow
'#47ff00', # Green
    
'#00007f', # Blue
'#00487f', # Light blue
'#007f6C', # Teal
'#7f487f', # Pink
'#7f0000', # Red
'#7f9100', # Orange
'#6C7f00', # Yellow
'#477f00', # Green
'#ffffff', # White
])
plot_label(labels['label_orientation'][1], 'label_orientation', cmap_orient)
plt.show()

for name in [
    'label_weights',
    'distractor_weight',
]:
    print(name)
    plot_label(labels[name][1], name)
    plt.show()

In [None]:
fig = plot_channel(data[0], 'freespace_index')
fig.show()