From d218c3c8276fa90380196859d0b34a02b4cede70 Mon Sep 17 00:00:00 2001 From: theodorbadea Date: Thu, 3 Apr 2025 11:53:39 +0000 Subject: [PATCH 1/3] add collective dependencies as data_deps --- src/converter/pytorch_converter.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/converter/pytorch_converter.py b/src/converter/pytorch_converter.py index b5ace295..dc177197 100644 --- a/src/converter/pytorch_converter.py +++ b/src/converter/pytorch_converter.py @@ -433,9 +433,16 @@ def convert_ctrl_dep_to_data_dep( stack: List[ChakraNode] = [chakra_node] last_visited_non_gpu: Optional[ChakraNode] = None last_visited_any: Optional[ChakraNode] = None - + pg_name_to_nccl_ops: Dict[str, List[int]] = {} while stack: - current_node = stack.pop() + current_node = stack.pop() + if current_node.type == COMM_COLL_NODE: + pg_name = json_node_map[current_node.id].pg_name + if pg_name in pg_name_to_nccl_ops.keys(): + current_node.data_deps.append(pg_name_to_nccl_ops[pg_name][-1]) + pg_name_to_nccl_ops[pg_name].append(current_node.id) + else: + pg_name_to_nccl_ops[pg_name] = [current_node.id] if current_node.id in visited: continue From d0f259acbae47156dad039944e2f1cec84577522 Mon Sep 17 00:00:00 2001 From: theodorbadea Date: Thu, 3 Apr 2025 12:27:08 +0000 Subject: [PATCH 2/3] fix lint error --- src/converter/pytorch_converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/converter/pytorch_converter.py b/src/converter/pytorch_converter.py index dc177197..2bcc8cda 100644 --- a/src/converter/pytorch_converter.py +++ b/src/converter/pytorch_converter.py @@ -438,7 +438,7 @@ def convert_ctrl_dep_to_data_dep( current_node = stack.pop() if current_node.type == COMM_COLL_NODE: pg_name = json_node_map[current_node.id].pg_name - if pg_name in pg_name_to_nccl_ops.keys(): + if pg_name in pg_name_to_nccl_ops: current_node.data_deps.append(pg_name_to_nccl_ops[pg_name][-1]) pg_name_to_nccl_ops[pg_name].append(current_node.id) else: From 89d82661630a58789a5d2b30a1ec47d49024160d Mon Sep 17 00:00:00 2001 From: theodorbadea Date: Mon, 7 Apr 2025 10:41:13 +0000 Subject: [PATCH 3/3] remove redundant check --- src/converter/pytorch_converter.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/converter/pytorch_converter.py b/src/converter/pytorch_converter.py index 2bcc8cda..3bcde903 100644 --- a/src/converter/pytorch_converter.py +++ b/src/converter/pytorch_converter.py @@ -435,7 +435,9 @@ def convert_ctrl_dep_to_data_dep( last_visited_any: Optional[ChakraNode] = None pg_name_to_nccl_ops: Dict[str, List[int]] = {} while stack: - current_node = stack.pop() + current_node = stack.pop() + visited.add(current_node.id) + if current_node.type == COMM_COLL_NODE: pg_name = json_node_map[current_node.id].pg_name if pg_name in pg_name_to_nccl_ops: @@ -443,10 +445,7 @@ def convert_ctrl_dep_to_data_dep( pg_name_to_nccl_ops[pg_name].append(current_node.id) else: pg_name_to_nccl_ops[pg_name] = [current_node.id] - if current_node.id in visited: - continue - visited.add(current_node.id) json_node = json_node_map.get(current_node.id) if not json_node: continue