A Friendly Introduction to TensorRT: Building Engines

Vilson Rodrigues
5 min readMay 6, 2024

--

Learn how to export models to an efficient model format

Credits by DALL-E 3.

NVIDIA TensorRT is an SDK for optimizing trained deep learning models to enable high-performance inference. TensorRT contains a Deep Learning inference optimizer for trained deep learning models, and a runtime for execution.

Key Components

Network Definition

A representation of a model in TensorRT. A network definition is a graph of tensors and operators.

Builder

TensorRT’s model optimizer. The builder takes as input a network definition, performs device-independent and device-specific optimizations, and creates an engine.

Engine

A representation of a model that has been optimized by the TensorRT builder.

Logger

Associated with the builder and engine to capture errors, warnings, and other information during the build and inference phases.

ONNX parser

Takes a converted PyTorch/TensorFlow trained model into the ONNX format as input and populates a network object in TensorRT.

Plan

An optimized inference engine in a serialized format. To initialize the inference engine, the application will first deserialize the model from the plan file. A typical application will build an engine once, and then serialize it as a plan file for later use.

Runtime

The component of TensorRT that performs inference on a TensorRT engine. The runtime API supports synchronous and asynchronous execution, profiling, and enumeration and querying of the bindings for an engine inputs and outputs.

Compute Precision

TensorRT supports compute precision:

  • TF32
  • FP32
  • FP16
  • FP8
  • BF16
  • INT64
  • INT32
  • UINT8
  • INT8
  • INT4
  • BOOL

WorkFlow

A common flow in TensorRT applications is build and export the engine in an offline process (can be slow). TensorRT optimize an engine to specific hardware. By default, TensorRT engines are only compatible with the type of device where they were built.

To inference load an engine, create a context execution and execute.

TensorRT Flow. Credits by NVIDIA.

Codes

This article use the newest TensorRT version: v10. We will build an engine to ResNet50 pytorch-based. Codes avaliable on github.

Install deps

pip install torch>=2.3.0 timm>=0.9.0 torchvision onnx tensorrt>=10.0.1

1. Export Model to ONNX

Download pre-trained ResNet using TIMM (pytorch-image-models)

import timm

model = timm.create_model("resnet50.a1_in1k", pretrained=True)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

Define model configs to export model

import torch

channels = 3
width = 224
height = 224
input_model = [channels, height, width]
max_batch_size = 4

Define the model input

shape_input_model = [max_batch_size] + input_model
# Generate a random tensor
tensor_input = torch.randn(shape_input_model)

Pytorch has two way to export model to ONNX: dynamo and script based. Dynamo preserves the dynamic nature of the model instead of using traditional static tracing techniques. But dynamo export in Pytorch 2.3.0 is still Beta. To apply TensorRT onnx parser the following exception is raised the input ~input_name~ is duplicate.

Here we will go export model with dynamic axes (Batch, 3, Height, Width).

# https://pytorch.org/docs/stable/onnx.html
if tensor_input.size(0) > 1:
dynamic = {
"inputs": {0: "batch", 2: "height", 3: "width"},
"outputs": {0: "batch", 1: "logits"},
}
else:
dynamic = None

opset_version = 18
f = "model.onnx"

torch.onnx.export(
model,
tensor_input,
f,
verbose=True,
input_names=["inputs"],
output_names=["outputs"],
opset_version=opset_version,
do_constant_folding=True, # torch>=1.12 require do_constant_folding=False
dynamic_axes=dynamic,
)

2. Build Engines

The first step in TensorRT apps is define Logger. Here I am choose Verbose mode.

import tensorrt as trt 

TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)

Create a builder and a config builder

builder = trt.Builder(TRT_LOGGER)
config = builder.create_builder_config()
# Set cache
cache = config.create_timing_cache(b"")
config.set_timing_cache(cache, ignore_mismatch=False)

Max Workspace define a memory limit to TensorRT layers. From documentation: One important property is the maximum workspace size. Layer implementations often require a temporary workspace, and this parameter limits the maximum size that any layer in the network can use. If insufficient workspace is provided, it is possible that TensorRT will not be able to find an implementation for a layer. By default, the workspace is set to the total global memory size of the given device; restrict it when necessary, for example, when multiple engines are to be built on a single device.

# https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#build_engine_python
# max_workspace = (1 << 30)
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, max_workspace)

TensorRT has two batch modes: implicit and explicit. In implicit batch mode, every tensor has an implicit batch dimension and all other dimensions must have constant length. In explicit batch mode, all dimensions are explicit and can be dynamic, that is their length can change at execution time. Many new features, such as dynamic shapes and loops, are available only in this mode. It is also required by the ONNX parser. In TensorRT 10 implicit batch is deprecated, explict batch is default is not possible disable.

# https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#version-compat
# https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#explicit-implicit-batch
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(flag)
parser = trt.OnnxParser(network, TRT_LOGGER)

Parse ONNX model

path_onnx_model = "./model.onnx"
with open(path_onnx_model, "rb") as f:
if not parser.parse(f.read()):
print(f"ERROR: Failed to parse the ONNX file {path_onnx_model}")
for error in range(parser.num_errors):
print(parser.get_error(error))
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]

Let’s look at what’s inside the inputs and outputs

for input in inputs:
print(f"Model {input.name} shape: {input.shape} {input.dtype}")
for output in outputs:
print(f"Model {output.name} shape: {output.shape} {output.dtype}")
# https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#work_dynamic_shapes
>> Model inputs shape: (-1, 3, -1, -1) DataType.FLOAT
>> Model outputs shape: (-1, 1000) DataType.FLOAT

The -1 indicates that dimension is runtime dimension in build phase is not necessary specify to TensorRT the real dimensions just in Runtime.

To explict batch, set min, opt and max shape. This help to TensorRT to search better optimizations. Each shape optimization is applied in each input. In this case, just one, the images.

if max_batch_size > 1:
# https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#opt_profiles
profile = builder.create_optimization_profile()
min_shape = [1] + shape_input_model[-3:]
opt_shape = [int(max_batch_size/2)] + shape_input_model[-3:]
max_shape = shape_input_model
for input in inputs:
profile.set_shape(input.name, min_shape, opt_shape, max_shape)
config.add_optimization_profile(profile)

TensorRT has three reduce precision options: FP16, INT8 and TF32 (Tensor Cores). Note that TensorRT will still choose a higher-precision kernel if it results in overall lower runtime, or if no low-precision implementation exists.

# https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#reduced-precision
# https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html#hardware-precision-matrix
half = True
int8 = False
if half:
config.set_flag(trt.BuilderFlag.FP16)
elif int8:
config.set_flag(trt.BuilderFlag.INT8)

Strip Weights help to create and optimize an engine without unnecessary weights. On inference load engine and refit with onnx weights. It`s more fast and no duplicate weights.

# https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#weightless-build
# https://github.com/NVIDIA/TensorRT/tree/main/samples/python/sample_weight_stripping
strip_weights = False
if strip_weights:
config.set_flag(trt.BuilderFlag.STRIP_PLAN)
# To remove strip plan from config
# config.flags &= ~(1 << int(trt.BuilderFlag.STRIP_PLAN))

Build and save the engine

engine_bytes = builder.build_serialized_network(network, config) 
engine_path = "./model.engine"
with open(engine_path, "wb") as f:
f.write(engine_bytes)

3. Load Engines

Now load the engine. Before, if you is on colab, execute:

# Colab bug...
import locale
locale.getpreferredencoding = lambda: "UTF-8"
def load_stripped_engine_and_refit(
engine_path: str,
onnx_model_path: str,
) -> trt.ICudaEngine:
runtime = trt.Runtime(TRT_LOGGER)
with open(engine_path, "rb") as engine_file:
engine = runtime.deserialize_cuda_engine(engine_file.read())
refitter = trt.Refitter(engine, TRT_LOGGER)
parser_refitter = trt.OnnxParserRefitter(refitter, TRT_LOGGER)
assert parser_refitter.refit_from_file(onnx_model_path)
assert refitter.refit_cuda_engine()
return engine

def load_normal_engine(engine_path: str) -> trt.ICudaEngine:
runtime = trt.Runtime(TRT_LOGGER)
with open(engine_path, "rb") as plan:
engine = runtime.deserialize_cuda_engine(plan.read())
return engine
if strip_weights:
engine = load_stripped_engine_and_refit(engine_path, path_onnx_model)
else:
engine = load_normal_engine(engine_path)

Ok, it’s was hard but fun. In the next post we will have efficient inference with TensorRT. Await for it. Thanks everyone. See you later.

References

--

--