接上文
3.代码分析
main函数首先将入参写入参数结构体
Settings s;
struct Settings {
bool verbose = false;
bool accel = false;
bool input_floating = false;
bool profiling = false;
int loop_count = 1;
float input_mean = 127.5f;
float input_std = 127.5f;
string model_name = "./mobilenet_quant_v1_224.tflite";
string input_bmp_name = "./grace_hopper.bmp";
string labels_file_name = "./labels.txt";
string input_layer_type = "uint8_t";
int number_of_threads = 4;
};
运行执行函数 RunInference(&s);
【背景知识:https://blog.csdn.net/jILRvRTrc/article/details/80553561】
RunInference();
1)建立模型model = tflite::FlatBufferModel::BuildFromFile(s->model_name.c_str());
2)建立OpResolver 用于指向每个node的操作函数 tflite::ops::builtin::BuiltinOpResolver resolver;
3)建立解释器 tflite::InterpreterBuilder(*model, resolver)(&interpreter);
4)对解释器进行参数设置包括
interpreter->UseNNAPI(s->accel); 是否使用NNAPI加速
interpreter->SetNumThreads(s->number_of_threads);设置运行线程数
bmp文件读入并进行必要的resize
int image_width = 224;
int image_height = 224;
int image_channels = 3;
std::vector<uint8_t> in = read_bmp(s->input_bmp_name, &image_width,
&image_height, &image_channels, s);
resize<float>(interpreter->typed_tensor<float>(input), in.data(),
image_height, image_width, image_channels, wanted_height,
wanted_width, wanted_channels, s);
打印运行参数相关信息
if (s->verbose) PrintInterpreterState(interpreter.get());
// output something like
// time (ms) , Node xxx, OpCode xxx, symblic name
// 5.352, Node 5, OpCode 4, DEPTHWISE_CONV_2D
运行模型及获得运行时间
struct timeval start_time, stop_time;
gettimeofday(&start_time, nullptr);
for (int i = 0; i < s->loop_count; i++) {
if (interpreter->Invoke() != kTfLiteOk) {
LOG(FATAL) << "Failed to invoke tflite!\n";
}
}
gettimeofday(&stop_time, nullptr);
打印profiling
if (s->profiling) {
profiler->StopProfiling();
auto profile_events = profiler->GetProfileEvents();
for (int i = 0; i < profile_events.size(); i++) {
auto op_index = profile_events[i]->event_metadata;
const auto node_and_registration =
interpreter->node_and_registration(op_index);
const TfLiteRegistration registration = node_and_registration->second;
PrintProfilingInfo(profile_events[i], op_index, registration);
}
}
获取输出
int output = interpreter->outputs()[0];
switch (interpreter->tensor(output)->type) {
case kTfLiteFloat32:
get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size,
num_results, threshold, &top_results, true);
加载label并显示对应输出结果
std::vector<string> labels;
size_t label_count;
if (ReadLabelsFile(s->labels_file_name, &labels, &label_count) != kTfLiteOk)
exit(-1);
for (const auto& result : top_results) {
const float confidence = result.first;
const int index = result.second;
LOG(INFO) << confidence << ": " << index << " " << labels[index] << "\n";
}
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)