Skip to content

does the prompt content(or length) matter? #18

@hygxy

Description

@hygxy
python3 convert_from_jax.py --jax_path /home/xxx/.cache/openpi/openpi-assets/checkpoints/pi0_droid --output converted_checkpoint.pkl --prompt "Sort bowls and paper cups into their designated places" --tokenizer_path /home/xxx/.cache/modelscope/hub/models/AI-ModelScope/paligemma-3b-pt-224/

python3 ./infer_test.py

leads to the following error:

Traceback (most recent call last):
  File "/home/xxx/realtime-vla/./infer_test.py", line 20, in <module>
    infer = Pi0Inference(converted_checkpoint, number_of_images, length_of_trajectory)
  File "/home/xxx/realtime-vla/pi0_infer.py", line 1341, in __init__
    self.record_infer_graph()
    ~~~~~~~~~~~~~~~~~~~~~~~^^
  File "/home/xxx/realtime-vla/pi0_infer.py", line 1349, in record_infer_graph
    self.record_run()
    ~~~~~~~~~~~~~~~^^
  File "/home/xxx/realtime-vla/pi0_infer.py", line 1345, in record_run
    pi0_model(self.weights, self.buffers, self.num_views)
    ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxx/realtime-vla/pi0_infer.py", line 1236, in pi0_model
    transformer_decoder(weights, buffers, encoder_seq_len)
    ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxx/realtime-vla/pi0_infer.py", line 1195, in transformer_decoder
    matmul_k8_256_n_softmax_mask0(
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        buffers['decoder_q_buf'],
        ^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<2 lines>...
        encoder_seq_len
        ^^^^^^^^^^^^^^^
    )
    ^
  File "/home/xxx/realtime-vla/pi0_infer.py", line 1101, in matmul_k8_256_n_softmax_mask0
    matmul_abT_scale[(((total_queries + 31) // 32) * ((total_keys + 31) // 32),)](Q, K, out,
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^
        total_queries, total_keys, head_dim, head_dim ** -0.5,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        BLOCK_SIZE_M=32, BLOCK_SIZE_N=32, BLOCK_SIZE_K=64)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxx/anaconda3/lib/python3.13/site-packages/triton/runtime/jit.py", line 347, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxx/anaconda3/lib/python3.13/site-packages/triton/runtime/jit.py", line 591, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
    ^^^^^^^^^^
  File "/home/xxx/anaconda3/lib/python3.13/site-packages/triton/compiler/compiler.py", line 413, in __getattribute__
    self._init_handles()
    ~~~~~~~~~~~~~~~~~~^^
  File "/home/xxx/anaconda3/lib/python3.13/site-packages/triton/compiler/compiler.py", line 408, in _init_handles
    self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
                                                             ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        self.name, self.kernel, self.metadata.shared, device)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

Meanwhile, with prompt being "Sort bowls and paper cups into their designated"(without the last token), the script runs successfully, i.e.

rm converted_checkpoint.pkl
python3 convert_from_jax.py --jax_path /home/xxx/.cache/openpi/openpi-assets/checkpoints/pi0_droid --output converted_checkpoint.pkl --prompt "Sort bowls and paper cups into their designated" --tokenizer_path /home/xxx/.cache/modelscope/hub/models/AI-ModelScope/paligemma-3b-pt-224/
python3 ./infer_test.py

content of infer_test.py

import pickle
import torch
import time

converted_checkpoint = pickle.load(open('converted_checkpoint.pkl', 'rb'))
from pi0_infer import Pi0Inference

number_of_images = 3
length_of_trajectory = 50

infer = Pi0Inference(converted_checkpoint, number_of_images, length_of_trajectory)

start_time = time.time()
output_actions = infer.forward(
   torch.randn((number_of_images, 224, 224, 3), dtype=torch.bfloat16), # (number_of_images, 224, 224, 3)
   torch.randn((32,), dtype=torch.bfloat16),            # (32,)
   torch.randn((length_of_trajectory, 32), dtype=torch.bfloat16), # (length_of_trajectory, 32)
)
end_time = time.time()
elapsed = end_time - start_time
print(f"infer耗时: {elapsed:.4f}s")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions