Skip to content

How to call TVM ffi functions from tensorflow? #412

@xuantengh

Description

@xuantengh

I have a AOT tilelang kernel (topk from their examples), and now I want to invoke it from TensorFlow like this:

def dlpackify(x: tf.Tensor):
    return tf.experimental.dlpack.to_dlpack(x)


mod = tvm_ffi.load_module("build/topk.so")
func = mod.get_function("main")

x = tf.random.normal([320, 128], dtype=tf.float32)
print(x.shape)
k = 6

topk_tensors, topk_indices = (
    tf.zeros([320, k], dtype=tf.float32),
    tf.zeros([320, k], dtype=tf.int32),
)
func(dlpackify(x), dlpackify(topk_tensors), dlpackify(topk_indices))

But it seems to fail to recognize the correct shape info of tf tensors:

Traceback (most recent call last):
  File "/mnt/cephfs/load.py", line 23, in <module>
    func(dlpackify(x), dlpackify(topk_tensors), dlpackify(topk_indices))
  File "python/tvm_ffi/cython/function.pxi", line 923, in tvm_ffi.core.Function.__call__
  File "<unknown>", line 0, in topk_kernel
RuntimeError: kernel topk_kernel input logits ndim expected 2, but got 391978384

(and each time the random wrong ndim number varies).

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions