My current code:
// For Eigen::ThreadPoolDevice. #define EIGEN_USE_THREADS 1 #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" using namespace tensorflow; REGISTER_OP("ArrayContainerCreate") .Attr("T: type") .Attr("container: string = ''") .Attr("shared_name: string = ''") .Output("resource: resource") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape) .Doc(R"doc(Array container, random index access)doc"); REGISTER_OP("ArrayContainerGetSize") .Input("handle: resource") .Output("out: int32") .SetShapeFn(shape_inference::ScalarShape) ; // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/resource_mgr.h struct ArrayContainer : public ResourceBase { ArrayContainer(const DataType& dtype) : dtype_(dtype) {} string DebugString() override { return "ArrayContainer"; } int64 MemoryUsed() const override { return 0; }; mutex mu_; const DataType dtype_; int32 get_size() { mutex_lock l(mu_); return (int32) 42; } }; // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/resource_op_kernel.h class ArrayContainerCreateOp : public ResourceOpKernel<ArrayContainer> { public: explicit ArrayContainerCreateOp(OpKernelConstruction* context) : ResourceOpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("T", &dtype_)); } private: virtual bool IsCancellable() const { return false; } virtual void Cancel() {} Status CreateResource(ArrayContainer** ret) override EXCLUSIVE_LOCKS_REQUIRED(mu_) { *ret = new ArrayContainer(dtype_); if(*ret == nullptr) return errors::ResourceExhausted("Failed to allocate"); return Status::OK(); } Status VerifyResource(ArrayContainer* ar) override { if(ar->dtype_ != dtype_) return errors::InvalidArgument("Data type mismatch: expected ", DataTypeString(dtype_), " but got ", DataTypeString(ar->dtype_), "."); return Status::OK(); } DataType dtype_; }; REGISTER_KERNEL_BUILDER(Name("ArrayContainerCreate").Device(DEVICE_CPU), ArrayContainerCreateOp); class ArrayContainerGetSizeOp : public OpKernel { public: using OpKernel::OpKernel; void Compute(OpKernelContext* context) override { ArrayContainer* ar; OP_REQUIRES_OK(context, GetResourceFromContext(context, "handle", &ar)); core::ScopedUnref unref(ar); int32 size = ar->get_size(); Tensor* tensor_size = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &tensor_size)); tensor_size->flat<int32>().setConstant(size); } }; REGISTER_KERNEL_BUILDER(Name("ArrayContainerGetSize").Device(DEVICE_CPU), ArrayContainerGetSizeOp);
I compile that. Note that I first got some undefined symbol: _ZN6google8protobuf8internal26fixed_address_empty_stringE
error but I resolved that by adding these additional compiler flags:
from google.protobuf.pyext import _message as msg lib = msg.__file__ extra_compiler_flags = [ "-Xlinker", "-rpath", "-Xlinker", os.path.dirname(lib), "-L", os.path.dirname(lib), "-l", ":" + os.path.basename(lib)]
I read about that here.
Then I load that as a module via tf.load_op_library
.
Then, I have this Python code:
handle = mod.array_container_create(T=tf.int32) size = mod.array_container_get_size(handle=handle)
When I try to evaluate size
, I get the error:
InvalidArgumentError (see above for traceback): Trying to access resource located in device 14ArrayContainer from device /job:localhost/replica:0/task:0/cpu:0 [[Node: ArrayContainerGetSize = ArrayContainerGetSize[_device="/job:localhost/replica:0/task:0/cpu:0"](array_container)]]
The device name (14ArrayContainer
) somehow seem to be messed up. Why is that? What is the problem with the code?
For some more testing, I added this additional code in the ArrayContainerCreateOp
:
ResourceHandle rhandle = MakeResourceHandle<ArrayContainer>(context, cinfo_.container(), cinfo_.name()); printf("created. device: %s\n", rhandle.device().c_str()); printf("container: %s\n", rhandle.container().c_str()); printf("name: %s\n", rhandle.name().c_str()); printf("actual device: %s\n", context->device()->attributes().name().c_str()); printf("actual name: %s\n", cinfo_.name().c_str());
This gives me the output:
created. device: 14ArrayContainer container: 14ArrayContainer name: 14ArrayContainer actual device: /job:localhost/replica:0/task:0/cpu:0 actual name: _2_array_container
So clearly, there is some of the problem.
This looks like something is messed up with the protobuf? Maybe I am linking the wrong lib? But I have not found which lib to link instead.
(I also posted an issue about this here.)
0 comments:
Post a Comment