Tensorflow lite源码分析(Tensorflow 2.14)

Tensorflow lite源码分析(Tensorflow 2.14)

装载现有模型

​ 首先Android app使用tensorflow api中的interpreter创建了一个解释器,该解释器有多种构造函数和继承,这里使用的是文件中读取模型,同时使用了一个委托(delegate),但是我们最后再来看这个委托是什么。

1
2
3
4
5
6
7
8
9
10
11
private val tflite by lazy {
Interpreter(
FileUtil.loadMappedFile(this, MODEL_PATH),
Interpreter.Options().addDelegate(nnApiDelegate))
}
private val detector by lazy {
ObjectDetectionHelper(
tflite,
FileUtil.loadLabels(this, LABELS_PATH)
)
}

​ 该函数调用了Interpreter类,具体函数过程如下:

1
2
3
4
5
6
7
8
9
10
//tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
NativeInterpreterWrapper(ByteBuffer buffer, InterpreterImpl.Options options) {
TensorFlowLite.init();
......
this.modelByteBuffer = buffer;
long errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
long modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
init(errorHandle, modelHandle, options);
}

​ 核心在于第四行和第六行,第四行将buffer赋给this对象,第六行通过buffer创建模型。这里就接着调用了JNI跳转到C/C++的实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
//tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer(
JNIEnv* env, jclass /*clazz*/, jobject model_buffer, jlong error_handle) {
if (!tflite::jni::CheckJniInitializedOrThrow(env)) return 0;

BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
if (error_reporter == nullptr) return 0;
const char* buf =
static_cast<char*>(env->GetDirectBufferAddress(model_buffer));
jlong capacity = env->GetDirectBufferCapacity(model_buffer);
if (!VerifyModel(buf, capacity)) {
ThrowException(
env, tflite::jni::kIllegalArgumentException,
"ByteBuffer is not a valid TensorFlow Lite model flatbuffer");
return 0;
}
auto model = FlatBufferModel::BuildFromBuffer(
buf, static_cast<size_t>(capacity), error_reporter);
if (!model) {
ThrowException(env, tflite::jni::kIllegalArgumentException,
"ByteBuffer does not encode a valid model: %s",
error_reporter->CachedErrorMessage());
return 0;
}
return reinterpret_cast<jlong>(model.release());
}

​ 在上面逻辑中,最核心的一行是第18行,从buffer创建模型。其调用的是model_builder.cc文件下的BuildFromBuffer函数.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
//tensorflow/lite/core/model_builder.cc
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
const char* caller_owned_buffer, size_t buffer_size,
ErrorReporter* error_reporter) {
//调用ValidateErrorReporter函数,检查error_reporter是否为nullptr,如果是,则返回默认的错误报告器,否则返回原来的error_reporter。
error_reporter = ValidateErrorReporter(error_reporter);
//创建一个MemoryAllocation对象,它是Allocation类的子类,用于封装缓冲区的地址和大小,并提供读取和写入的接口。MemoryAllocation对象使用new操作符在堆上分配内存,并使用caller_owned_buffer, buffer_size, error_reporter作为构造函数的参数。
std::unique_ptr<Allocation> allocation(
new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
return BuildFromAllocation(std::move(allocation), error_reporter);
}

ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
return e ? e : DefaultErrorReporter();
}

MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes, ErrorReporter* error_reporter)
: Allocation(error_reporter, Allocation::Type::kMemory) {
......
buffer_ = ptr;
buffer_size_bytes_ = num_bytes;
}

​ 在做了一些错误检查和给buffer分配内存后转化成专门的Allocation类(这么做的原因还是为了抽象,把各种从文件读的模型,从内存buffer读的模型都转化成allocation类,最后进行统一处理),然后进行核心的函数处理BuildFromAllocation:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// tensorflow/lite/core/model_builder.cc
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromAllocation(
std::unique_ptr<Allocation> allocation, ErrorReporter* error_reporter) {
std::unique_ptr<FlatBufferModel> model(new FlatBufferModel(
std::move(allocation), ValidateErrorReporter(error_reporter)));
if (!model->initialized()) {
model.reset();
} else {
model->ValidateModelBuffers(error_reporter);
}
return model;
}
// FlatBufferModel的构造函数实现
// 将model直接传递给类型model_成员变量
FlatBufferModel::FlatBufferModel(const Model* model, ErrorReporter* error_reporter)
: model_(model), error_reporter_(ValidateErrorReporter(error_reporter)) {}
//将allocation变量直接传递给allocation_变量,然后使用GetModel对allocation进行处理。
FlatBufferModel::FlatBufferModel(std::unique_ptr<Allocation> allocation, ErrorReporter* error_reporter)
: error_reporter_(ValidateErrorReporter(error_reporter)), allocation_(std::move(allocation)) {
if (!allocation_ || !allocation_->valid() || !CheckModelIdentifier()) {
return;
}

model_ = ::tflite::GetModel(allocation_->base());
}

​ 它使用了GetModel方法:

1
2
3
4
5
6
7
// tensorflow/tensorflow/lite/core/model_builder.h
const tflite::Model* GetModel() const { return model_; }

// tensorflow/tensorflow/lite/schema/schema_generated.h
inline const tflite::Model *GetModel(const void *buf) {
return ::flatbuffers::GetRoot<tflite::Model>(buf);
}

​ 可以发现这里调用了google的flatbuffer库,具体源码参考flatbuffers/include/flatbuffers/buffer.h at d3e8cb60a133be5387008864d3fc212e31774b63 · google/flatbuffers (github.com)。实际调用引入过程是schema_generated.h 代码中引入了头文件 #include "flatbuffers/flatbuffers.h" ,然后 flatbuffers/flatbuffers.h 引入头文件flatbuffers/buffer.h,最终得到GetModel的实现。

1
2
3
4
5
6
7
8
9
10
11
12
// https://github.com/google/flatbuffers/blob/d3e8cb60a133be5387008864d3fc212e31774b63/include/flatbuffers/buffer.h
template<typename T> const T *GetRoot(const void *buf) {
return GetMutableRoot<T>(const_cast<void *>(buf));
}

template<typename T> T *GetMutableRoot(void *buf) {
if (!buf) return nullptr;
EndianCheck();
return reinterpret_cast<T *>(
reinterpret_cast<uint8_t *>(buf) +
EndianScalar(*reinterpret_cast<uoffset_t *>(buf)));
}

​ OK 到这一步为止,我们可以知道实际仅仅是吧flatbuffer格式的文件读进了内存,并用各种抽象进行管理,内存布局仍然是flatbuffer中定义的内存布局。并且最后的包装成了一个Model对象,Model *_model

然后让我们来看看TFLite中模型具体是怎么定义的吧,这个涉及到的文件为schema.fbs(tensorflow/tensorflow/lite/schema/schema.fbs)

  • Model:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    table Model {
    // 一个int32类型的值,表示模型的版本号。目前最新的版本号是3
    version:uint;

    // 一个OperatorCode类型的数组,表示模型中使用的算子的编码。每个算子有一个唯一的编码,用于标识算子的名称和版本。
    operator_codes:[OperatorCode];

    // 一个SubGraph类型的数组,表示模型中包含的子图。每个子图有一个输入张量列表、一个输出张量列表、一个状态张量列表、一个算子节点列表和一个名称。子图是模型执行的基本单元,可以理解为一个计算图. 0th子图为main图
    subgraphs:[SubGraph];

    // A description of the model.
    description:string;

    // 一个Buffer类型的数组,表示模型中所有张量数据的存储空间。每个缓冲区有一个字节数组和一个哈希值。通过缓冲区的索引,可以从字节数组中读取或写入张量的数据。0th的buffer必须是空buffer
    // Note the 0th entry of this array must be an empty buffer (sentinel).
    // This is a convention so that tensors without a buffer can provide 0 as
    // their buffer.
    buffers:[Buffer];

    // 一个int32类型的数组,表示模型中包含的元数据所在的缓冲区的索引。元数据是一种用于描述模型属性和关联文件等信息的结构,可以用于提高模型的可解释性和兼容性。
    metadata_buffer:[int];

    // Metadata about the model. 包含name和buffer,即名字字段和该metadata所在的buffer
    metadata:[Metadata];

    // 一个SignatureDef类型的数组,表示模型中包含的签名定义。每个签名定义有一个输入映射、一个输出映射和一个方法名称。签名定义是一种用于描述模型输入和输出接口以及功能等信息的结构,可以用于提高模型的可用性和灵活性。
    signature_defs:[SignatureDef];
    }
  • Subgraph:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    table SubGraph {
    // A list of all tensors used in this subgraph.
    tensors:[Tensor];

    // 子图的输入张量索引,例如inputs [0] 表示tensors数组中第0个张量表示输入张量
    inputs:[int];
    // 输出张量索引
    outputs:[int];
    // operators数组
    operators:[Operator];
    name:string;
    }

  • Tensor:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    table Tensor {
    // int 类型的数字,表示张量的形状,例如[1, 224, 224, 3],表示1个图片样本,224*224像素大小,RGB3个通道
    shape:[int];
    // 张量的数据类型,比如int32, float32, int8等等 ,也定义在该文件中的enum TensorType字段
    type:TensorType;
    // buffer:一个int32类型的值,表示张量的数据所在的缓冲区的索引。缓冲区是一种存储模型中所有张量数据的结构,每个缓冲区有一个唯一的索引和一个字节数组。通过这个字段,可以从缓冲区中读取或写入张量的数据。
    buffer:uint;
    name:string; // 一个string类型的值,表示张量的名称。这个字段可以用于标识或描述张量的作用或来源,也可以用于调试或可视化。
    quantization:QuantizationParameters; // Optional. 一个QuantizationParameters类型的表,表示张量的量化参数。量化是一种将浮点数转换为整数或低位数来减少模型大小和提高性能的技术。这个表中包含了一些用于反量化或重新量化张量数据的参数,如最小值、最大值、比例因子、零点等。

    is_variable:bool = false; //一个bool类型的值,表示张量是否是变量。变量是一种可以在模型运行过程中改变值的张量,通常用于存储模型的参数或状态。如果这个字段为true,则表示张量是变量,否则表示张量是常量。

    // 一个SparsityParameters类型的表,表示张量的稀疏性参数。稀疏性是一种将张量中大部分为零的元素压缩或省略来减少模型大小和提高性能的技术。这个表中包含了一些用于解压缩或重新压缩张量数据的参数,如稀疏维度、非零值索引、非零值块等。
    sparsity:SparsityParameters; // Optional.

    shape_signature:[int]; // Optional.
    has_rank: bool = false;

    // 用作嵌套张量类型字段
    variant_tensors:[VariantSubType];
    }
  • Buffer:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    table Buffer {
    data:[ubyte] (force_align: 16);

    // In a model that is larger than 2GB, then buffers instead uses the following
    // attributes to find stored data, which is outside of flatbuffers
    // the offset is calculated relative to the beginning of the file and is only
    // valid if > 1.
    offset: ulong;
    size: ulong;
    }

处理模型

​ 接着上面的步骤,我们回到java包装器的地方,回顾一下代码:

1
2
3
4
5
6
7
8
9
//tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
NativeInterpreterWrapper(ByteBuffer buffer, InterpreterImpl.Options options) {
TensorFlowLite.init();
......
this.modelByteBuffer = buffer;
long errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
long modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
init(errorHandle, modelHandle, options);
}

​ 这里将模型读入buffer后,接着运行了init函数,那么init函数里面干了什么事情呢?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
// tensorflow/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java

// errorHandle:一个long类型的值,表示一个错误处理器的句柄,用于报告解释器运行过程中的错误信息。
// modelHandle:一个long类型的值,表示一个模型的句柄,用于加载和访问TensorFlow Lite模型的数据和元数据。
// options:一个InterpreterImpl.Options类型的对象,表示一些解释器的选项,如线程数、加速配置、精度控制等
private void init(long errorHandle, long modelHandle, InterpreterImpl.Options options) {
// some error handle, 检查options是否为空,检查options中是否有加速配置

// ......

this.errorHandle = errorHandle;
this.modelHandle = modelHandle;
// First create the interpreter without delegates. We need an interpreter in order to figure
// out whether the model contains any unresolved flex ops, and creating the interpreter with
// delegates might fail if there are any unresolved flex ops.
// (Alternatively, we could determine this without needing to recreate the interpreter
// by passing the tflite::Model in to here, and then traversing that?)
// 上面这段英文罗里吧嗦一大堆,实际就是先创建一个基本的解释器,如果后续有代理再进行处理。
ArrayList<Long> delegateHandles = new ArrayList<>();
this.interpreterHandle =
createInterpreter(
modelHandle,
errorHandle,
options.getNumThreads(),
options.getUseXNNPACK(),
delegateHandles);
// 判断是否有上述interpreter没法处理的op,也就是说上面这个函数就是处理模型成op的真实过程
this.originalGraphHasUnresolvedFlexOp = hasUnresolvedFlexOp(interpreterHandle);
// 这里添加委托,委托就是一种可以提高模型执行效率和兼容性的机制,可以将模型中的一些操作交给特定的硬件或软件来执行。
addDelegates(options);
// 初始化当前对象中包含InterpreterFactory接口的代理。InterpreterFactory接口是一种可以让代理自己创建解释器并管理其生命周期的机制。
initDelegatesWithInterpreterFactory();
delegateHandles.ensureCapacity(delegates.size());
for (Delegate delegate : delegates) {
delegateHandles.add(delegate.getNativeHandle());
}
// 说明有代理需要处理,所以把之前创建的基本解释器删除,添加有代理的解释器
if (!delegateHandles.isEmpty()) {
// If there are any delegates enabled, recreate the interpreter with those delegates.
delete(/* errorHandle= */ 0, /* modelHandle= */ 0, this.interpreterHandle);
this.interpreterHandle =
createInterpreter(
modelHandle,
errorHandle,
options.getNumThreads(),
options.getUseXNNPACK(),
delegateHandles);
}

// ......

// 创建两个TensorImpl类型的数组,分别赋值给当前对象的inputTensors和outputTensors成员变量。这两个数组用于存储模型的输入和输出张量的对象。数组的大小分别由调用getInputCount和getOutputCount方法得到。
this.inputTensors = new TensorImpl[getInputCount(interpreterHandle)];
this.outputTensors = new TensorImpl[getOutputCount(interpreterHandle)];

// ......
// 为模型的输入和输出张量分配内存
allocateTensors(interpreterHandle, errorHandle);
this.isMemoryAllocated = true;
}

​ 很复杂,但是我们关注一个核心函数,createInterpreter。同样这里是一个JNI包装器,实际从java代码调用到了C++的API。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
// tensorflow/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle,
jint num_threads, jboolean useXnnpack, jobject delegate_handle_list) {
// some handler check for JNI
// 获取两个调用的参数,应该是JNI调用过来的时候需要进行一个数据结构的转换。
FlatBufferModel* model = convertLongToModel(env, model_handle);
if (model == nullptr) return 0;

BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
if (error_reporter == nullptr) return 0;


std::unique_ptr<OpResolver> resolver =
std::make_unique<tflite::jni::OpResolverLazyDelegateProxy>(
tflite::CreateOpResolver(), useXnnpack != JNI_FALSE);

InterpreterBuilder interpreter_builder(*model, *resolver);
interpreter_builder.SetNumThreads(static_cast<int>(num_threads));

// Add delegate_list to interpreter_builder.

// Java: int size = delegate_list.size();
jint size = env->CallIntMethod(delegate_handle_list, list_size_method);
for (jint i = 0; i < size; ++i) {
// Java: Long jdelegate_handle = delegate_handle_list->get(i);
jobject jdelegate_handle =
env->CallObjectMethod(delegate_handle_list, list_get_method, i);
if (jdelegate_handle == nullptr) {
if (!env->ExceptionCheck()) {
ThrowException(env, tflite::jni::kIllegalArgumentException,
"Internal error: null object in Delegate handle list");
}
return 0;
}
// Java: long delegate_handle = jdelegate_handle.longValue();
jlong delegate_handle =
env->CallLongMethod(jdelegate_handle, long_value_method);
if (delegate_handle == 0) {
if (!env->ExceptionCheck()) {
ThrowException(env, tflite::jni::kIllegalArgumentException,
"Internal error: Found invalid handle");
}
return 0;
}
auto delegate = reinterpret_cast<TfLiteOpaqueDelegate*>(delegate_handle);
interpreter_builder.AddDelegate(delegate);
}

看来核心就在第15行开始的函数。这一行是什么意思呢,把一些修饰符删去后可以知道,resolver = OpResolverLazyDelegateProxy(CreateOpResolver(), useXnnpack != JNI_FALSE) 所以我们先看看 CreateOpResolver 干了什么事情。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
// tensorflow/tensorflow/lite/create_op_resolver_with_selected_ops.cc
std::unique_ptr<MutableOpResolver> CreateOpResolver() {
std::unique_ptr<MutableOpResolver> resolver =
std::make_unique<MutableOpResolver>();
RegisterSelectedOps(resolver.get());
return resolver;
}

// tensorflow/tensorflow/lite/core/create_op_resolver_with_builtin_ops.cc
std::unique_ptr<MutableOpResolver> CreateOpResolver() { // NOLINT
return std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver>(
new tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
}

BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_ABS, Register_ABS(), /* min_version = */ 1,
/* max_version = */ 5);
AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH());
AddBuiltin(BuiltinOperator_RELU, Register_RELU(), /* min_version = */ 1,
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1());
AddBuiltin(BuiltinOperator_RELU_0_TO_1, Register_RELU_0_TO_1());
AddBuiltin(BuiltinOperator_RELU6, Register_RELU6(), /* min_version = */ 1,
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_TANH, Register_TANH(), /* min_version = */ 1,
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC(),
/* min_version = */ 1,
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D(),
/* min_version */ 1,
/* max_version */ 3);


enum BuiltinOperator : int32_t {
BuiltinOperator_ADD = 0,
BuiltinOperator_AVERAGE_POOL_2D = 1,
BuiltinOperator_CONCATENATION = 2,
BuiltinOperator_CONV_2D = 3,
BuiltinOperator_DEPTHWISE_CONV_2D = 4,
BuiltinOperator_DEPTH_TO_SPACE = 5,
BuiltinOperator_DEQUANTIZE = 6,
BuiltinOperator_EMBEDDING_LOOKUP = 7,
BuiltinOperator_FLOOR = 8,
...
BuiltinOperator_ABS = 101,
...

//tensorflow/tensorflow/lite/kernels/elementwise.cc
// elementwise::ElementWiseQuantizedInit:这个函数用于初始化一个逐元素运算的算子,它会根据输入和输出张量的类型和量化参数,创建一个逐元素运算的上下文对象,并返回其指针。
// elementwise::ElementWiseQuantizedFree:这个函数用于释放一个逐元素运算的算子,它会删除之前创建的逐元素运算的上下文对象,并释放其内存空间。
// PrepareAbs:这个函数用于准备一个绝对值算子,它会检查输入和输出张量的形状是否一致,并根据输入和输出张量的类型和量化参数,计算出绝对值运算所需的乘法和偏移量,并保存在逐元素运算的上下文对象中。
// elementwise::AbsEval:这个函数用于执行一个绝对值算子,它会根据输入和输出张量的类型和量化参数,以及之前计算出的乘法和偏移量,对输入张量中的每个元素求其绝对值,并将结果写入输出张量中。
GENERIC_PREPARE(PrepareAbs, elementwise::IsAbsSupportedType,
elementwise::kAbsName)
TfLiteRegistration* Register_ABS() {
static TfLiteRegistration r = {elementwise::ElementWiseQuantizedInit,
elementwise::ElementWiseQuantizedFree,
PrepareAbs, elementwise::AbsEval};
return &r;
}

TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
const TfLiteType type = input->type;
switch (type) {
case kTfLiteFloat32:
return EvalImpl<float>(context, node, std::abs<float>, type);
case kTfLiteInt8:
return AbsEvalQuantized<int8_t>(context, node, type);
case kTfLiteInt16:
return input->quantization.type == kTfLiteNoQuantization
? AbsInt16EvalImpl(context, node, type)
: AbsEvalQuantized<int16_t>(context, node, type);
case kTfLiteInt32:
return EvalImpl<int32_t>(context, node, std::abs<int32_t>, type);
default:
TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
TfLiteTypeGetName(type));
return kTfLiteError;
}
}

template <typename T>
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
std::function<T(T)> func,
std::function<TfLiteStatus(T)> validate_input_func,
TfLiteType expected_type) {
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
const int64_t num_elements = NumElements(input);
const T* in_data = GetTensorData<T>(input);
T* out_data = GetTensorData<T>(output);
for (int64_t i = 0; i < num_elements; ++i) {
if (validate_input_func) {
TF_LITE_ENSURE_OK(context, validate_input_func(in_data[i]));
}
out_data[i] = func(in_data[i]);
}
return kTfLiteOk;
}

​ 可以发现注册了很多内置的算子。这些算子应该是tensorflow中实现的。实现都在tensorflow/tensorflow/lite/kernels文件夹中。然后实际核心运算可以看到,使用的是std标准库中的std::abs<int32_t>, 这里就能够和内核驱动进行一个连接。

然后再回头来看看InterpreterBuilder interpreter_builder(*model, *resolver); 这个调用了以下构造函数,其中options_experimental 为默认的参数, 也就是进行一些赋值后返回。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// tensorflow/tensorflow/lite/core/interpreter_builder.h
class InterpreterBuilder {
public:
/// For this constructor, the ErrorReporter will be extracted from the
/// FlatBufferModel.
/// `options` object is copied during construction. So caller can release it
// after calling the constructor.
InterpreterBuilder(const FlatBufferModel& model,
const OpResolver& op_resolver,
const InterpreterOptions* options_experimental = nullptr);


// tensorflow/tensorflow/lite/core/interpreter_builder.cc
InterpreterBuilder::InterpreterBuilder(
const FlatBufferModel& model, const OpResolver& op_resolver,
const InterpreterOptions* options_experimental)
: model_(model.GetModel()),
op_resolver_(op_resolver),
error_reporter_(ValidateErrorReporter(model.error_reporter())),
metadata_(model.ReadAllMetadata()),
allocation_(model.allocation()) {
if (options_experimental) {
options_ = *options_experimental;
}
}

运行模型

​ 运行模型首先还是调用的是java的API,这里是run函数。这个函数先判断一些错误检查,然后获取到input tensor,这个可以从之前说过的graph中的index获取。然后如果没有给tensor分配空间则分配空间。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// tensorflow/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
void run(Object[] inputs, Map<Integer, Object> outputs) {
// some error check

// get input tensor and allocate space
for (int i = 0; i < inputs.length; ++i) {
TensorImpl tensor = getInputTensor(i);
int[] newShape = tensor.getInputShapeIfDifferent(inputs[i]);
if (newShape != null) {
resizeInput(i, newShape);
}
}

boolean allocatedTensors = allocateTensorsIfNeeded();

for (int i = 0; i < inputs.length; ++i) {
getInputTensor(i).setTo(inputs[i]);
}

long inferenceStartNanos = System.nanoTime();
run(interpreterHandle, errorHandle);
long inferenceDurationNanoseconds = System.nanoTime() - inferenceStartNanos;
......


// tensorflow/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
private static native void run(long interpreterHandle, long errorHandle);

最后又调用了重载的run函数,可以看到又是一个JNI跳转,于是再次回到tensorflow/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc 搜索跳转的包装函数run。

1
2
3
4
5
6
7
8
9
10
11
12
JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) {
// ......

if (interpreter->Invoke() != kTfLiteOk) {
// TODO(b/168266570): Return InterruptedException.
ThrowException(env, tflite::jni::kIllegalArgumentException,
"Internal error: Failed to run on the given Interpreter: %s",
error_reporter->CachedErrorMessage());
return;
}
}

​ 核心在于第五行,调用了interpreter结构体的invoke函数。我们看看invoke函数做了什么, 这个interpreter已经是C++的结构体了,不再是java的结构体了,所以直接调用的C++代码,在tensorflow/tensorflow/lite/core/interpreter.cc文件。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
// tensorflow/tensorflow/lite/core/interpreter.cc
TfLiteStatus Interpreter::Invoke() {
ScopedRuntimeInstrumentationProfile scoped_runtime_event(root_profiler_.get(),
"invoke");

// ......

TF_LITE_ENSURE_STATUS_WITH_SCOPED_INSTRUMENTATION(
scoped_runtime_event, primary_subgraph().Invoke());

if (!allow_buffer_handle_output_) {
for (int tensor_index : outputs()) {
TF_LITE_ENSURE_STATUS_WITH_SCOPED_INSTRUMENTATION(
scoped_runtime_event,
primary_subgraph().EnsureTensorDataIsReadable(tensor_index));
}
}

return kTfLiteOk;
}

​ 可以发现调用了primary_subgraph().Invoke(),我们先看看primary_subgraph() 是什么东西:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
  //tensorflow/tensorflow/lite/core/interpreter.h
class Interpreter {
public:
Subgraph& primary_subgraph() {
return *subgraphs_.front(); // Safe as subgraphs_ always has 1 entry.
}
private:
// Subgraphs
std::vector<std::unique_ptr<Subgraph>> subgraphs_;
// ......
}

//tensorflow/tensorflow/lite/core/interpreter.cc
Interpreter::Interpreter(ErrorReporter* error_reporter)
: error_reporter_(error_reporter ? error_reporter
: DefaultErrorReporter()) {
//......
AddSubgraphs(1);
...
}


void Interpreter::AddSubgraphs(int subgraphs_to_add,
int* first_new_subgraph_index) {
// first_new_subgraph_index 默认是nullptr
const size_t base_index = subgraphs_.size();
if (first_new_subgraph_index) *first_new_subgraph_index = base_index;

subgraphs_.reserve(base_index + subgraphs_to_add);
for (int i = 0; i < subgraphs_to_add; ++i) {
Subgraph* subgraph = new Subgraph(
error_reporter_, external_contexts_, &subgraphs_, &resources_,
&resource_ids_, &initialization_status_map_, subgraphs_.size());
subgraphs_.emplace_back(subgraph);
}
}

这里设计到了一个subgraphs_的变量,这是什么呢,这其实是一个子图的vector容器,在创建Interpreter时,其有一个私有的成员变量为subgraphs_, 并且在Interpreter的构造函数中会将index为0的图加入这个vector中,也就是说primary_subgraph就是最初完整的那个模型图,后续会对这个图进行分裂成各个子图。这一部分在table subgraph中已经讲得很明确了。

​ OK, 知道了就是获取到从文件中读取到的模型图,然后回来invoke函数,又是一层皮,还调用了InvokeImpl

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
//tensorflow/tensorflow/lite/core/subgraph.cc
TfLiteStatus Subgraph::Invoke() {
auto status = InvokeImpl();
telemetry::TelemetryReportEvent(&context_, "Invoke", status);
return status;
}


TfLiteStatus Subgraph::InvokeImpl() {

// ......

EnsureTensorsVectorCapacity();
tensor_resized_since_op_invoke_ = false;
if (auto s = OpInvoke(registration, &node); s != kTfLiteOk) {
auto err = ReportOpError(&context_, node, registration, node_index,
"failed to invoke");
return s == kTfLiteCancelled ? s : err;
}

//......
return status;
}

可以发现调用了OpInvoke函数, 然后他又调用了Registration的invoke函数,这就和前面结合起来了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
TfLiteStatus Subgraph::OpInvoke(const TfLiteRegistration& op_reg,
TfLiteNode* node) {
if (op_reg.registration_external &&
op_reg.registration_external->node_index != -1) {
TfLiteRegistration* referenced_registration =
&nodes_and_registration_[op_reg.registration_external->node_index]
.second;
if (referenced_registration->invoke == nullptr) return kTfLiteError;
return referenced_registration->invoke(&context_, node);
}

if (op_reg.registration_external && op_reg.registration_external->invoke) {
return op_reg.registration_external->invoke(
reinterpret_cast<TfLiteOpaqueContext*>(&context_),
reinterpret_cast<TfLiteOpaqueNode*>(node));
}
if (op_reg.invoke == nullptr) return kTfLiteError;
return op_reg.invoke(&context_, node);
}

委托

在运行模型那部分代码最后出现了一个结构体TfLiteRegistrationExternal, 这个类需要和TfLiteRegistration 一起分析。

参考链接

TensorFlow Lite 源代码分析-模型加载和执行 - actorfit (actorsfit.com)

TensorFlow Lite 深度解析 - 谷歌中国工程师教学视频_哔哩哔哩_bilibili

LVC21-113: TensorFlow Lite Delegates on Arm-based Devices | Linaro Resources Hub