简化您的 ONNX 模型
项目描述
ONNX 简化器
ONNX 很棒,但有时过于复杂。
背景
有一天,我想将以下简单的 reshape 操作导出到 ONNX:
import torch
class JustReshape(torch.nn.Module):
def __init__(self):
super(JustReshape, self).__init__()
def forward(self, x):
return x.view((x.shape[0], x.shape[1], x.shape[3], x.shape[2]))
net = JustReshape()
model_name = 'just_reshape.onnx'
dummy_input = torch.randn(2, 3, 4, 5)
torch.onnx.export(net, dummy_input, model_name, input_names=['input'], output_names=['output'])
这个模型中的输入形状是静态的,所以我期望的是
但是,我得到了以下复杂的模型:
我们的解决方案
ONNX Simplifier 用于简化 ONNX 模型。它推断整个计算图,然后用它们的常量输出(又名常量折叠)替换冗余运算符。
网页版
我们在 convertmodel.com 上发布了 ONNX Simplifier。它开箱即用,不需要任何安装。请注意,它在本地浏览器中运行,您的模型是完全安全的。
蟒蛇版本
pip3 install -U pip && pip3 install onnxsim
然后
onnxsim input_onnx_model output_onnx_model
有关更多高级功能,请尝试以下命令获取帮助消息
onnxsim -h
示范
复杂模型 与其简化版本的整体比较 :
脚本内工作流程
如果您想在另一个脚本中嵌入 ONNX 简化器 python 包,就这么简单。
import onnx
from onnxsim import simplify
# load your predefined ONNX model
model = onnx.load(filename)
# convert model
model_simp, check = simplify(model)
assert check, "Simplified ONNX model could not be validated"
# use model_simp as a standard ONNX model object
您可以在onnxsim/onnx_simplifier.py中查看 API 的更多详细信息
使用 ONNX 简化器的项目
聊天
我们为ONNX创建了一个中文QQ群!
ONNX QQ群(中文):1021964010,验证码:nndab。欢迎加入!
对于英语用户,我在ONNX Slack上很活跃。你可以在那里找到和我(大鹊仙)聊天。