Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self, model_path, cache_dir=None, from_hub="huggingface"):
self.use_safetensors = False
elif "model.safetensors" in os.listdir(self.model_path):
with safe_open(os.path.join(self.model_path, "model.safetensors"), framework="pt") as f:
self.weight_map = {k: "model.safetensors" for k in f.keys()}
self.weight_map = dict.fromkeys(f.keys(), "model.safetensors")
self.use_safetensors = True
else:
raise FileNotFoundError
Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ def _save_hf(self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16)

safetensor_index += 1
safetensor_name = f"model-{safetensor_index:04d}-fused-save_rank{save_rank}.safetensors"
weight_map.update({name: safetensor_name for name in name_list})
weight_map.update(dict.fromkeys(name_list, safetensor_name))
assert save_executor is not None, "Internal Error, save_executor should not be None"
future = save_executor.submit(
_save_file,
Expand Down
59 changes: 36 additions & 23 deletions xtuner/v1/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator, model_validator
from torch.distributed import init_process_group
from torch.distributed.device_mesh import init_device_mesh
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, LinearLR, SequentialLR
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, LinearLR, LRScheduler, SequentialLR
from typing_extensions import NotRequired, Self, TypedDict

from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
Expand Down Expand Up @@ -764,7 +764,7 @@ def build_engine(
engine.model.set_hf(model_path)
return engine

def build_lr_scheduler(self, lr_cfg: LRConfig, scheduler_step: int) -> torch.optim.lr_scheduler.LRScheduler:
def build_lr_scheduler(self, lr_cfg: LRConfig, scheduler_step: int) -> LRScheduler:
"""Build the learning rate scheduler.

Args:
Expand All @@ -774,36 +774,49 @@ def build_lr_scheduler(self, lr_cfg: LRConfig, scheduler_step: int) -> torch.opt
torch.optim.lr_scheduler.LRScheduler: Configured learning rate scheduler.
"""
if lr_cfg.warmup_ratio < 1:
warmup_steps = int(lr_cfg.warmup_ratio * scheduler_step)
warmup_step = int(lr_cfg.warmup_ratio * scheduler_step)
else:
warmup_steps = int(lr_cfg.warmup_ratio)
warmup_step = int(lr_cfg.warmup_ratio)

def warmup_fn(x):
return x / warmup_steps if x < warmup_steps else 1
return x / warmup_step if x < warmup_step else 1

warmup_scheduler = LambdaLR(self._engine.optimizer, warmup_fn)

scheduler: torch.optim.lr_scheduler.LRScheduler
if lr_cfg.lr_type == "linear":
scheduler = LinearLR(
self._engine.optimizer,
start_factor=1.0,
end_factor=lr_cfg.lr_min / self._engine.optimizer.defaults["lr"],
total_iters=scheduler_step - warmup_steps,
scheduler_after_warmup: LRScheduler
lr_scheduler: LRScheduler

if warmup_step < scheduler_step:
if lr_cfg.lr_type == "linear":
scheduler_after_warmup = LinearLR(
self._engine.optimizer,
start_factor=1.0,
end_factor=lr_cfg.lr_min / self._engine.optimizer.defaults["lr"],
total_iters=scheduler_step - warmup_step,
)
elif lr_cfg.lr_type == "cosine":
scheduler_after_warmup = CosineAnnealingLR(
self._engine.optimizer, T_max=scheduler_step - warmup_step, eta_min=lr_cfg.lr_min
)
elif lr_cfg.lr_type == "constant":
scheduler_after_warmup = LambdaLR(self._engine.optimizer, lambda x: 1.0)
else:
raise ValueError(f"Unsupported lr type: {lr_cfg.lr_type}")
lr_scheduler = SequentialLR(
optimizer=self._engine.optimizer,
schedulers=[warmup_scheduler, scheduler_after_warmup],
milestones=[warmup_step],
)
elif lr_cfg.lr_type == "cosine":
scheduler = CosineAnnealingLR(
self._engine.optimizer, T_max=scheduler_step - warmup_steps, eta_min=lr_cfg.lr_min
elif warmup_step == scheduler_step:
self.logger.warning(
f"You're setting warmup_step ({warmup_step} to be equal to scheduler_step ({scheduler_step}), "
"which is generally not recommended."
)
elif lr_cfg.lr_type == "constant":
scheduler = LambdaLR(self._engine.optimizer, lambda x: 1.0)
lr_scheduler = warmup_scheduler
else:
raise ValueError(f"Unsupported lr type: {lr_cfg.lr_type}")
lr_scheduler = SequentialLR(
optimizer=self._engine.optimizer,
schedulers=[warmup_scheduler, scheduler],
milestones=[warmup_steps],
)
raise ValueError(
f"Expected warmup_step ({warmup_step}) to be no more than scheduler_step ({scheduler_step})"
)
return lr_scheduler

def _maybe_save(self, is_snapshot: bool = False) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/utils/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
self.use_safetensors = False
elif "model.safetensors" in os.listdir(self.model_path):
with safe_open(os.path.join(self.model_path, "model.safetensors"), framework="pt") as f:
self.weight_map = {k: "model.safetensors" for k in f.keys()}
self.weight_map = dict.fromkeys(f.keys(), "model.safetensors")
self.use_safetensors = True
else:
raise FileNotFoundError
Expand Down