这是一部分介绍 tf 中 custom op 的文档,记录了学习和模仿中的一些内容
TF Custom OP
官方文档
官方文档中给出了一个例子 link
C++ part
这里分成两部分,一部分是注册 op 另外一部分是具体实现 op
注册 op
需要引用头文件
1 2 3
| #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" using namespace tensorflow;
|
接着使用一个 tf 文件中定义的宏 REGISTER_OP
,需要包括
- 输入
.Input("xxx: datatype")
- 输出
.Output("xxx: datatype")
- 参数
.Attrs("")
官方举例(zero_out.cc)
1 2 3 4 5 6 7
| REGISTER_OP("ZeroOut") .Input("to_zero: int32") .Output("zeroed: int32") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { c->set_output(0, c->input(0)); return Status::OK(); });
|
另外官方还提到 OP 的名字需要是驼峰式命令(e.g. HelloWorldFromCode)
实现 op
需要引用头文件
1 2
| #include "tensorflow/core/framework/op_kernel.h" using namespace tensorflow;
|
关于关键字 explicit 简而言之就是为了防止隐式转换的
然后将上面定义的 OP 名称( 额外加上 ‘OP’ ),继承于 OpKernel
,还要重写其中的 Compute 方法,tf 提供了一个结构体 OpKernelContext
可以用来得到一些定义信息
官方给出的例子
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
| #include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
class ZeroOutOp : public OpKernel { public: explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override { const Tensor& input_tensor = context->input(0); auto input = input_tensor.flat<int32>();
Tensor* output_tensor = NULL; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output_flat = output_tensor->flat<int32>();
const int N = input.size(); for (int i = 1; i < N; i++) { output_flat(i) = 0; }
if (N > 0) output_flat(0) = input(0); } };
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
|
python part
在 python 程序里需要引入
1 2
| from tensorflow.python.framework import load_library from tensorflow.python.platform import resource_loader
|
在python应用程序中使用加载动态文件
1 2
| zero_out_ops = load_library.load_op_library( resource_loader.get_path_to_datafile('_zero_out_ops.so'))
|
最后使用的时候可以为
1
| zero_out = zero_out_ops.zero_out
|
GPU knernel
GPU 实现方法需要一些额外的步骤,包括
预先处理
在 .h文件中,引用头文件
1
| #include <unsupported/Eigen/CXX11/Tensor>
|
以及按照需求定义一些宏防止多次引用和方便关闭功能
定义模板用于选择设备和数据类型
定义一个结构体
1 2 3 4 5
| template <typename T> struct ExampleFunctor<Eigen::GpuDevice, T> { void operator()(const Eigen::GpuDevice& d, int size, const T* in, T* out); }; #endif
|
再新建一个具体实现的文件夹.cc
同样之前的操作
1 2 3 4 5 6 7 8 9 10
| #include "kernel_example.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice;
|
同理需要注册
1 2 3 4 5 6 7 8
| REGISTER_OP("Example") .Attr("T: numbertype") .Input("input: T") .Output("input_times_two: T") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { c->set_output(0, c->input(0)); return Status::OK(); });
|