In [None]:
# import tensorflow.compat.v1 as tf
import tensorflow as tf
import glob
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import colors
import numpy as np
from PIL import Image
from IPython.display import clear_output

import IPython
import vehicle.perception.ll3d.learning.topdown_segmentation_net as net_def
from tflight.layers.training_mode import set_training_mode
from vehicle.perception.ll3d.learning.topdown_segmentation_data import (
    make_loss_mask,
    make_sharded_inputs,
    map_data_and_label_anchors,
    map_training_data_proto_func)
from vehicle.perception.ll3d.learning.topdown_segmentation_options import \
    read_options
from vision.detection.training.tf_sseg_det.anchor_utils import anchors

from tensorflow.python.keras.utils.conv_utils import normalize_data_format
from tensorflow.python import debug as tf_debug

from vehicle.perception.ll3d.learning.topdown_segmentation_net import (
    TopdownSegmentationModel,
    make_topdown_detector_net
)
from vehicle.perception.ll3d.learning.topdown_segmentation_net_utils import (
    TopdownDetectorNet
)

%matplotlib inline
%config InlineBackend.figure_format='retina'

tf.__version__

In [None]:
from typing import Any, List

import tensorflow as tf
from tensorflow.core.framework.graph_pb2 import GraphDef


def create_graph_def(
    inputs: List[tf.keras.Input], outputs: List[tf.Tensor]
) -> GraphDef:
    """
    Create a GraphDef for given inputs and outputs.
    This function uses the TF 2.x Keras API. It should be run in eager mode.
    A simple use case is:
    ```
    inputs = tf.keras.Input(shape=..., batch_size=1)
    outputs = my_model(inputs, training=False)
    graph_def = create_graph_def(inputs, outputs)
    ```
    The resulting GraphDef can be directly saved in binary form and passed to the UFF parser.
    A vision.TensorFlowWrapperOptions using the resulting GraphDef would look like
    ```
    opts = vision.TensorFlowWrapperOptions()
    opts.graph_file = "path/to/graph_def.pb"
    opts.input_tensor_name = [t.op.name for t in inputs]
    opts.output_tensor_names = [t.op.name for t in outputs]
    ...
    ```
    """
    model = tf.keras.Model(inputs=inputs, outputs=outputs)

    @tf.function(  # type: ignore[misc]
        input_signature=[
            tf.TensorSpec(t.shape, dtype=t.dtype, name=t.name) for t in inputs
        ]
    )
    def func(*args: tf.Tensor) -> Any:
        return model(args)

    concrete_func = func.get_concrete_function()
    frozen_func = tf.python.framework.convert_to_constants.convert_variables_to_constants_v2(
        concrete_func
    )
    graph_def = frozen_func.graph.as_graph_def()

    # The frozen_func.graph.inputs will be a bunch of Placeholders.
    apply_renames(graph_def, frozen_func.graph.inputs, inputs)
    # The frozen_func.graph.outputs will be a bunch of Identities.
    apply_renames(graph_def, frozen_func.graph.outputs, outputs)

    return graph_def


def apply_renames(
    graph_def: GraphDef,
    graph_tensors: List[tf.Tensor],
    desired_tensors: List[tf.Tensor],
) -> None:
    """
    Rename nodes in a GraphDef.
    The graph_tensors should be represented in graph_def.
    The desired_tensors are the tensors whose names we wish to use.
    """
    
    assert len(graph_tensors) == len(desired_tensors)
    renames = {
        t.op.name: n.op.name for t, n in zip(graph_tensors, desired_tensors)
    }

    for node in graph_def.node:
        node.name = renames.get(node.name, node.name)
        for idx, old_input in enumerate(node.input):
            node.input[idx] = renames.get(old_input, old_input)

In [None]:
opt_pbjson = "/home/svolta/driving2/vehicle/perception/ll3d/learning/data/topdown_segmentation_options.pbjson"
model_config = TopdownSegmentationModel(options_pbjson=opt_pbjson)

clear_output()

In [None]:
tf.compat.v1.reset_default_graph()
model = TopdownDetectorNet(
    model_config.seg_objectives,
    model_config.class_targets_per_location,
    model_config.box_targets_per_location,
    model_config.box_objectives,
    learned_deconvs=False)
inputs = tf.keras.Input(shape=[149, 2, 256, 256], batch_size=1)
outputs_dict = model(inputs, training=False)

outputs_list = []
for output in outputs_dict:
    for k, v in output.items():
        outputs_list.append(v)
        
graph_def = create_graph_def([inputs], outputs_list)

In [None]:
tf.io.write_graph(graph_or_graph_def=graph_def,
                      logdir="/tmp/multi_res_tds/try_train",
                      name="frozen_model.pb",
                      as_text=False)

print(F"{len(graph_def.node)} ops in the final graph.")
for n in graph_def.node:
    print(n.name)