Android部署TFLite模型并启用GPU加速

项目准备

  1. 训练并量化好的TFLite模型: model.tflite
  2. 需要使用TFLite的安卓工程
  3. 开发用手机

部署流程

  1. 在Gradle中配置TFLite相关库,在build.gradle中补充依赖库,具体如下(版本2.8.0)
dependencies {
    ...
    implementation 'com.github.bumptech.glide:glide:4.13.2'
    implementation 'org.tensorflow:tensorflow-lite:2.8.0'
    implementation 'org.tensorflow:tensorflow-lite-gpu:2.8.0'
    implementation 'org.tensorflow:tensorflow-lite-support:0.3.0'
    ...
}

配置完成后,注意在gradle.property中开启testOnly,否则USB调试失败

# Error: INSTALL_FAILED_TEST_ONLY'
android.injected.testOnly = false
  1. 将TFLite模型作为附件部署在工程中,位置为/app/src/main/assets
  2. 在代码中引入TFLite模型并进行推理

3.1. 引入相关包

import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;

import org.tensorflow.lite.gpu.GpuDelegate;

3.2 载入模型

// 读取ByteBuffer
private MappedByteBuffer loadModelFile(String model) throws IOException{
        AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model);
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }

// 将ByteBuffer绑定到TFLite推理器
MappedByteBuffer modelFile = loadModelFile(model);
// 设置推理器选项
Interpreter.Options options = new Interpreter.Options();
GpuDelegate delegate = new GpuDeledate();
// 以下选项根据需求调整
options.setNumThreads(4); // 4 CPU threads
options.addDelegate(deledate); // 添加GPU支持
options.setUseNNAPI(true); // 使用NNAPI推理

tflite = new Interpreter(modelFile, options); // 用来推理的TFLite模型

3.3 配置输入输出节点

// 获取输入节点数据类型
DataType imageDataType = tflite.getInputTensor(0).dataType();
// 配置输入数据Buffer
TensorImage inputImageBuffer = new TensorImage(imageDataType);
inputImageBuffer.load(bmp);
// 调整输入节点信息
int[] inputShape = {1, bmp.getHeight(), bmp.getWidth(), 3};
tflite.resizeInput(tflite.getInputTensor(0).index(), inputShape);

// 类似的,调整输出节点数据
// 此处做图像增强,输入输出大小一致
DataType probabilityDataType = tflite.getOutputTensor(0).dataType();
int[] probabilityShape = {1, bmp.getWidth(), bmp.getHeight(), 3};
TensorBuffer outputProbabilityBuffer = TensorBuffer.createFixedSize(probabilityShape, probabilityDataType);

// 得到输入输出的Buffer
ByteBuffer inputs = inputImageBuffer.getBuffer();
ByteBuffer outputs = outputProbabilityBuffer.getBuffer();

// 使用推理器进行推理
tflite.run(inputs, outputs);

// 读取输出Buffer信息
float[] results = outputProbabilityBuffer.getFloatArray();

综上,连接手机即可进行USB调试

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2022年5月21日
下一篇 2022年5月21日

相关推荐