-
Notifications
You must be signed in to change notification settings - Fork 50
Open
Labels
enhancementNew feature or requestNew feature or request
Description
I have a AOT tilelang kernel (topk from their examples), and now I want to invoke it from TensorFlow like this:
def dlpackify(x: tf.Tensor):
return tf.experimental.dlpack.to_dlpack(x)
mod = tvm_ffi.load_module("build/topk.so")
func = mod.get_function("main")
x = tf.random.normal([320, 128], dtype=tf.float32)
print(x.shape)
k = 6
topk_tensors, topk_indices = (
tf.zeros([320, k], dtype=tf.float32),
tf.zeros([320, k], dtype=tf.int32),
)
func(dlpackify(x), dlpackify(topk_tensors), dlpackify(topk_indices))But it seems to fail to recognize the correct shape info of tf tensors:
Traceback (most recent call last):
File "/mnt/cephfs/load.py", line 23, in <module>
func(dlpackify(x), dlpackify(topk_tensors), dlpackify(topk_indices))
File "python/tvm_ffi/cython/function.pxi", line 923, in tvm_ffi.core.Function.__call__
File "<unknown>", line 0, in topk_kernel
RuntimeError: kernel topk_kernel input logits ndim expected 2, but got 391978384
(and each time the random wrong ndim number varies).
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request