TensorFlow Lite模型,云侧训练与安卓端侧推理
- 引言
- 一、云侧深度模型的训练代码
- 1.加载数据集的格式分析
- 1.1 从数据集加载的数据格式
- 1.2 对加载的数据进行处理
- 2. 深度模型搭建
- 3. 模型训练、评估、保存、转换
- 4. 模型预测
- 二、端侧安卓的推理代码
- 1. 安卓项目配置
- 1.1 app.gradle引入依赖
- 1.2 AndroidManifest.xml新增照相机权限
- 1.3 模型放置
- 2. 安卓端侧代码实现
- 2.1 布局文件
- 2.2 主函数文件
- 2.3 mnist数据集工具类
- 三、测试结果
- 参考网址
- 总结
引言
本次博客主要基于TensorFlow官网的demo进行学习,把学习过程的心得理解记录。其主要内容为TensorFlow云侧训练深度模型,并转换为手机端lite深度模型,最后在安卓手机端侧利用该模型进行推理得出预测结果。本次学习以mnist数据集为例,毕竟入手深度学习,mnist相当于学习编程语言的Hello World!利用的工具有Anaconda的Jupyter Notebook,和Android Studio。
一、云侧深度模型的训练代码
1.加载数据集的格式分析
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
class MNISTLoader():
def __init__(self):
mnist = tf.keras.datasets.mnist
(self.train_data, self.train_label), (self.test_data, self.test_label) = mnist.load_data()
self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)
self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)
self.train_label = self.train_label.astype(np.int32)
self.test_label = self.test_label.astype(np.int32)
self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]
导入TensorFlow和numpy包即可,我们会用到TensorFlow的Keras,它是用 Python 编写的高级神经网络 API,支持快速的构建网络框架。
1.1 从数据集加载的数据格式
先对MNISTLoader这个类进行分析,该类先加载了数据集数据,如下。
(train_data, train_label), (test_data, test_label) = mnist.load_data()
打印数据格式如下。
print("train_data:变量类型={0},变量形状={1},数据类型={2}".format(type(train_data), train_data.shape, train_data.dtype))
print("train_label:变量类型={0},变量形状={1},数据类型={2}".format(type(train_label), train_label.shape,train_label.dtype))
print("test_data:变量类型={0},变量形状={1},数据类型={2}".format(type(test_data), test_data.shape,test_data.dtype))
print("test_label:变量类型={0},变量形状={1},数据类型={2}".format(type(test_label), test_label.shape,test_label.dtype))
打印结果如下。
train_data:变量类型=<class 'numpy.ndarray'>,变量形状=(60000, 28, 28),数据类型=uint8
train_label:变量类型=<class 'numpy.ndarray'>,变量形状=(60000,),数据类型=uint8
test_data:变量类型=<class 'numpy.ndarray'>,变量形状=(10000, 28, 28),数据类型=uint8
test_label:变量类型=<class 'numpy.ndarray'>,变量形状=(10000,),数据类型=uint8
也就是说加载了60000张28×28的图片作为训练集,10000张28×28的图片作为测试集。其中的数据类型为uint8,取值为0~255。
接着又用了np.expand_dims()为图片的数据集进行了维度扩展,axis=-1表示在原来的变量形状的最后一个维度增加多一维,-1在python的索引通常都是表示最后一个索引。为什么要增加这么个维度呢?因为最后一个维度的数值表示图片的通道数。比如图片为RGB图时,最后一个维度的数值是3,而mnist的数据集为灰度图片,即单通道表示的图片,所以最后一个维度数值是1。train_label、test_label的数据则是用0~9表示对应数据集的各个类。
1.2 对加载的数据进行处理
对加载的数据进行的运算,主要包括对图片进行0~1数值的归一化,维度扩展,和数据类型转换;对标签值进行数值类型转换。注意对数值类型转换尤为重要,这跟后续在安卓端编程中需要用到什么数据类型来作为输入输出要对应起来。数据转换的语句如下。
train_data = np.expand_dims(train_data.astype(np.float32) / 255.0, axis=-1)
train_label = train_label.astype(np.int32)
再次运行如下语句查看数据格式
print("train_data:变量类型={0},变量形状={1},数据类型={2}".format(type(train_data), train_data.shape, train_data.dtype))
print("train_label:变量类型={0},变量形状={1},数据类型={2}".format(type(train_label), train_label.shape, train_label.dtype))
得到了新的数据格式,作为最终输入到模型进行训练的数据格式
train_data:变量类型=<class 'numpy.ndarray'>,变量形状=(60000, 28, 28, 1),数据类型=float32
train_label:变量类型=<class 'numpy.ndarray'>,变量形状=(60000,),数据类型=int32
2. 深度模型搭建
用Keras的Sequential来按顺序搭建模型,超级简单。需要添加的神经网络层,只需要add进来就可以了,Keras提供了很多常用的网络层。同时目前最新版本的Keras搭建模型时,每一层(包括首层输入层)的输入会根据上一层的输出自动推断,所以不需要input_shape参数。
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(
filters=32,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu
))
model.add(tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2))
model.add(tf.keras.layers.Conv2D(
filters=64,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu
))
model.add(tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2))
model.add(tf.keras.layers.Reshape(target_shape=(7 * 7 * 64,)))
model.add(tf.keras.layers.Dense(units=1024, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(units=10, activation=tf.nn.softmax))
3. 模型训练、评估、保存、转换
num_epochs = 20
batch_size = 50
learning_rate = 0.001
save_path = r"D:\code\jupyter\saved"
data_loader = MNISTLoader()
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=[tf.keras.metrics.sparse_categorical_accuracy]
)
model.fit(data_loader.train_data, data_loader.train_label,
epochs=num_epochs, batch_size=batch_size)
print(model.evaluate(data_loader.test_data, data_loader.test_label))
model.save(save_path)
converter = tf.lite.TFLiteConverter.from_saved_model(save_path)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()
open(os.path.join(save_path, "mnist_savedmodel_quantized.tflite"),
"wb").write(tflite_quant_model)
模型的损失函数采用了sparse_categorical_crossentropy
,则不同类的label直接用数字表示就可以了,如数字2的图片对应的label值为2。
模型训练时会动态给出结果如下:
1200/1200 [==============================] - 42s 35ms/step - loss: 0.0249 - sparse_categorical_accuracy: 0.9924
模型评估时会动态给出结果如下:
313/313 [==============================] - 2s 6ms/step - loss: 0.0375 - sparse_categorical_accuracy: 0.9881
最后模型mnist_savedmodel_quantized.tflite保存到了相应的路径save_path,同时,利用转换器转换为适合安卓手机端使用的量化模型。
4. 模型预测
im = data_loader.test_data[0].reshape(28, 28)
fig = plt.figure()
plotwindow = fig.add_subplot(111)
plt.axis('off')
plt.imshow(im, cmap='gray')
plt.show()
plt.close()
预测图片如下:
im = im.reshape(1, 28, 28, 1)
print("各个类的概率:{0}".format(model.predict(im)))
print("最大概率的类:{0}".format(model.predict_classes(im)))
关于模型的输入格式,由于我们在构建model的时候,首层Conv2D
没有使用data_format
参数,其默认输入格式为channels_last
,即batch_shape + (spatial_dim1, spatial_dim2, spatial_dim3, channels)
。所以reshape
的第一个数字是batch_size,最后一个数字是颜色通道数。
输出结果如下:
各个类的概率:[[9.9865129e-09 4.3024698e-08 5.2642001e-05 3.9080669e-06 2.2962024e-10
2.2086294e-07 5.7997704e-13 9.9992096e-01 2.2194282e-08 2.2103426e-05]]
最大概率的类:[7]
通过上面的例子可知,我们直接预测的输出是一个包含各个类的预测概率的数组,而通过model.predict_classes(im)
则会拿到预测数组里分值最高的数值对应的索引,model.predict_classes()
该方法将会被抛弃,提示使用np.argmax(model.predict(x), axis=-1)
二、端侧安卓的推理代码
安卓端实现通过调用相机获取图片输入,接着通过模型推理后打印日志输出结果。
1. 安卓项目配置
1.1 app.gradle引入依赖
android {
aaptOptions {
noCompress "tflite"
}
}
dependencies {
implementation 'org.tensorflow:tensorflow-lite:2.4.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.2.0'
}
1.2 AndroidManifest.xml新增照相机权限
<uses-permission android:name="android.permission.CAMERA" />
1.3 模型放置
把转换后的模型mnist_savedmodel_quantized.tflite放置到src\main\assets目录下,没该目录的需新建一个。
2. 安卓端侧代码实现
2.1 布局文件
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity"
android:orientation="vertical"
android:gravity="center">
<ImageView
android:id="@+id/camera_image"
android:layout_weight="1"
android:layout_width="wrap_content"
android:layout_height="0dp">
</ImageView>
<Button
android:id="@+id/open_camera_button"
android:text="打开相机"
android:layout_width="wrap_content"
android:layout_height="wrap_content">
</Button>
</LinearLayout>
2.2 主函数文件
package com.example.tensorflowlite;
import java.io.IOException;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.support.common.FileUtil;
import androidx.annotation.Nullable;
import androidx.appcompat.app.AppCompatActivity;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;
import android.Manifest;
import android.content.Intent;
import android.content.pm.PackageManager;
import android.graphics.Bitmap;
import android.os.Bundle;
import android.provider.MediaStore;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
public class MainActivity extends AppCompatActivity implements View.OnClickListener {
private static final String TAG = "MainActivity";
private static final String MODEL_PATH = "mnist_savedmodel_quantized.tflite";
private static final int CAMERA_PERMISSION_REQ_CODE = 1;
private static final int CAMERA_CAPTURE_REQ_CODE = 2;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Button button = findViewById(R.id.open_camera_button);
button.setOnClickListener(this);
}
private void openCamera() {
if (ContextCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
if (ActivityCompat.shouldShowRequestPermissionRationale(this, Manifest.permission.CAMERA)) {
Log.e(TAG, "error");
} else {
ActivityCompat.requestPermissions(this, new String[] {Manifest.permission.CAMERA},
CAMERA_PERMISSION_REQ_CODE);
}
} else {
Intent camera = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
startActivityForResult(camera, CAMERA_CAPTURE_REQ_CODE);
}
}
@Override
public void onClick(View v) {
switch (v.getId()) {
case R.id.open_camera_button:
openCamera();
break;
default:
Log.i(TAG, "nothing");
}
}
@Override
protected void onActivityResult(int requestCode, int resultCode, @Nullable Intent data) {
super.onActivityResult(requestCode, resultCode, data);
if (resultCode == RESULT_OK && requestCode == CAMERA_CAPTURE_REQ_CODE) {
Bundle extras = data.getExtras();
Bitmap bitmap = (Bitmap) extras.get("data");
ImageView cameraImage = findViewById(R.id.camera_image);
cameraImage.setImageBitmap(bitmap);
inference(bitmap);
}
}
private void inference(Bitmap bitmap) {
try {
Interpreter interpreter =
new Interpreter(FileUtil.loadMappedFile(this, MODEL_PATH), new Interpreter.Options());
float[][] labelProbArray = new float[1][10];
interpreter.run(MnistUtil.convertBitmapToByteBuffer(bitmap), labelProbArray);
for (int i = 0; i < labelProbArray[0].length; i++) {
Log.i(TAG, labelProbArray[0][i] + "");
}
} catch (IOException e) {
e.printStackTrace();
}
}
}
在主活动页中,通过点击底部打开相机按钮,拍照后返回主页,在主页显示照片图像同时日志打印推理的结果。主要的函数有:openCamera()
打开相机,onActivityResult(int requestCode, int resultCode, @Nullable Intent data)
等待相机回调结果获取图片,inference(Bitmap bitmap)
对图像进行推理,同时显示图像和打印推理结果。
2.3 mnist数据集工具类
package com.example.tensorflowlite;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import android.graphics.Bitmap;
public class MnistUtil {
public static ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) {
int dimImgWidth = 28;
int dimImgHeight = 28;
int dimBatchSize = 1;
int dimPixelSize = 1;
int numBytesPerChannel = 4;
int[] intValues = new int[dimImgWidth * dimImgHeight];
Bitmap scaleBitmap = Bitmap.createScaledBitmap(bitmap, dimImgWidth, dimImgHeight, true);
scaleBitmap.getPixels(intValues, 0, scaleBitmap.getWidth(), 0, 0, scaleBitmap.getWidth(),
scaleBitmap.getHeight());
ByteBuffer imgData =
ByteBuffer.allocateDirect(numBytesPerChannel * dimBatchSize * dimImgWidth * dimImgHeight * dimPixelSize);
imgData.order(ByteOrder.nativeOrder());
imgData.rewind();
int pixel = 0;
for (int i = 0; i < dimImgWidth; ++i) {
for (int j = 0; j < dimImgHeight; ++j) {
int val = intValues[pixel++];
addImgValue(imgData, val);
}
}
return imgData;
}
private static void addImgValue(ByteBuffer imgData, int val) {
int mImageMean = 0;
float mImageStd = 255.0f;
imgData.putFloat(((val & 0xFF) - mImageMean) / mImageStd);
}
}
注意这里的图像缓冲区大小为什么要乘以4:ByteBuffer.allocateDirect(numBytesPerChannel * dimBatchSize * dimImgWidth * dimImgHeight * dimPixelSize)
创建了一个4×1×28×28×1大小的缓冲区存储图片,因为缓冲区是以字节byte来存储的,通过计算,每个图像像素点最终转化为float型,而float型在java虚拟机中以4个字节存在,所以需要乘以4。在图像比较大的时候,缓冲区是很重要的。
三、测试结果
日志打印所有类的概率如下:
2021-07-08 10:13:03.325 15543-15543/com.example.tensorflowlite I/MainActivity: 0.0030371095
2021-07-08 10:13:03.325 15543-15543/com.example.tensorflowlite I/MainActivity: 0.003125498
2021-07-08 10:13:03.325 15543-15543/com.example.tensorflowlite I/MainActivity: 0.011447249
2021-07-08 10:13:03.325 15543-15543/com.example.tensorflowlite I/MainActivity: 0.055658735
2021-07-08 10:13:03.325 15543-15543/com.example.tensorflowlite I/MainActivity: 7.467345E-5
2021-07-08 10:13:03.325 15543-15543/com.example.tensorflowlite I/MainActivity: 0.05097304
2021-07-08 10:13:03.325 15543-15543/com.example.tensorflowlite I/MainActivity: 1.911169E-5
2021-07-08 10:13:03.325 15543-15543/com.example.tensorflowlite I/MainActivity: 0.8677362
2021-07-08 10:13:03.326 15543-15543/com.example.tensorflowlite I/MainActivity: 9.3077944E-4
2021-07-08 10:13:03.326 15543-15543/com.example.tensorflowlite I/MainActivity: 0.006997687
结果为0~9按顺序打印后,可以看到数字7的概率为0.8677362。
参考网址
官方安卓端侧代码
官方云侧训练模型代码
Keras中文文档
TensorFlow Lite中文文档
总结
你学会了吗?
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)