Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .dev_scripts/dcp_to_hf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from xtuner.v1.model import get_model_config_from_hf
from xtuner.v1.model.moe.moe import MoEConfig
from transformers import AutoTokenizer
from cyclopts import App, Parameter
from pathlib import Path
import torch.distributed as dist
Expand Down Expand Up @@ -39,6 +40,12 @@ def dcp_to_hf(
help="Path to the DCP checkpoint, <work_dirs>/<timestamp>/checkpoints/ckpt-step-6"
),
],
tokenizer_path: Annotated[
Path,
Parameter(
help="Path to the tokenizer folder, usually the same as the hf_path"
),
],
hf_path: Annotated[
Path | None,
Parameter(
Expand All @@ -52,6 +59,7 @@ def dcp_to_hf(
),
] = "bf16",
):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
dist.init_process_group(backend="cuda:nccl,cpu:gloo")
torch.serialization.add_safe_globals(
[
Expand Down Expand Up @@ -98,6 +106,9 @@ def dcp_to_hf(
else:
model.save_hf(hf_path, save_dtype=torch.float8_e4m3fn)

if dist.get_rank() == 0:
tokenizer.save_pretrained(hf_path)


if __name__ == "__main__":
cli()
19 changes: 16 additions & 3 deletions xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def hf_config(self) -> PretrainedConfig | None:
"""HuggingFace configuration."""
return None

def save_hf(self, hf_path: str | Path):
def save_hf(self, hf_path: str | Path, dtype: torch.dtype = torch.bfloat16):
"""Save the configuration to a HuggingFace-compatible format.

Args:
Expand All @@ -188,7 +188,20 @@ def save_hf(self, hf_path: str | Path):
if self.hf_config is None:
raise NotImplementedError("The `hf_config` property must be implemented to save in HuggingFace format.")

self.hf_config.save_pretrained(hf_path)
if dtype not in {torch.bfloat16, torch.float8_e4m3fn}:
raise NotImplementedError(f"Saving dtype {dtype} is not supported yet.")

hf_config = self.hf_config
if dtype is torch.float8_e4m3fn:
hf_config.quantization_config = {
"activation_scheme": "dynamic",
"fmt": "e4m3",
"quant_method": "fp8",
"scale_fmt": "ue8m0",
"weight_block_size": [128, 128],
}

hf_config.save_pretrained(hf_path)


class ModelOutputs(TypedDict):
Expand Down Expand Up @@ -919,7 +932,7 @@ def _save_hf(
raise RuntimeError("Internal Error, both self.config.hf_config and self._hf_path are None")

if self.config.hf_config is not None:
self.config.save_hf(hf_dir)
self.config.save_hf(hf_dir, dtype=save_dtype)
else: # if self._hf_path is not None:
for file in cast(Path, self._hf_path).iterdir():
if file.suffix != ".safetensors":
Expand Down