Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions et_replay/tools/et_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from param_bench.train.compute.python.lib import pytorch as lib_pytorch
from param_bench.train.compute.python.lib.init_helper import load_modules
from param_bench.train.compute.python.workloads import pytorch as workloads_pytorch
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._inductor.async_compile import AsyncCompile

# grid and split_scan_grid are dynamically loaded
Expand Down Expand Up @@ -1457,10 +1458,14 @@ def run_op(self, node, iter, cnt): # noqa: C901
outputs = []
if output_count == 0:
if node.kernel_backend == "triton":
# The last entry in inputs is stream
# ET captured the raw pointer of cudaStream_t,
# but triton needs the stream object, hard code to 0 for now
exec("func.run(*inputs[:-1], stream=0)")
# The last entry in inputs is the stream id, however if we pass that back into func.run()
# ET will attemp to dereference the stream again. To fully support triton kernels we need
# to keep a mapping between stream address and stream id. This is not implemented yet so
# we will capture the raw stream from the device and pass that back to the replay.
# In reality all PT2 compute ops are in the same stream to begin with, only collectives
# are in different streams, so this is fine.
stream = get_raw_stream(self.cuda_id)
func.run(*inputs[:-1], stream=stream)
else:
func(*inputs)
else:
Expand Down