diff --git a/et_replay/execution_trace.py b/et_replay/execution_trace.py index d8eec397..8e24aa2a 100644 --- a/et_replay/execution_trace.py +++ b/et_replay/execution_trace.py @@ -357,8 +357,8 @@ def __init__(self, json): input_tensors = self.nodes[id].get_input_tensors() output_tensors = self.nodes[id].get_output_tensors() - # track the various process and threads we have - if x["name"] == "__ROOT_THREAD__": + # track annonation to get thread ids of root nodes + if x["name"] == "[pytorch|profiler|execution_trace|thread]": tid = self.nodes[id].tid self.proc_group[pid][tid] = id