我正在尝试从 onnx 模型中提取输入层、输出层及其形状等数据。我知道有 python 接口可以做到这一点。我想做类似的事情code https://stackoverflow.com/questions/56734576/find-input-shape-from-onnx-file但在c++中。我还粘贴了链接中的代码。我已经在 python 中尝试过了,它对我有用。我想知道是否有 C++ API 可以做同样的事情。
import onnx
model = onnx.load(r"model.onnx")
# The model is represented as a protobuf structure and it can be accessed
# using the standard python-for-protobuf methods
# iterate through inputs of the graph
for input in model.graph.input:
print (input.name, end=": ")
# get type of input tensor
tensor_type = input.type.tensor_type
# check if it has a shape:
if (tensor_type.HasField("shape")):
# iterate through dimensions of the shape:
for d in tensor_type.shape.dim:
# the dimension may have a definite (integer) value or a symbolic identifier or neither:
if (d.HasField("dim_value")):
print (d.dim_value, end=", ") # known dimension
elif (d.HasField("dim_param")):
print (d.dim_param, end=", ") # unknown dimension with symbolic name
else:
print ("?", end=", ") # unknown dimension with no name
else:
print ("unknown rank", end="")
print()
我也是c++新手,请帮助我。
ONNX 格式本质上是protobuf https://developers.google.com/protocol-buffers,因此可以用协议编译器支持的任何语言打开。
如果是 C++
- 获取 onnx 原型文件(onnx 仓库 https://github.com/onnx/onnx/blob/master/onnx/onnx.proto3)
- 编译它
protoc --cpp_out=. onnx.proto3
命令。它将生成onnx.proto3.pb.cc
and onnx.proto3.pb.h
files
- 链接protobuf库(可能是protobuf-lite),生成的cpp文件和以下代码:
#include <fstream>
#include <cassert>
#include "onnx.proto3.pb.h"
void print_dim(const ::onnx::TensorShapeProto_Dimension &dim)
{
switch (dim.value_case())
{
case onnx::TensorShapeProto_Dimension::ValueCase::kDimParam:
std::cout << dim.dim_param();
break;
case onnx::TensorShapeProto_Dimension::ValueCase::kDimValue:
std::cout << dim.dim_value();
break;
default:
assert(false && "should never happen");
}
}
void print_io_info(const ::google::protobuf::RepeatedPtrField< ::onnx::ValueInfoProto > &info)
{
for (auto input_data: info)
{
auto shape = input_data.type().tensor_type().shape();
std::cout << " " << input_data.name() << ":";
std::cout << "[";
if (shape.dim_size() != 0)
{
int size = shape.dim_size();
for (int i = 0; i < size - 1; ++i)
{
print_dim(shape.dim(i));
std::cout << ", ";
}
print_dim(shape.dim(size - 1));
}
std::cout << "]\n";
}
}
int main(int argc, char **argv)
{
std::ifstream input("mobilenet.onnx", std::ios::ate | std::ios::binary); // open file and move current position in file to the end
std::streamsize size = input.tellg(); // get current position in file
input.seekg(0, std::ios::beg); // move to start of file
std::vector<char> buffer(size);
input.read(buffer.data(), size); // read raw data
onnx::ModelProto model;
model.ParseFromArray(buffer.data(), size); // parse protobuf
auto graph = model.graph();
std::cout << "graph inputs:\n";
print_io_info(graph.input());
std::cout << "graph outputs:\n";
print_io_info(graph.output());
return 0;
}
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)