From c1e910550e1d495afd48214c06741113223b2d2c Mon Sep 17 00:00:00 2001 From: Leon Lin Date: Thu, 22 May 2025 13:15:32 -0700 Subject: [PATCH] Regenerate pt2_et trace with corrected stream_id, fix test (#207) Summary: Pull Request resolved: https://github.com/facebookresearch/param/pull/207 This diff regenerates the pt2_et.json.tar.gz trace file used in the test_et_replay.py test file. Reviewed By: shengfukevin Differential Revision: D75181044 --- et_replay/tools/et_replay.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index 7dab6ba6..7753db56 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -42,6 +42,7 @@ from param_bench.train.compute.python.lib import pytorch as lib_pytorch from param_bench.train.compute.python.lib.init_helper import load_modules from param_bench.train.compute.python.workloads import pytorch as workloads_pytorch +from torch._C import _cuda_getCurrentRawStream as get_raw_stream from torch._inductor.async_compile import AsyncCompile # grid and split_scan_grid are dynamically loaded @@ -1457,10 +1458,14 @@ def run_op(self, node, iter, cnt): # noqa: C901 outputs = [] if output_count == 0: if node.kernel_backend == "triton": - # The last entry in inputs is stream - # ET captured the raw pointer of cudaStream_t, - # but triton needs the stream object, hard code to 0 for now - exec("func.run(*inputs[:-1], stream=0)") + # The last entry in inputs is the stream id, however if we pass that back into func.run() + # ET will attemp to dereference the stream again. To fully support triton kernels we need + # to keep a mapping between stream address and stream id. This is not implemented yet so + # we will capture the raw stream from the device and pass that back to the replay. + # In reality all PT2 compute ops are in the same stream to begin with, only collectives + # are in different streams, so this is fine. + stream = get_raw_stream(self.cuda_id) + func.run(*inputs[:-1], stream=stream) else: func(*inputs) else: