tf custom op

这是一部分介绍 tf 中 custom op 的文档,记录了学习和模仿中的一些内容

TF Custom OP

官方文档

官方文档中给出了一个例子 link

C++ part

这里分成两部分,一部分是注册 op 另外一部分是具体实现 op

注册 op

需要引用头文件

1
2
3
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow;

接着使用一个 tf 文件中定义的宏 REGISTER_OP ,需要包括

  • 输入.Input("xxx: datatype")
  • 输出.Output("xxx: datatype")
  • 参数.Attrs("")

官方举例(zero_out.cc)

1
2
3
4
5
6
7
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});

另外官方还提到 OP 的名字需要是驼峰式命令(e.g. HelloWorldFromCode)

实现 op

需要引用头文件

1
2
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;

关于关键字 explicit 简而言之就是为了防止隐式转换的

然后将上面定义的 OP 名称( 额外加上 ‘OP’ ),继承于 OpKernel,还要重写其中的 Compute 方法,tf 提供了一个结构体 OpKernelContext 可以用来得到一些定义信息

官方给出的例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<int32>();

// Create an output tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output_flat = output_tensor->flat<int32>();

// Set all but the first element of the output tensor to 0.
const int N = input.size();
for (int i = 1; i < N; i++) {
output_flat(i) = 0;
}

// Preserve the first input value if possible.
if (N > 0) output_flat(0) = input(0);
}
};

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

python part

在 python 程序里需要引入

1
2
from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader

在python应用程序中使用加载动态文件

1
2
zero_out_ops = load_library.load_op_library(
resource_loader.get_path_to_datafile('_zero_out_ops.so'))

最后使用的时候可以为

1
zero_out = zero_out_ops.zero_out

GPU knernel

GPU 实现方法需要一些额外的步骤,包括

  • 预先处理

    在 .h文件中,引用头文件

    1
    #include <unsupported/Eigen/CXX11/Tensor>

    以及按照需求定义一些宏防止多次引用和方便关闭功能

  • 定义模板用于选择设备和数据类型

  • 定义一个结构体

    1
    2
    3
    4
    5
    template <typename T>
    struct ExampleFunctor<Eigen::GpuDevice, T> {
    void operator()(const Eigen::GpuDevice& d, int size, const T* in, T* out);
    };
    #endif
  • 再新建一个具体实现的文件夹.cc

同样之前的操作

1
2
3
4
5
6
7
8
9
10
#include "kernel_example.h" //这是上面这个 .h 文件

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;

同理需要注册

1
2
3
4
5
6
7
8
REGISTER_OP("Example")
.Attr("T: numbertype")
.Input("input: T")
.Output("input_times_two: T")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});

tf custom op
http://home.ustc.edu.cn/~ustcxwy0271/2022/04/01/tf-custom-op/
作者
Xu Weiye
发布于
2022年4月1日
许可协议