项目准备
- 训练并量化好的TFLite模型: model.tflite
- 需要使用TFLite的安卓工程
- 开发用手机
部署流程
- 在Gradle中配置TFLite相关库,在build.gradle中补充依赖库,具体如下(版本2.8.0)
dependencies {
...
implementation 'com.github.bumptech.glide:glide:4.13.2'
implementation 'org.tensorflow:tensorflow-lite:2.8.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.8.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.3.0'
...
}
配置完成后,注意在gradle.property中开启testOnly,否则USB调试失败
# Error: INSTALL_FAILED_TEST_ONLY'
android.injected.testOnly = false
- 将TFLite模型作为附件部署在工程中,位置为/app/src/main/assets
- 在代码中引入TFLite模型并进行推理
3.1. 引入相关包
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
import org.tensorflow.lite.gpu.GpuDelegate;
3.2 载入模型
private MappedByteBuffer loadModelFile(String model) throws IOException{
AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
MappedByteBuffer modelFile = loadModelFile(model);
Interpreter.Options options = new Interpreter.Options();
GpuDelegate delegate = new GpuDeledate();
options.setNumThreads(4);
options.addDelegate(deledate);
options.setUseNNAPI(true);
tflite = new Interpreter(modelFile, options);
3.3 配置输入输出节点
DataType imageDataType = tflite.getInputTensor(0).dataType();
TensorImage inputImageBuffer = new TensorImage(imageDataType);
inputImageBuffer.load(bmp);
int[] inputShape = {1, bmp.getHeight(), bmp.getWidth(), 3};
tflite.resizeInput(tflite.getInputTensor(0).index(), inputShape);
DataType probabilityDataType = tflite.getOutputTensor(0).dataType();
int[] probabilityShape = {1, bmp.getWidth(), bmp.getHeight(), 3};
TensorBuffer outputProbabilityBuffer = TensorBuffer.createFixedSize(probabilityShape, probabilityDataType);
ByteBuffer inputs = inputImageBuffer.getBuffer();
ByteBuffer outputs = outputProbabilityBuffer.getBuffer();
tflite.run(inputs, outputs);
float[] results = outputProbabilityBuffer.getFloatArray();
综上,连接手机即可进行USB调试
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)