Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions tests/train/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,67 @@ def test_resolve_ep_size(self, model_ep_size, fsdp_ep_size, target_ep_size):
assert trainer.config.model_cfg.ep_size == target_ep_size
self.cleanup_trainer(trainer)

def test_print_training_config(self):
"""Test that training config is printed on rank 0 only."""
import json

model_cfg = Qwen3Dense4BConfig()
trainer_cfg = self.build_trainer_cfg(model_cfg)

trainer = Trainer.from_config(trainer_cfg)

# Wait for all processes to finish logging
dist.barrier()

# Read the log file for each rank
rank = dist.get_rank()
log_file = trainer.log_dir / f"rank{rank}.log"

self.assertTrue(log_file.exists(), f"Log file should exist for rank {rank}")

with open(log_file, "r") as f:
log_content = f.read()

if rank == 0:
# Check that config was logged on rank 0
self.assertIn("Training config:", log_content, "Config should be logged on rank 0")

# Extract and verify the JSON content
# Find the start of the JSON config in the log content
config_marker = "Training config:"
config_start_pos = log_content.find(config_marker)
self.assertGreater(config_start_pos, -1, "Should find 'Training config:' in logs")

# Extract the JSON part (everything after "Training config: ")
json_start_pos = log_content.find("{", config_start_pos)
self.assertGreater(json_start_pos, -1, "Should find JSON object in log")

# Find the end of JSON by matching braces
brace_count = 0
json_end_pos = json_start_pos
for i in range(json_start_pos, len(log_content)):
if log_content[i] == "{":
brace_count += 1
elif log_content[i] == "}":
brace_count -= 1
if brace_count == 0:
json_end_pos = i + 1
break

json_str = log_content[json_start_pos:json_end_pos]
config_dict = json.loads(json_str)

# Verify key fields are present in the logged config
self.assertIn("model_cfg", config_dict)
self.assertIn("optim_cfg", config_dict)
self.assertIn("global_batch_size", config_dict)
self.assertEqual(config_dict["global_batch_size"], 8)
else:
# On non-rank-0, config should not be logged
self.assertNotIn("Training config:", log_content, "Config should not be logged on non-rank-0")

self.cleanup_trainer(trainer)

@property
def world_size(self) -> int:
return 8
Expand Down
7 changes: 6 additions & 1 deletion xtuner/_testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@ def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self._logger.remove(self._handle_id)
try:
self._logger.remove(self._handle_id)
except KeyboardInterrupt as e:
raise e
except:
...

def get_output(self) -> str:
self._handle.seek(0)
Expand Down
14 changes: 12 additions & 2 deletions xtuner/v1/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ class Trainer:
backend (str): Backend for distributed training.
"""

config: TrainerConfig | None
_config: TrainerConfig | None
_META_PATH = ".xtuner"
_PROFILE_TIME_PATH = "profiling_time"
_PROFILE_MEMORY_PATH = "profiling_memory"
Expand Down Expand Up @@ -687,7 +687,8 @@ def from_config(cls, config: TrainerConfig) -> Self:
trainer_cfg=config,
internal_metrics_cfg=config.internal_metrics_cfg,
)
self.config = config
self._config = config
self._print_training_config()
return self

def fit(self):
Expand Down Expand Up @@ -936,6 +937,10 @@ def cur_epoch(self) -> int | None:
"""
return self._cur_epoch

@property
def config(self) -> TrainerConfig | None:
return self._config

def _init_logger(self, log_dir: Path):
# Logging system maybe need better design
log_level = os.environ.get("XTUNER_LOG_LEVEL", "INFO").upper()
Expand Down Expand Up @@ -1794,6 +1799,11 @@ def _setup_env(self):
log_str += "=================================================="
logger.info(log_str)

def _print_training_config(self):
if self._config is not None and self.rank == 0:
config_str = self._config.model_dump_json(indent=2)
logger.info(f"Training config: {config_str}")

def _resolve_deprecate_compile_cfg(self, model_cfg: TransformerConfig | BaseComposeConfig, fsdp_cfg: FSDPConfig):
if not fsdp_cfg.torch_compile:
model_cfg.compile_cfg = False