Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,18 @@ def grad_accumulation_steps(self, data_batches_len: int):
intra_layer_micro_batch = self.intra_layer_micro_batch
return data_batches_len // intra_layer_micro_batch

# this method can be called outside, e.g., at the beginning of compute_actor_logprobs or compute_ref_logprobs during rl training
def maybe_precompute_float8_dynamic_scale_for_fsdp(self):
if self.float8_handler is not None and self.float8_handler.enabled:
self.float8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model)

def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]:
"""Perform a training step with the given data batches and mesh.

Args:
data_batches (List[Dict]): The input data batches for the training step.
"""
if self.float8_handler is not None and self.float8_handler.enabled:
self.float8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model)
self.maybe_precompute_float8_dynamic_scale_for_fsdp()

loss_log: LossLog = {} # type: ignore[typeddict-item]
other_log: OtherLog = {} # type: ignore[typeddict-item]
Expand Down
16 changes: 10 additions & 6 deletions xtuner/v1/engine/vision_compose_train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,23 @@ def save_hf(self, hf_dir: str, save_dtype: torch.dtype = torch.bfloat16):
if self._processor is not None:
self._processor.save_pretrained(hf_dir)

def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]:
"""Perform a training step with the given data batches and mesh.

Args:
data_batches (List[Dict]): The input data batches for the training step.
"""
# this method can be called outside, e.g., at the beginning of compute_actor_logprobs or compute_ref_logprobs during rl training
def maybe_precompute_float8_dynamic_scale_for_fsdp(self):
if self.llm_float8_handler is not None and self.llm_float8_handler.enabled:
self.llm_float8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model.language_model)
if self.vision_float8_handler is not None and self.vision_float8_handler.enabled:
self.vision_float8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model.vision_tower)
if self.projector_float8_handler is not None and self.projector_float8_handler.enabled:
self.projector_float8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model.multi_modal_projector)

def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]:
"""Perform a training step with the given data batches and mesh.

Args:
data_batches (List[Dict]): The input data batches for the training step.
"""
self.maybe_precompute_float8_dynamic_scale_for_fsdp()

loss_log: LossLog = {} # type: ignore[typeddict-item]
other_log: OtherLog = {} # type: ignore[typeddict-item]
intra_layer_micro_batch = self.intra_layer_micro_batch
Expand Down
16 changes: 9 additions & 7 deletions xtuner/v1/float8/float8_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def default_grouped_linear_filter_fn(mod: nn.Module, fqn: str):

# handler 要跟 Engine 一一对应?
class Float8Handler:
scaling_granularity_gemm: ScalingGranularity
scaling_granularity_grouped_gemm: ScalingGranularity
scaling_granularity_gemm: ScalingGranularity | None
scaling_granularity_grouped_gemm: ScalingGranularity | None
fsdp_mesh: Optional[DeviceMesh] = None
tilewise_reduce_mesh_devided_64: Optional[DeviceMesh] = None
tilewise_reduce_mesh_mapping: Dict[Tuple[int, int], DeviceMesh] = {}
Expand All @@ -61,12 +61,14 @@ def __init__(
)
return

assert scaling_granularity_gemm in (ScalingGranularity.TILEWISE, ScalingGranularity.TENSORWISE), (
"scaling_granularity_gemm must be TILEWISE or TENSORWISE."
)
assert scaling_granularity_grouped_gemm in (ScalingGranularity.TILEWISE, ScalingGranularity.TENSORWISE), (
"scaling_granularity_grouped_gemm must be TILEWISE or TENSORWISE."
assert scaling_granularity_gemm in (ScalingGranularity.TILEWISE, ScalingGranularity.TENSORWISE, None), (
"scaling_granularity_gemm must be TILEWISE, TENSORWISE or None."
)
assert scaling_granularity_grouped_gemm in (
ScalingGranularity.TILEWISE,
ScalingGranularity.TENSORWISE,
None,
), "scaling_granularity_grouped_gemm must be TILEWISE or TENSORWISE."

self.scaling_granularity_gemm = scaling_granularity_gemm
self.scaling_granularity_grouped_gemm = scaling_granularity_grouped_gemm
Expand Down
6 changes: 3 additions & 3 deletions xtuner/v1/module/grouped_linear/moe_group_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor, Shard, distribute_tensor

from xtuner.v1.float8.config import ScalingGranularity
from xtuner.v1.float8.config import Float8Config, ScalingGranularity
from xtuner.v1.float8.float8_gmm_tile_wise import TileWiseFloat8GroupedLinear
from xtuner.v1.ops import group_gemm

Expand Down Expand Up @@ -55,10 +55,10 @@ def build_grouped_linear(
num_routed_experts: int,
moe_bias: bool = False,
ep_mesh: DeviceMesh | None = None,
float8_cfg=None,
float8_cfg: Float8Config | None = None,
):
"""Build a grouped linear layer with optional float8 support."""
if float8_cfg is None:
if float8_cfg is None or float8_cfg.scaling_granularity_gemm is None:
return GroupedLinear(in_features, out_features, num_routed_experts, moe_bias=moe_bias, ep_mesh=ep_mesh)
elif float8_cfg.scaling_granularity_grouped_gemm == ScalingGranularity.TILEWISE:
return TileWiseFloat8GroupedLinear(
Expand Down
6 changes: 3 additions & 3 deletions xtuner/v1/module/linear/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.distributed.tensor import DTensor
from torch.nn import functional as F

from xtuner.v1.float8.config import ScalingGranularity
from xtuner.v1.float8.config import Float8Config, ScalingGranularity
from xtuner.v1.float8.float8_linear_tensor_wise import TensorWiseFloat8Linear
from xtuner.v1.float8.float8_linear_tile_wise import TileWiseFloat8Linear

Expand All @@ -30,10 +30,10 @@ def build_linear(
bias: bool = True,
device=None,
dtype=None,
float8_cfg=None,
float8_cfg: Float8Config | None = None,
) -> nn.Module:
"""Build a linear layer with optional float8 support."""
if float8_cfg is None:
if float8_cfg is None or float8_cfg.scaling_granularity_gemm is None:
return _Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
elif float8_cfg.scaling_granularity_gemm is ScalingGranularity.TILEWISE:
return TileWiseFloat8Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
Expand Down
2 changes: 2 additions & 0 deletions xtuner/v1/rl/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,8 @@ def _init_data_mesh(
def compute_actor_logprobs(
self, seq_ctx_list: list[SequenceContext], loss_ctx_input_list: list[RLLossContextInputItem]
) -> list[RLLossContextInputItem]:
# precompute float8 dynamic scale only once
self._engine.maybe_precompute_float8_dynamic_scale_for_fsdp()
for seq_ctx, loss_ctx_input in zip(seq_ctx_list, loss_ctx_input_list):
output = self._engine.forward_only(seq_ctx=seq_ctx)
loss_ctx_input.old_logprobs = gather_logprobs(output["logits"], loss_ctx_input.shifted_labels)
Expand Down
Loading