diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index 7dab6ba6..7753db56 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -42,6 +42,7 @@ from param_bench.train.compute.python.lib import pytorch as lib_pytorch from param_bench.train.compute.python.lib.init_helper import load_modules from param_bench.train.compute.python.workloads import pytorch as workloads_pytorch +from torch._C import _cuda_getCurrentRawStream as get_raw_stream from torch._inductor.async_compile import AsyncCompile # grid and split_scan_grid are dynamically loaded @@ -1457,10 +1458,14 @@ def run_op(self, node, iter, cnt): # noqa: C901 outputs = [] if output_count == 0: if node.kernel_backend == "triton": - # The last entry in inputs is stream - # ET captured the raw pointer of cudaStream_t, - # but triton needs the stream object, hard code to 0 for now - exec("func.run(*inputs[:-1], stream=0)") + # The last entry in inputs is the stream id, however if we pass that back into func.run() + # ET will attemp to dereference the stream again. To fully support triton kernels we need + # to keep a mapping between stream address and stream id. This is not implemented yet so + # we will capture the raw stream from the device and pass that back to the replay. + # In reality all PT2 compute ops are in the same stream to begin with, only collectives + # are in different streams, so this is fine. + stream = get_raw_stream(self.cuda_id) + func.run(*inputs[:-1], stream=stream) else: func(*inputs) else: