裁剪ONNX模型

按需裁剪ONNX模型,仅保留所需要的部分。

裁剪ONNX模型
封面图片由Midjourney生成

最近遇到一个模型量化与转换问题,涉及到了对ONNX运算图的修改,记录如下。

背景

目标模型为物体检测模型,由PyTorch训练得到,部署时转为ONNX格式。模型输入为预处理后的图像,输出为像素坐标与置信度。这个模型在移植到瑞芯微的嵌入式设备时遇到了性能问题。一次完整的推演耗时大约150ms,完全无法接受。

为了加快模型推演,首先考虑将模型权重量化为8bit精度。量化后发现模型输出的坐标数据是对的,但是置信度相关的数值一律为0。考虑坐标数值为3位整数,而置信度为0~1的浮点数,两者差距太大,有可能是8bit量化出了问题。模型最终输出的坐标值涉及到了模型预测值、anchor与图像尺寸。其中模型预测值经过sigmoid函数之后是0 ~ 1之间的数值。所以,量化出问题,很有可能是ONNX模型内置的后处理部分造成的。

要解决这个问题,需要将后处理部分去掉——这意味着要更改模型的计算图。