Tensorflow lite源码分析(Tensorflow 2.14)
Tensorflow lite源码分析(Tensorflow 2.14)
装载现有模型
首先Android app使用tensorflow api中的interpreter创建了一个解释器,该解释器有多种构造函数和继承,这里使用的是文件中读取模型,同时使用了一个委托(delegate),但是我们最后再来看这个委托是什么。
1 | private val tflite by lazy { |
该函数调用了Interpreter
类,具体函数过程如下:
1 | //tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java |
核心在于第四行和第六行,第四行将buffer赋给this对象,第六行通过buffer创建模型。这里就接着调用了JNI跳转到C/C++的实现。
1 | //tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc |
在上面逻辑中,最核心的一行是第18行,从buffer创建模型。其调用的是model_builder.cc
文件下的BuildFromBuffer
函数.
1 | //tensorflow/lite/core/model_builder.cc |
在做了一些错误检查和给buffer分配内存后转化成专门的Allocation类(这么做的原因还是为了抽象,把各种从文件读的模型,从内存buffer读的模型都转化成allocation类,最后进行统一处理),然后进行核心的函数处理BuildFromAllocation
:
1 | // tensorflow/lite/core/model_builder.cc |
它使用了GetModel
方法:
1 | // tensorflow/tensorflow/lite/core/model_builder.h |
可以发现这里调用了google的flatbuffer
库,具体源码参考flatbuffers/include/flatbuffers/buffer.h at d3e8cb60a133be5387008864d3fc212e31774b63 · google/flatbuffers (github.com)。实际调用引入过程是schema_generated.h
代码中引入了头文件 #include "flatbuffers/flatbuffers.h"
,然后 flatbuffers/flatbuffers.h
引入头文件flatbuffers/buffer.h
,最终得到GetModel
的实现。
1 | // https://github.com/google/flatbuffers/blob/d3e8cb60a133be5387008864d3fc212e31774b63/include/flatbuffers/buffer.h |
OK 到这一步为止,我们可以知道实际仅仅是吧flatbuffer格式的文件读进了内存,并用各种抽象进行管理,内存布局仍然是flatbuffer中定义的内存布局。并且最后的包装成了一个Model对象,Model *_model
。
然后让我们来看看TFLite中模型具体是怎么定义的吧,这个涉及到的文件为schema.fbs(tensorflow/tensorflow/lite/schema/schema.fbs)
Model:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28table Model {
// 一个int32类型的值,表示模型的版本号。目前最新的版本号是3
version:uint;
// 一个OperatorCode类型的数组,表示模型中使用的算子的编码。每个算子有一个唯一的编码,用于标识算子的名称和版本。
operator_codes:[OperatorCode];
// 一个SubGraph类型的数组,表示模型中包含的子图。每个子图有一个输入张量列表、一个输出张量列表、一个状态张量列表、一个算子节点列表和一个名称。子图是模型执行的基本单元,可以理解为一个计算图. 0th子图为main图
subgraphs:[SubGraph];
// A description of the model.
description:string;
// 一个Buffer类型的数组,表示模型中所有张量数据的存储空间。每个缓冲区有一个字节数组和一个哈希值。通过缓冲区的索引,可以从字节数组中读取或写入张量的数据。0th的buffer必须是空buffer
// Note the 0th entry of this array must be an empty buffer (sentinel).
// This is a convention so that tensors without a buffer can provide 0 as
// their buffer.
buffers:[Buffer];
// 一个int32类型的数组,表示模型中包含的元数据所在的缓冲区的索引。元数据是一种用于描述模型属性和关联文件等信息的结构,可以用于提高模型的可解释性和兼容性。
metadata_buffer:[int];
// Metadata about the model. 包含name和buffer,即名字字段和该metadata所在的buffer
metadata:[Metadata];
// 一个SignatureDef类型的数组,表示模型中包含的签名定义。每个签名定义有一个输入映射、一个输出映射和一个方法名称。签名定义是一种用于描述模型输入和输出接口以及功能等信息的结构,可以用于提高模型的可用性和灵活性。
signature_defs:[SignatureDef];
}Subgraph:
1
2
3
4
5
6
7
8
9
10
11
12
13table SubGraph {
// A list of all tensors used in this subgraph.
tensors:[Tensor];
// 子图的输入张量索引,例如inputs [0] 表示tensors数组中第0个张量表示输入张量
inputs:[int];
// 输出张量索引
outputs:[int];
// operators数组
operators:[Operator];
name:string;
}
Tensor:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21table Tensor {
// int 类型的数字,表示张量的形状,例如[1, 224, 224, 3],表示1个图片样本,224*224像素大小,RGB3个通道
shape:[int];
// 张量的数据类型,比如int32, float32, int8等等 ,也定义在该文件中的enum TensorType字段
type:TensorType;
// buffer:一个int32类型的值,表示张量的数据所在的缓冲区的索引。缓冲区是一种存储模型中所有张量数据的结构,每个缓冲区有一个唯一的索引和一个字节数组。通过这个字段,可以从缓冲区中读取或写入张量的数据。
buffer:uint;
name:string; // 一个string类型的值,表示张量的名称。这个字段可以用于标识或描述张量的作用或来源,也可以用于调试或可视化。
quantization:QuantizationParameters; // Optional. 一个QuantizationParameters类型的表,表示张量的量化参数。量化是一种将浮点数转换为整数或低位数来减少模型大小和提高性能的技术。这个表中包含了一些用于反量化或重新量化张量数据的参数,如最小值、最大值、比例因子、零点等。
is_variable:bool = false; //一个bool类型的值,表示张量是否是变量。变量是一种可以在模型运行过程中改变值的张量,通常用于存储模型的参数或状态。如果这个字段为true,则表示张量是变量,否则表示张量是常量。
// 一个SparsityParameters类型的表,表示张量的稀疏性参数。稀疏性是一种将张量中大部分为零的元素压缩或省略来减少模型大小和提高性能的技术。这个表中包含了一些用于解压缩或重新压缩张量数据的参数,如稀疏维度、非零值索引、非零值块等。
sparsity:SparsityParameters; // Optional.
shape_signature:[int]; // Optional.
has_rank: bool = false;
// 用作嵌套张量类型字段
variant_tensors:[VariantSubType];
}Buffer:
1
2
3
4
5
6
7
8
9
10table Buffer {
data:[ubyte] (force_align: 16);
// In a model that is larger than 2GB, then buffers instead uses the following
// attributes to find stored data, which is outside of flatbuffers
// the offset is calculated relative to the beginning of the file and is only
// valid if > 1.
offset: ulong;
size: ulong;
}
处理模型
接着上面的步骤,我们回到java包装器的地方,回顾一下代码:
1 | //tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java |
这里将模型读入buffer后,接着运行了init函数,那么init函数里面干了什么事情呢?
1 | // tensorflow/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java |
很复杂,但是我们关注一个核心函数,createInterpreter
。同样这里是一个JNI包装器,实际从java代码调用到了C++的API。
1 | // tensorflow/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc |
看来核心就在第15行开始的函数。这一行是什么意思呢,把一些修饰符删去后可以知道,resolver = OpResolverLazyDelegateProxy(CreateOpResolver(), useXnnpack != JNI_FALSE)
所以我们先看看 CreateOpResolver
干了什么事情。
1 | // tensorflow/tensorflow/lite/create_op_resolver_with_selected_ops.cc |
可以发现注册了很多内置的算子。这些算子应该是tensorflow中实现的。实现都在tensorflow/tensorflow/lite/kernels文件夹中。然后实际核心运算可以看到,使用的是std标准库中的std::abs<int32_t>
, 这里就能够和内核驱动进行一个连接。
然后再回头来看看InterpreterBuilder interpreter_builder(*model, *resolver);
这个调用了以下构造函数,其中options_experimental
为默认的参数, 也就是进行一些赋值后返回。
1 | // tensorflow/tensorflow/lite/core/interpreter_builder.h |
运行模型
运行模型首先还是调用的是java的API,这里是run函数。这个函数先判断一些错误检查,然后获取到input tensor,这个可以从之前说过的graph中的index获取。然后如果没有给tensor分配空间则分配空间。
1 | // tensorflow/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java |
最后又调用了重载的run函数,可以看到又是一个JNI跳转,于是再次回到tensorflow/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
搜索跳转的包装函数run。
1 | JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run( |
核心在于第五行,调用了interpreter结构体的invoke函数。我们看看invoke函数做了什么, 这个interpreter已经是C++的结构体了,不再是java的结构体了,所以直接调用的C++代码,在tensorflow/tensorflow/lite/core/interpreter.cc
文件。
1 | // tensorflow/tensorflow/lite/core/interpreter.cc |
可以发现调用了primary_subgraph().Invoke()
,我们先看看primary_subgraph()
是什么东西:
1 | //tensorflow/tensorflow/lite/core/interpreter.h |
这里设计到了一个subgraphs_
的变量,这是什么呢,这其实是一个子图的vector容器,在创建Interpreter时,其有一个私有的成员变量为subgraphs_
, 并且在Interpreter的构造函数中会将index为0的图加入这个vector中,也就是说primary_subgraph就是最初完整的那个模型图,后续会对这个图进行分裂成各个子图。这一部分在table subgraph中已经讲得很明确了。
OK, 知道了就是获取到从文件中读取到的模型图,然后回来invoke
函数,又是一层皮,还调用了InvokeImpl
。
1 | //tensorflow/tensorflow/lite/core/subgraph.cc |
可以发现调用了OpInvoke
函数, 然后他又调用了Registration的invoke函数,这就和前面结合起来了。
1 | TfLiteStatus Subgraph::OpInvoke(const TfLiteRegistration& op_reg, |
委托
在运行模型那部分代码最后出现了一个结构体TfLiteRegistrationExternal
, 这个类需要和TfLiteRegistration
一起分析。
参考链接
TensorFlow Lite 源代码分析-模型加载和执行 - actorfit (actorsfit.com)
TensorFlow Lite 深度解析 - 谷歌中国工程师教学视频_哔哩哔哩_bilibili
LVC21-113: TensorFlow Lite Delegates on Arm-based Devices | Linaro Resources Hub