项目准备
- 训练并量化好的TFLite模型: model.tflite
- 需要使用TFLite的安卓工程
- 开发用手机
部署流程
- 在Gradle中配置TFLite相关库,在build.gradle中补充依赖库,具体如下(版本2.8.0)
dependencies {
...
implementation 'com.github.bumptech.glide:glide:4.13.2'
implementation 'org.tensorflow:tensorflow-lite:2.8.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.8.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.3.0'
...
}
配置完成后,注意在gradle.property中开启testOnly,否则USB调试失败
# Error: INSTALL_FAILED_TEST_ONLY'
android.injected.testOnly = false
- 将TFLite模型作为附件部署在工程中,位置为/app/src/main/assets
- 在代码中引入TFLite模型并进行推理
3.1. 引入相关包
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
import org.tensorflow.lite.gpu.GpuDelegate;
3.2 载入模型
// 读取ByteBuffer
private MappedByteBuffer loadModelFile(String model) throws IOException{
AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
// 将ByteBuffer绑定到TFLite推理器
MappedByteBuffer modelFile = loadModelFile(model);
// 设置推理器选项
Interpreter.Options options = new Interpreter.Options();
GpuDelegate delegate = new GpuDeledate();
// 以下选项根据需求调整
options.setNumThreads(4); // 4 CPU threads
options.addDelegate(deledate); // 添加GPU支持
options.setUseNNAPI(true); // 使用NNAPI推理
tflite = new Interpreter(modelFile, options); // 用来推理的TFLite模型
3.3 配置输入输出节点
// 获取输入节点数据类型
DataType imageDataType = tflite.getInputTensor(0).dataType();
// 配置输入数据Buffer
TensorImage inputImageBuffer = new TensorImage(imageDataType);
inputImageBuffer.load(bmp);
// 调整输入节点信息
int[] inputShape = {1, bmp.getHeight(), bmp.getWidth(), 3};
tflite.resizeInput(tflite.getInputTensor(0).index(), inputShape);
// 类似的,调整输出节点数据
// 此处做图像增强,输入输出大小一致
DataType probabilityDataType = tflite.getOutputTensor(0).dataType();
int[] probabilityShape = {1, bmp.getWidth(), bmp.getHeight(), 3};
TensorBuffer outputProbabilityBuffer = TensorBuffer.createFixedSize(probabilityShape, probabilityDataType);
// 得到输入输出的Buffer
ByteBuffer inputs = inputImageBuffer.getBuffer();
ByteBuffer outputs = outputProbabilityBuffer.getBuffer();
// 使用推理器进行推理
tflite.run(inputs, outputs);
// 读取输出Buffer信息
float[] results = outputProbabilityBuffer.getFloatArray();
综上,连接手机即可进行USB调试
文章出处登录后可见!
已经登录?立即刷新