想要修改onnx模型文件的节点名称,要么在最初的pytorch代码里去改,要么就直接在onnx模型文件里改。
而我这里直接在onnx模型文件改,我有一个onnx文件,输出节点的名字是这样的:
这不改就看着真难受,那么就用python改:
import onnx
model = onnx.load("model.onnx")
idx_start = 0
for input in model.graph.input:
for node in model.graph.node:
for i, name in enumerate(node.input):
if name == input.name:
node.input[i] = "input_" + str(idx_start)
input.name = "input_" + str(idx_start)
idx_start += 1
idx_start = 0
for output in model.graph.output:
for node in model.graph.node:
for i, name in enumerate(node.output):
if name == output.name:
node.output[i] = "output_" + str(idx_start)
output.name = "output_" + str(idx_start)
idx_start += 1
onnx.save(model, "modified_model.onnx")
改完后:
其实修改其他节点的名称也可以这样去做,注意修改的是要关注到前后连接的节点。
验证一下改的东西对不对吧:
import numpy as np
import onnxruntime as ort
img = np.load("img.npy")
session = ort.InferenceSession("modified_model.onnx")
output_new = session.run(None, {"input_0": img})
session = ort.InferenceSession("model.onnx")
output_old = session.run(None, {"x": img})
print(np.allclose(output_new, output_old))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)