diff --git a/src/trace_link/chakra_host_trace_loader.py b/src/trace_link/chakra_host_trace_loader.py index 8b2723b3..ef173f94 100644 --- a/src/trace_link/chakra_host_trace_loader.py +++ b/src/trace_link/chakra_host_trace_loader.py @@ -1,9 +1,11 @@ import logging import sys -from typing import List +from typing import Any, Callable, Dict, List, Tuple +from et_replay.execution_trace import EXECUTION_TRACE_THREAD_ANNOTATION as THREAD_ANNOTATION +from et_replay.execution_trace import ExecutionTrace as PyTorchTrace from et_replay.execution_trace import Node as PyTorchOperator -from et_replay.utils import load_execution_trace_file +from et_replay.utils import read_dictionary_from_json_file # Increase the recursion limit for deep Chakra host execution traces. sys.setrecursionlimit(10**6) @@ -12,25 +14,31 @@ class ChakraHostTraceLoader: """Loads Chakra host traces.""" - def load(self, chakra_host_trace_file: str) -> List[PyTorchOperator]: + def load(self, + chakra_host_trace_file: str, + connect_host_trace: bool) -> Tuple[List[PyTorchOperator], Dict[str, Any]]: """ Load and process the Chakra Host Execution Trace. Args: chakra_host_trace_file (str): Path to the PyTorch execution trace file. + connect_host_trace (bool): Connect host nodes with missing parents to the corresponding thread root node. Returns: - List[PyTorchOperator]: List of PyTorch operators. + Tuple[List[PyTorchOperator], Dict[str, Any]]: Tuple containing list of PyTorch operators and host trace. """ logging.debug(f"Starting to load Chakra host execution trace from file: {chakra_host_trace_file}.") - chakra_host_trace = load_execution_trace_file(chakra_host_trace_file) + host_trace = read_dictionary_from_json_file(chakra_host_trace_file) - root_node = chakra_host_trace.get_nodes()[1] # Root node is usually 1-based + host_ops = self._create_host_ops(host_trace, connect_host_trace) + root_node = host_ops.get(1) # Root node is usually 1-based + chakra_host_ops = self.extract_chakra_host_ops(root_node) + logging.debug(f"Extracted {len(chakra_host_ops)} operators from Chakra host execution trace.") logging.debug("Chakra host execution trace has been loaded and processed successfully.") - return chakra_host_ops + return chakra_host_ops, host_trace def extract_chakra_host_ops(self, node: PyTorchOperator) -> List[PyTorchOperator]: """ @@ -55,3 +63,82 @@ def traverse(node: PyTorchOperator): traverse(node) logging.debug(f"Traversed {len(nodes)} nodes from root node ID: {node.id}") return sorted(nodes, key=lambda x: x.id) + + def _create_host_ops(self, host_trace: Dict[str, Any], connect_host_trace: bool) -> Dict[int, PyTorchOperator]: + """ + Create host operators from the provided host trace. + + This method processes the host trace, extracts nodes, and creates PyTorchOperator instances based on the schema + version specified in the host trace. + + Args: + host_trace (Dict[str, Any]): The host trace dictionary. + connect_host_trace (bool): Connect host nodes with missing parents to the corresponding thread root node. + + Returns: + Dict[int, PyTorchOperator]: A dictionary mapping operator IDs to PyTorchOperator instances. + """ + schema: str = host_trace["schema"] + pid: int = host_trace["pid"] + nodes: List[Dict[str, Any]] = host_trace["nodes"] + + create_operator = self._get_operator_creation_method(schema) + if create_operator is None: + raise ValueError( + f"No corresponding node creation function found for schema version {schema}" + ) + + host_ops: Dict[int, PyTorchOperator] = {} + thread_roots: Dict[int, int] = {} + for node in nodes: + host_op = create_operator(pid, node) + host_ops[host_op.id] = host_op + if host_op.parent_id == 1 and THREAD_ANNOTATION in host_op.name: + thread_roots[host_op.tid] = host_op.id + + for host_op in host_ops.values(): + if host_op.parent_id in host_ops and host_op.id != 1: + parent = host_ops[host_op.parent_id] + host_op.set_parent(parent) + parent.add_child(host_op) + elif connect_host_trace is True: # connect orphans to the thread root + parent_id = thread_roots.get(host_op.tid, None) + if parent_id is not None: + host_op.parent_id = parent_id + parent = host_ops[parent_id] + host_op.set_parent(parent) + parent.add_child(host_op) + node = next(filter(lambda n: n["id"] == host_op.id, nodes), None) + if node is not None: + node["ctrl_deps"] = parent_id + + for host_op in host_ops.values(): + host_op.sort_children() + + return host_ops + + def _get_operator_creation_method(self, schema: str) -> Callable[[int, Dict[str, Any]], PyTorchOperator] | None: + """ + Get the operator creation method for the specified schema version. + + Args: + schema (str): The schema version of the host trace. + + Returns: + Callable[[int, Dict[str, Any]], PyTorchOperator] | None: Operator creation functor for the schema version, + or None if no functor is found. + """ + node_creation_func = { + "1.0.1": PyTorchTrace._create_node_v1_0_1, + "1.0.2-chakra.0.0.4": PyTorchTrace._create_node_v1_0_2_chakra_0_0_4, + # 1.0.3 expands pg name to so it use the same parser as 1.0.2 + "1.0.3-chakra.0.0.4": PyTorchTrace._create_node_v1_0_2_chakra_0_0_4, + # 1.0.4 adds PT2 kernel backend and kernel file + "1.0.4-chakra.0.0.4": PyTorchTrace._create_node_v1_0_2_chakra_0_0_4, + # 1.1.0 includes new comm args in record_param_comms + "1.1.0-chakra.0.0.4": PyTorchTrace._create_node_v1_0_2_chakra_0_0_4, + # 1.1.1 includes tensor strides + "1.1.1-chakra.0.0.4": PyTorchTrace._create_node_v1_1_1_chakra_0_0_4, + # Add future versions here + } + return node_creation_func.get(schema) diff --git a/src/trace_link/trace_link.py b/src/trace_link/trace_link.py index 12074df5..99299e5e 100644 --- a/src/trace_link/trace_link.py +++ b/src/trace_link/trace_link.py @@ -37,6 +37,11 @@ def main() -> None: required=True, help="Path for the output Chakra host + device trace in the JSON format", ) + parser.add_argument("--connect-host-trace", + type=bool, + default=False, + help="Whether to connect host nodes with missing parents to the corresponding thread root node.", + ) parser.add_argument("--log-level", default="INFO", type=str, help="Log output verbosity level") args = parser.parse_args() @@ -44,7 +49,7 @@ def main() -> None: logging.basicConfig(level=args.log_level.upper()) linker = TraceLinker() - linker.link(args.rank, args.chakra_host_trace, args.chakra_device_trace, args.output_file) + linker.link(args.rank, args.chakra_host_trace, args.chakra_device_trace, args.output_file, args.connect_host_trace) logging.info(f"Linking process successful. Output file is available at {args.output_file}.") logging.info("Please run the chakra_converter for further postprocessing.") diff --git a/src/trace_link/trace_linker.py b/src/trace_link/trace_linker.py index af247dd7..df2a6117 100644 --- a/src/trace_link/trace_linker.py +++ b/src/trace_link/trace_linker.py @@ -4,7 +4,7 @@ import logging import os from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from et_replay.execution_trace import ( EXECUTION_TRACE_PROCESS_ANNOTATION, @@ -36,7 +36,11 @@ def __init__(self) -> None: self.chakra_device_trace_loader = ChakraDeviceTraceLoader() self.id_assigner = UniqueIdAssigner() - def link(self, rank: int, chakra_host_trace: str, chakra_device_trace: str, output_file: str) -> None: + def link(self, rank: int, + chakra_host_trace: str, + chakra_device_trace: str, + output_file: str, + connect_host_trace: bool) -> None: """ Links Chakra host execution traces (ET) and Chakra device ET to generate Chakra host + device ET. @@ -45,8 +49,9 @@ def link(self, rank: int, chakra_host_trace: str, chakra_device_trace: str, outp chakra_host_trace (str): Path to the Chakra host execution trace file. chakra_device_trace (str): Path to the Kineto trace file. output_file (str): Path for the output nyTorch execution trace plus file. + connect_host_trace (bool): Connect host nodes with missing parents to the corresponding thread root node. """ - host_ops = self.chakra_host_trace_loader.load(chakra_host_trace) + host_ops, host_trace = self.chakra_host_trace_loader.load(chakra_host_trace, connect_host_trace) ( kineto_cpu_ops, @@ -77,7 +82,7 @@ def link(self, rank: int, chakra_host_trace: str, chakra_device_trace: str, outp ) chakra_execution_trace_plus_data = self.link_traces( - chakra_host_trace, + host_trace, host_ops, kineto_cpu_ops, sorted_kineto_cpu_ops, @@ -382,7 +387,7 @@ def find_closest_start_kineto_op( def link_traces( self, - chakra_host_trace: str, + host_trace: Dict[str, Any], host_ops: List[PyTorchOperator], kineto_cpu_ops: List[KinetoOperator], sorted_kineto_cpu_ops: List[KinetoOperator], @@ -399,7 +404,7 @@ def link_traces( Link Chakra Host ET and Chakra Device ET to produce an enhanced Chakra ET (ET +). Args: - chakra_host_trace (str): Path to the Chakra host execution trace file. + host_trace (Dict[str, Any]): The Chakra host execution trace. host_ops (List[PyTorchOperator]): List of Chakra host operators. kineto_cpu_ops (List[KinetoOperator]): List of Kineto CPU operators. sorted_kineto_cpu_ops (List[KinetoOperator]): Sorted list of Kineto CPU operators. @@ -448,7 +453,7 @@ def link_traces( kineto_external_id_to_kineto_op_map, ) chakra_execution_trace_plus_data = self.construct_et_plus_data( - chakra_host_trace, + host_trace, host_op_id_to_kineto_ops_map, host_op_id_to_inclusive_dur_map, host_op_id_to_exclusive_dur_map, @@ -813,7 +818,7 @@ def link_gpu_ops(self, host_op: PyTorchOperator, kineto_gpu_ops: List[KinetoOper def construct_et_plus_data( self, - chakra_host_trace: str, + host_trace: Dict[str, Any], host_op_id_to_kineto_ops_map: Dict[int, List[KinetoOperator]], host_op_id_to_inclusive_dur_map: Dict[int, int], host_op_id_to_exclusive_dur_map: Dict[int, int], @@ -827,7 +832,7 @@ def construct_et_plus_data( offering a comprehensive view of the execution. Args: - chakra_host_trace (str): Path to the Chakra host execution trace file. + host_trace (Dict[str, Any]): The Chakra host execution trace. host_op_id_to_kineto_ops_map (Dict[int, List[KinetoOperator]]): Map from Chakra host op IDs to Kineto GPU ops. host_op_id_to_inclusive_dur_map (Dict[int, int]): Inclusive duration map for Chakra host ops. @@ -840,10 +845,8 @@ def construct_et_plus_data( Dict: The constructed ET+ data. """ logging.debug("Constructing ET+ data.") - with open(chakra_host_trace, "r") as file: - pytorch_et_data = json.load(file) - sorted_nodes = sorted(pytorch_et_data["nodes"], key=lambda x: x["id"]) + sorted_nodes = sorted(host_trace["nodes"], key=lambda x: x["id"]) gpu_ops = [] for op in sorted_nodes: gpu_ops += self.process_op_and_dependents( @@ -854,7 +857,7 @@ def construct_et_plus_data( host_op_id_to_timestamp_map, host_op_id_to_inter_thread_dep_map, ) - pytorch_et_data["nodes"] += gpu_ops + host_trace["nodes"] += gpu_ops # Add sync dependencies sync_dep_mapping = {} @@ -867,7 +870,7 @@ def construct_et_plus_data( del gpu_op["sync_dep_to"] # Update parent-child relationships with new IDs - sorted_nodes = sorted(pytorch_et_data["nodes"], key=lambda x: x["id"]) + sorted_nodes = sorted(host_trace["nodes"], key=lambda x: x["id"]) for op in sorted_nodes: for key in sync_dep_mapping: if self.id_assigner.lookup_new_id(key) == op["id"]: @@ -875,7 +878,7 @@ def construct_et_plus_data( if "ctrl_deps" in op: op["ctrl_deps"] = self.id_assigner.assign_or_retrieve_id(op["ctrl_deps"]) - return pytorch_et_data + return host_trace def process_op_and_dependents( self, diff --git a/tests/trace_link/test_trace_linker.py b/tests/trace_link/test_trace_linker.py index 8430867e..4174a982 100644 --- a/tests/trace_link/test_trace_linker.py +++ b/tests/trace_link/test_trace_linker.py @@ -594,7 +594,7 @@ def test_construct_et_plus_data(mock_json_load, mock_open, mock_process_op_and_d host_op_id_to_inter_thread_dep_map = {1: None, 2: None} pytorch_et_plus_data = trace_linker.construct_et_plus_data( - "path/to/pytorch_et_file", + mock_json_load.return_value, host_op_id_to_kineto_ops_map, host_op_id_to_inclusive_dur_map, host_op_id_to_exclusive_dur_map,