From ebb7b75e046f1b85f7abe1ffb640169965907d63 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Fri, 12 Dec 2025 07:10:54 +0000 Subject: [PATCH] [Draft] Add fsdp_mesh in `load_spec` --- xtuner/v1/model/base.py | 198 ++++++++++++++---- .../compose/intern_s1/modeling_intern_s1.py | 2 +- .../compose/qwen3_vl/modeling_projector.py | 2 +- .../compose/qwen3_vl/modeling_qwen3_vl.py | 3 +- .../model/compose/qwen3_vl/modeling_vision.py | 5 +- xtuner/v1/model/dense/dense.py | 11 +- xtuner/v1/model/moe/moe.py | 11 +- xtuner/v1/utils/load_spec.py | 2 + 8 files changed, 175 insertions(+), 59 deletions(-) diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index 8e7a92aa0..2ab743c66 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -13,6 +13,11 @@ import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +from torch.distributed.fsdp import ( + CPUOffloadPolicy, + MixedPrecisionPolicy, + fully_shard, +) from cyclopts import Parameter from more_itertools import consume from pydantic import BaseModel as PydanticBaseModel @@ -458,6 +463,9 @@ def get_shard_placement(placements: tuple[Placement, ...]) -> Shard | None: load_spec_mapping[name] = load_spec self.load_spec_mapping = load_spec_mapping + # TODO: Since composibale model could modify the `load_spec_mapping`, here we maintain a copy of it make + # sure the model can update the original `load_spec_mapping` correctly. + self._load_spec_mapping = load_spec_mapping.copy() def _to_float8( self, @@ -498,7 +506,8 @@ def _get_shard_hf_param( def _get_hf_params(fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]]) -> list[torch.Tensor]: # Get fsdp unsharded params _tensor_list, _spec_list = list(zip(*fsdp_tensor_list)) - if self.fsdp_mesh is not None: + fsdp_mesh = self._get_fsdp_mesh(_spec_list) + if fsdp_mesh is not None: fsdp_unsharded_tensor_list = self._fsdp_foreach_allgather(_tensor_list, _spec_list) # type: ignore else: fsdp_unsharded_tensor_list = _tensor_list # type: ignore @@ -519,29 +528,47 @@ def _get_hf_params(fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]]) -> lis if bucket_size is None: bucket_size = self.config.hf_save_cfg.bucket_size - safetensor_size = 0 - tensor_list: list[tuple[torch.Tensor, LoadSpec]] = [] - name_list: list[str] = [] + fsdp_mesh_grouped_tensor_map: dict[ + DeviceMesh | None, + tuple[list[tuple[torch.Tensor, LoadSpec]], list[str], int], + ] = {} for param, load_spec in params: - local_tensor = param._local_tensor if isinstance(param, DTensor) else param + if isinstance(param, DTensor): + local_tensor = param._local_tensor + fsdp_mesh = load_spec.fsdp_mesh + else: + local_tensor = param + fsdp_mesh = None + local_tensor = local_tensor.to(dtype=dtype) + + if fsdp_mesh not in fsdp_mesh_grouped_tensor_map: + fsdp_mesh_grouped_tensor_map[fsdp_mesh] = ([], [], 0) + + tensor_list, name_list, safetensor_size = fsdp_mesh_grouped_tensor_map[fsdp_mesh] + tensor_size = self._get_tensor_size(param, dtype) if safetensor_size + tensor_size > bucket_size and tensor_list: hf_params = _get_hf_params(tensor_list) yield name_list, hf_params + fsdp_mesh_grouped_tensor_map.pop(fsdp_mesh) safetensor_size = tensor_size name_list = load_spec.hf_keys.copy() tensor_list = [(local_tensor, load_spec)] + + fsdp_mesh_grouped_tensor_map[fsdp_mesh] = (tensor_list, name_list, 0) continue safetensor_size += tensor_size tensor_list.append((local_tensor, load_spec)) name_list.append(load_spec.hf_keys[0]) + fsdp_mesh_grouped_tensor_map[fsdp_mesh] = (tensor_list, name_list, safetensor_size) - if tensor_list: - hf_params = _get_hf_params(tensor_list) - yield name_list, hf_params + if fsdp_mesh_grouped_tensor_map: + for tensor_list, name_list, _ in fsdp_mesh_grouped_tensor_map.values(): + hf_params = _get_hf_params(tensor_list) + yield name_list, hf_params def _get_fused_hf_param( self, @@ -563,7 +590,9 @@ def _get_hf_params( tensor_list: list[torch.Tensor] tensor_list, spec_list = list(zip(*fsdp_tensor_list)) # type: ignore[assignment] - if self.fsdp_mesh is not None: + + fsdp_mesh = self._get_fsdp_mesh(spec_list) + if fsdp_mesh is not None: fsdp_unshard_tensor_list = self._fsdp_foreach_allgather(tensor_list, spec_list) # type: ignore else: fsdp_unshard_tensor_list = tensor_list # type: ignore @@ -671,28 +700,47 @@ def _get_hf_params( if bucket_size is None: bucket_size = self.config.hf_save_cfg.bucket_size - safetensor_size = 0 - tensor_list: list[tuple[torch.Tensor, LoadSpec]] = [] - name_list: list[str] = [] + tensor_list: list[tuple[torch.Tensor, LoadSpec]] + name_list: list[str] + + fsdp_mesh_grouped_tensor_map: dict[ + DeviceMesh | None, + tuple[list[tuple[torch.Tensor, LoadSpec]], list[str], int], + ] = {} for param, load_spec in params: - local_tensor = param._local_tensor if isinstance(param, DTensor) else param + if isinstance(param, DTensor): + local_tensor = param._local_tensor + fsdp_mesh = load_spec.fsdp_mesh + else: + local_tensor = param + fsdp_mesh = None + + if fsdp_mesh not in fsdp_mesh_grouped_tensor_map: + fsdp_mesh_grouped_tensor_map[fsdp_mesh] = ([], [], 0) + + tensor_list, name_list, safetensor_size = fsdp_mesh_grouped_tensor_map[fsdp_mesh] local_tensor = local_tensor.bfloat16() tensor_size = self._get_tensor_size(param, dtype) if safetensor_size + tensor_size > bucket_size and tensor_list: hf_params, name_list = _get_hf_params(tensor_list, name_list) yield name_list, hf_params + + fsdp_mesh_grouped_tensor_map.pop(fsdp_mesh) safetensor_size = tensor_size name_list = load_spec.hf_keys.copy() tensor_list = [(local_tensor, load_spec)] + fsdp_mesh_grouped_tensor_map[fsdp_mesh] = ((tensor_list, name_list, 0)) continue safetensor_size += tensor_size tensor_list.append((local_tensor, load_spec)) name_list.extend(load_spec.hf_keys) + fsdp_mesh_grouped_tensor_map[fsdp_mesh] = (tensor_list, name_list, safetensor_size) - if tensor_list: - hf_params, name_list = _get_hf_params(tensor_list, name_list) - yield name_list, hf_params + if fsdp_mesh_grouped_tensor_map: + for tensor_list, name_list, _ in fsdp_mesh_grouped_tensor_map.values(): + hf_params, name_list = _get_hf_params(tensor_list, name_list) + yield name_list, hf_params def _get_same_hf_param( self, @@ -706,26 +754,40 @@ def _get_same_hf_param( if bucket_size is None: bucket_size = self.config.hf_save_cfg.bucket_size safetensor_size = 0 - tensor_list: list[torch.Tensor] = [] - load_spec_list: list[LoadSpec] = [] - name_list: list[str] = [] + tensor_list: list[torch.Tensor] + load_spec_list: list[LoadSpec] + name_list: list[str] buffer_tensor_list: list[torch.Tensor] = [] buffer_name_list: list[str] = [] + fsdp_mesh_grouped_tensor_map: dict[ + DeviceMesh | None, + tuple[list[torch.Tensor], list[LoadSpec], list[str], int], + ] = {} + for param, load_spec in params: if not isinstance(param, DTensor): # in case, param is a buffer of module, FSDP will not shard it, so it's not a DTensor buffer_tensor_list.append(param) buffer_name_list.append(load_spec.hf_keys[0]) continue - local_tensor = param._local_tensor if isinstance(param, DTensor) else param + local_tensor = param._local_tensor + fsdp_mesh = load_spec.fsdp_mesh + + if fsdp_mesh not in fsdp_mesh_grouped_tensor_map: + fsdp_mesh_grouped_tensor_map[fsdp_mesh] = ([], [], [], 0) + + tensor_list, load_spec_list, name_list, safetensor_size = fsdp_mesh_grouped_tensor_map[fsdp_mesh] + local_tensor = local_tensor.bfloat16() tensor_size = self._get_tensor_size(param, dtype) if safetensor_size + tensor_size > bucket_size and tensor_list: - if self.fsdp_mesh is not None: + fsdp_mesh = self._get_fsdp_mesh(load_spec_list) + if fsdp_mesh is not None: gathered_tensor_list = self._fsdp_foreach_allgather(tensor_list, load_spec_list) else: gathered_tensor_list = tensor_list + gathered_tensor_list = [ self.param_to_safetensor(safetensor, name) for safetensor, name in zip(gathered_tensor_list, name_list) @@ -736,29 +798,37 @@ def _get_same_hf_param( ) gathered_tensor_list = [t.to(device=device) for t in gathered_tensor_list] yield name_list, gathered_tensor_list + fsdp_mesh_grouped_tensor_map.pop(fsdp_mesh) + safetensor_size = tensor_size name_list = load_spec.hf_keys.copy() tensor_list = [local_tensor] load_spec_list = [load_spec] + + fsdp_mesh_grouped_tensor_map[fsdp_mesh] = ((tensor_list, load_spec_list, name_list, 0)) continue safetensor_size += tensor_size + tensor_list.append(local_tensor) name_list.append(load_spec.hf_keys[0]) load_spec_list.append(load_spec) + fsdp_mesh_grouped_tensor_map[fsdp_mesh] = (tensor_list, load_spec_list, name_list, safetensor_size) - if tensor_list: - if self.fsdp_mesh is not None: - gathered_tensor_list = self._fsdp_foreach_allgather(tensor_list, load_spec_list) - else: - gathered_tensor_list = tensor_list + if fsdp_mesh_grouped_tensor_map: + for tensor_list, load_spec_list, name_list, _ in fsdp_mesh_grouped_tensor_map.values(): + fsdp_mesh = self._get_fsdp_mesh(load_spec_list) + if fsdp_mesh is not None: + gathered_tensor_list = self._fsdp_foreach_allgather(tensor_list, load_spec_list) + else: + gathered_tensor_list = tensor_list - gathered_tensor_list = [ - self.param_to_safetensor(safetensor, name) for safetensor, name in zip(gathered_tensor_list, name_list) - ] - if dtype == torch.float8_e4m3fn: - gathered_tensor_list, name_list = self._to_float8(gathered_tensor_list, name_list, tensor_list, dtype) - gathered_tensor_list = [t.to(device=device) for t in gathered_tensor_list] - yield name_list, gathered_tensor_list + gathered_tensor_list = [ + self.param_to_safetensor(safetensor, name) for safetensor, name in zip(gathered_tensor_list, name_list) + ] + if dtype == torch.float8_e4m3fn: + gathered_tensor_list, name_list = self._to_float8(gathered_tensor_list, name_list, tensor_list, dtype) + gathered_tensor_list = [t.to(device=device) for t in gathered_tensor_list] + yield name_list, gathered_tensor_list if buffer_tensor_list: yield buffer_name_list, buffer_tensor_list @@ -1036,11 +1106,12 @@ def _load_same_hf_param( return [hf_key] loaded_tensor = loaded_tensor.to(local_tensor.device) + fsdp_mesh = load_spec.fsdp_mesh - if self.fsdp_mesh is not None and isinstance(param, nn.Parameter): + if fsdp_mesh is not None and isinstance(param, nn.Parameter): shape_before_fsdp = load_spec.shape _, _offset = compute_local_shape_and_global_offset( - shape_before_fsdp, self.fsdp_mesh, [Shard(self.FSDP_SHARD_DIM)] + shape_before_fsdp, fsdp_mesh, [Shard(self.FSDP_SHARD_DIM)] ) fsdp_start = _offset[self.FSDP_SHARD_DIM] fsdp_end = fsdp_start + local_tensor.shape[self.FSDP_SHARD_DIM] @@ -1072,7 +1143,8 @@ def _load_fused_hf_param( local_tensor = param._local_tensor if isinstance(param, DTensor) else param assert load_spec.dim == self.FSDP_SHARD_DIM, "Only support FSDP and model parallel sharding at the same dim!" - if self.fsdp_mesh is not None: + fsdp_mesh = load_spec.fsdp_mesh + if fsdp_mesh is not None: shape_before_fsdp = load_spec.shape if is_float8_weight(local_tensor): # fp8 weights may be padded, so we need to calculate the hf_key_size base on local_tensor._ori_shape @@ -1092,7 +1164,7 @@ def _load_fused_hf_param( ) hf_key_size = int(hf_key_size) _, _offset = compute_local_shape_and_global_offset( - shape_before_fsdp, self.fsdp_mesh, [Shard(self.FSDP_SHARD_DIM)] + shape_before_fsdp, fsdp_mesh, [Shard(self.FSDP_SHARD_DIM)] ) fsdp_start = _offset[self.FSDP_SHARD_DIM] fsdp_end = fsdp_start + local_tensor.shape[self.FSDP_SHARD_DIM] @@ -1157,11 +1229,12 @@ def _load_shard_hf_param( assert load_spec.shard_start is not None and load_spec.shard_end is not None, ( "load_spec.shard_start and load_spec.shard_end should not be None for sharded params" ) + fsdp_mesh = load_spec.fsdp_mesh - if self.fsdp_mesh is not None: + if fsdp_mesh is not None: shape_before_fsdp = load_spec.shape _, _offset = compute_local_shape_and_global_offset( - shape_before_fsdp, self.fsdp_mesh, [Shard(self.FSDP_SHARD_DIM)] + shape_before_fsdp, fsdp_mesh, [Shard(self.FSDP_SHARD_DIM)] ) fsdp_start = _offset[self.FSDP_SHARD_DIM] fsdp_end = fsdp_start + local_tensor.shape[self.FSDP_SHARD_DIM] @@ -1192,7 +1265,8 @@ def _has_meta_param(self, module: nn.Module, recurse: bool = False) -> bool: def _fsdp_foreach_allgather( self, tensor_list: list[torch.Tensor], load_spec_list: list[LoadSpec] ) -> list[torch.Tensor]: - assert self.fsdp_mesh is not None, "Internal Error, fsdp_mesh should not be None" + fsdp_mesh = self._get_fsdp_mesh(load_spec_list) + assert fsdp_mesh is not None, "Internal Error, fsdp_mesh should not be None" origin_fsdp_size = [] padded_tensor_list = [] @@ -1424,3 +1498,47 @@ def _maybe_enable_compile(self, compile_cfg: dict[str, TorchCompileOption]): for target, option in compile_cfg.items(): self._compile_overwrite(target, option) + + def _get_fsdp_mesh(self, load_spec_list: list[LoadSpec]) -> DeviceMesh | None: + fsdp_mesh = load_spec_list[0].fsdp_mesh + assert all(load_spec.fsdp_mesh == fsdp_mesh for load_spec in load_spec_list), ( + "Internal Error, all load_spec in the same module should have the same fsdp_mesh" + ) + return fsdp_mesh + + def _fully_shard( + self, + module: nn.Module, + mesh: DeviceMesh, + mp_policy: MixedPrecisionPolicy, + reshard_after_forward: bool = True, + offload_policy: CPUOffloadPolicy | None = None, + ): + fully_shard( + module, + mesh=mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + offload_policy=offload_policy, + ) + if self.hsdp_mesh is not None: + mesh = mesh[f"{self.config.mesh_prefix}.hsdp_shard"] + self._update_load_spec(module, mesh) + + def _update_load_spec(self, module: nn.Module, fsdp_mesh: DeviceMesh): + for name, submodule in self.named_modules(): + if submodule is module: + module_name = name + break + else: + raise RuntimeError(f"Internal Error, {module} not found in named_modules") + + for name, _ in module.named_parameters(): + name = self._clean_param_name(name) + name = f"{module_name}.{name}" if module_name else name + if name not in self._load_spec_mapping: + raise RuntimeError(f"Parameter {name} not found in load_spec_mapping") + load_spec = self.load_spec_mapping[name] + load_spec.fsdp_mesh = fsdp_mesh + + diff --git a/xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py b/xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py index 98ae0ed95..f19b61eca 100644 --- a/xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py +++ b/xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py @@ -136,7 +136,7 @@ def fully_shard( # Note: 非常关键,不能删除这个 assert assert self.fsdp_mesh is not None - fully_shard( + self._fully_shard( self, mesh=self.fsdp_mesh, mp_policy=mp_policy, diff --git a/xtuner/v1/model/compose/qwen3_vl/modeling_projector.py b/xtuner/v1/model/compose/qwen3_vl/modeling_projector.py index 958f4bebb..2e5d3ae77 100644 --- a/xtuner/v1/model/compose/qwen3_vl/modeling_projector.py +++ b/xtuner/v1/model/compose/qwen3_vl/modeling_projector.py @@ -108,7 +108,7 @@ def fully_shard( for param in self.parameters(): param.requires_grad = False - fully_shard( + self._fully_shard( self, mesh=self.fsdp_mesh, mp_policy=mp_policy, diff --git a/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py b/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py index a1b57d030..d431059d4 100644 --- a/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py +++ b/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py @@ -8,7 +8,6 @@ from torch.distributed.fsdp import ( CPUOffloadPolicy, MixedPrecisionPolicy, - fully_shard, FSDPModule, ) import torch.distributed as dist @@ -90,7 +89,7 @@ def fully_shard( # Note: 非常关键,不能删除这个 assert assert self.fsdp_mesh is not None - fully_shard( + self._fully_shard( self, mesh=self.fsdp_mesh, mp_policy=mp_policy, diff --git a/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py b/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py index 21aef8434..8c5754268 100644 --- a/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py +++ b/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py @@ -14,7 +14,6 @@ from torch.distributed.fsdp import ( CPUOffloadPolicy, MixedPrecisionPolicy, - fully_shard, ) from transformers.models.llama.modeling_llama import repeat_kv from xtuner.v1.float8.float8_handler import Float8Handler @@ -349,7 +348,7 @@ def fully_shard( self.blocks[layer_idx] = layer - fully_shard( + self._fully_shard( layer, mesh=self.fsdp_mesh, mp_policy=mp_policy, @@ -362,7 +361,7 @@ def fully_shard( for layer_cur, layer_next in zip(self.blocks[:-1], self.blocks[1:]): layer_cur.set_modules_to_forward_prefetch([layer_next]) - fully_shard( + self._fully_shard( self, mesh=self.fsdp_mesh, mp_policy=mp_policy, diff --git a/xtuner/v1/model/dense/dense.py b/xtuner/v1/model/dense/dense.py index 02298350b..b3472ac76 100644 --- a/xtuner/v1/model/dense/dense.py +++ b/xtuner/v1/model/dense/dense.py @@ -11,7 +11,6 @@ from torch.distributed.fsdp import ( CPUOffloadPolicy, MixedPrecisionPolicy, - fully_shard, ) from torch.distributed.tensor import DTensor from tqdm import tqdm @@ -223,7 +222,7 @@ def fully_shard( layer.forward = torch.compile(layer.forward, fullgraph=True) self.layers[str(layer_idx)] = layer - fully_shard( + self._fully_shard( layer, mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, @@ -237,7 +236,7 @@ def fully_shard( ): layer_cur.set_modules_to_forward_prefetch([layer_next]) # type: ignore - fully_shard( + self._fully_shard( self.embed_tokens, mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, @@ -245,7 +244,7 @@ def fully_shard( offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, ) - fully_shard( + self._fully_shard( self.norm, mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, @@ -253,7 +252,7 @@ def fully_shard( offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, ) - fully_shard( + self._fully_shard( self.lm_head, mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, @@ -261,7 +260,7 @@ def fully_shard( offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, ) - fully_shard( + self._fully_shard( self, mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index a5fa5c0d8..3a9e7de44 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -18,7 +18,6 @@ from torch.distributed.fsdp import ( CPUOffloadPolicy, MixedPrecisionPolicy, - fully_shard, ) from torch.distributed.tensor import DTensor, Replicate, distribute_tensor from tqdm import tqdm @@ -733,7 +732,7 @@ def fully_shard( reshard_after_forward = False else: reshard_after_forward = self.fsdp_config.reshard_after_forward - fully_shard( + self._fully_shard( layer, mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, @@ -747,7 +746,7 @@ def fully_shard( ): layer_cur.set_modules_to_forward_prefetch([layer_next]) # type: ignore - fully_shard( + self._fully_shard( self.embed_tokens, mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, @@ -755,7 +754,7 @@ def fully_shard( offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, ) - fully_shard( + self._fully_shard( self.norm, mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, @@ -763,7 +762,7 @@ def fully_shard( offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, ) - fully_shard( + self._fully_shard( self.lm_head, mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, @@ -771,7 +770,7 @@ def fully_shard( offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, ) - fully_shard( + self._fully_shard( self, mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, diff --git a/xtuner/v1/utils/load_spec.py b/xtuner/v1/utils/load_spec.py index ef95585fe..e0d2619d5 100644 --- a/xtuner/v1/utils/load_spec.py +++ b/xtuner/v1/utils/load_spec.py @@ -1,4 +1,5 @@ import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh from pydantic import BaseModel, ConfigDict from .enum_helper import StrEnum @@ -21,6 +22,7 @@ class LoadSpec(BaseModel): shard_start: int | None = None shard_end: int | None = None group: dist.ProcessGroup | None = None + fsdp_mesh: DeviceMesh | None = None # TODO: (yehaochen) Only a workaround def model_post_init(self, _) -> None: if self.load_enum == LoadEnum.SAME: