Build a image preprocessing model using Pytorch and integrate into your model using ONNX

Vilson Rodrigues
3 min readSep 6, 2023

--

Reduce your project’s dependencies with ONNX

Credits: Stable Diffusion 2-1

Traditional image preprocessing tools like OpenCV and PILLOW do not support batch and GPU processing. To fix it we can use tools as TorchVision to build a image preprocessing pipeline.

But having to use the entire Pytorch suite in production can be costly. An alternative is to export to ONNX, and we’ll explore composing models to integrate them.

To reproduce I created a Colab Notebook. Look here.

1. Export Image Model to ONNX

ResNet50 was my choose to base model. We’ll add dynamic axes to batch size when exporting it. The input to ResNet50 is (3, 224, 224) in the format (C, H, W).

Basic dependencies:

!pip install torch torchvision onnx

Gen ONNX model:

import torch
import torchvision.models as models

model = models.resnet50(pretrained=True)

model.eval()

input_dim = (1, 3, 224, 224)

dummy_input = torch.randn(input_dim)

onnx_path = 'resnet50.onnx'

dynamic = {'input': {0: 'batch'}, 'output': {0: 'batch'}}

torch.onnx.export(model,
dummy_input,
onnx_path,
verbose=True,
input_names=['input'],
output_names=['output'],
dynamic_axes=dynamic,
opset_version=17)

2. Build a Image Preprocessing Model

The model input will be (B, H, W, C) and the output (B, C, H, W). The operations I chose as a demonstration are:

  • Permute/Transpose channels
  • Normalize
  • Resize

We will to explore dynamic axes. Batch, Height and Width will accept any image input, requiring only 3 color channels.

import torch
import torchvision.transforms as transforms

class TransposeResizeNormalize(torch.nn.Module):
def __init__(
self,
resize,
mean_values=(127, 127, 127),
scale_factor=[128, 128, 128]
):
super(TransposeResizeNormalize, self).__init__()
self.resize = transforms.Resize(
resize,
antialias=True,
interpolation=transforms.InterpolationMode.NEAREST)
self.normalize = transforms.Normalize(
mean=mean_values,
std=scale_factor)

def forward(self, x):
x = x.permute(0, 3, 1, 2)
x = self.resize(x)
x = self.normalize(x)
return x

model_prep = TransposeResizeNormalize(resize=(224, 224))

# N, H, W, C
dummy_input = torch.randn(1, 123, 555, 3)

dynamic = {'input': {0: 'batch', 1: 'height', 2: 'width'},
'output': {0 : 'batch'}}

path_export_model_prep = 'prep.onnx'

torch.onnx.export(model_prep,
dummy_input,
path_export_model_prep,
opset_version=17,
do_constant_folding=True,
input_names = ['input'],
output_names=['output'],
dynamic_axes=dynamic,
verbose=True)

To use onnx-sim is a good pratice to simplify ONNX models.

!pip install onnxsim
!onnxsim prep.onnx prep-sim.onnx

LGTM 😎

Reducing by 2x model size

3. Compose Models

The ONNX API provides Compose. We only need to make a map between the output of the preprocessing model and the image model. Attention, both models need to be on the same version of IR and Opset.

First, load models:

import onnx
from onnx import compose
from onnx.compose import merge_models

model = onnx.load('resnet50.onnx')
prep = onnx.load('prep-sim.onnx')
# to force update version, could be:

# model.opset_import[0].version = 17
# model.ir_version = 8

Add prefix to resolve name problems:

# add prefix, resolve names conflits
prep_with_prefix = compose.add_prefix(prep, prefix="prep_")

model_prep = compose.merge_models(
prep_with_prefix,
model,
io_map=[('prep_output', # output prep model
'input')]) # input resnet50 model

Serialize:

onnx.save_model(model_prep, 'resnet_prep.onnx')

4. Test

!pip install onnxruntime

Define a random input:

import numpy as np

input = np.random.rand(1, 213, 152, 3)

Load and predict!

import onnxruntime as ort

model_path = 'resnet_prep.onnx'
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
outputs = session.run(None, {input_name: input.astype(np.float32)})
outputs[0].shape
#> (1, 1000)

🤠 Thanks, see you later!

--

--