diff --git a/src/converter/pytorch_converter.py b/src/converter/pytorch_converter.py index b5ace295..3bcde903 100644 --- a/src/converter/pytorch_converter.py +++ b/src/converter/pytorch_converter.py @@ -433,13 +433,19 @@ def convert_ctrl_dep_to_data_dep( stack: List[ChakraNode] = [chakra_node] last_visited_non_gpu: Optional[ChakraNode] = None last_visited_any: Optional[ChakraNode] = None - + pg_name_to_nccl_ops: Dict[str, List[int]] = {} while stack: current_node = stack.pop() - if current_node.id in visited: - continue + visited.add(current_node.id) + + if current_node.type == COMM_COLL_NODE: + pg_name = json_node_map[current_node.id].pg_name + if pg_name in pg_name_to_nccl_ops: + current_node.data_deps.append(pg_name_to_nccl_ops[pg_name][-1]) + pg_name_to_nccl_ops[pg_name].append(current_node.id) + else: + pg_name_to_nccl_ops[pg_name] = [current_node.id] - visited.add(current_node.id) json_node = json_node_map.get(current_node.id) if not json_node: continue