From a5aa8857d4b649fc90955456e44458c716eefe93 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Tue, 23 Dec 2025 10:02:22 +0000 Subject: [PATCH 1/3] [Enhance] add training config printing to trainer --- xtuner/v1/train/trainer.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index f9640e2ca..f2ba7b764 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -410,7 +410,7 @@ class Trainer: backend (str): Backend for distributed training. """ - config: TrainerConfig | None + _config: TrainerConfig | None _META_PATH = ".xtuner" _PROFILE_TIME_PATH = "profiling_time" _PROFILE_MEMORY_PATH = "profiling_memory" @@ -687,7 +687,8 @@ def from_config(cls, config: TrainerConfig) -> Self: trainer_cfg=config, internal_metrics_cfg=config.internal_metrics_cfg, ) - self.config = config + self._config = config + self._print_training_config() return self def fit(self): @@ -936,6 +937,10 @@ def cur_epoch(self) -> int | None: """ return self._cur_epoch + @property + def config(self) -> TrainerConfig | None: + return self._config + def _init_logger(self, log_dir: Path): # Logging system maybe need better design log_level = os.environ.get("XTUNER_LOG_LEVEL", "INFO").upper() @@ -1794,6 +1799,11 @@ def _setup_env(self): log_str += "==================================================" logger.info(log_str) + def _print_training_config(self): + if self._config is not None and self.rank == 0: + config_str = self._config.model_dump_json(indent=2) + logger.info(f"Training config: {config_str}") + def _resolve_deprecate_compile_cfg(self, model_cfg: TransformerConfig | BaseComposeConfig, fsdp_cfg: FSDPConfig): if not fsdp_cfg.torch_compile: model_cfg.compile_cfg = False From 28278b98adc748179160c35ad642af731502b983 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Tue, 23 Dec 2025 11:57:51 +0000 Subject: [PATCH 2/3] [Test] add test for training config printing --- tests/train/test_trainer.py | 61 +++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/train/test_trainer.py b/tests/train/test_trainer.py index 961d8b1f8..adb350e82 100644 --- a/tests/train/test_trainer.py +++ b/tests/train/test_trainer.py @@ -513,6 +513,67 @@ def test_resolve_ep_size(self, model_ep_size, fsdp_ep_size, target_ep_size): assert trainer.config.model_cfg.ep_size == target_ep_size self.cleanup_trainer(trainer) + def test_print_training_config(self): + """Test that training config is printed on rank 0 only.""" + import json + + model_cfg = Qwen3Dense4BConfig() + trainer_cfg = self.build_trainer_cfg(model_cfg) + + trainer = Trainer.from_config(trainer_cfg) + + # Wait for all processes to finish logging + dist.barrier() + + # Read the log file for each rank + rank = dist.get_rank() + log_file = trainer.log_dir / f"rank{rank}.log" + + self.assertTrue(log_file.exists(), f"Log file should exist for rank {rank}") + + with open(log_file, "r") as f: + log_content = f.read() + + if rank == 0: + # Check that config was logged on rank 0 + self.assertIn("Training config:", log_content, "Config should be logged on rank 0") + + # Extract and verify the JSON content + # Find the start of the JSON config in the log content + config_marker = "Training config:" + config_start_pos = log_content.find(config_marker) + self.assertGreater(config_start_pos, -1, "Should find 'Training config:' in logs") + + # Extract the JSON part (everything after "Training config: ") + json_start_pos = log_content.find("{", config_start_pos) + self.assertGreater(json_start_pos, -1, "Should find JSON object in log") + + # Find the end of JSON by matching braces + brace_count = 0 + json_end_pos = json_start_pos + for i in range(json_start_pos, len(log_content)): + if log_content[i] == "{": + brace_count += 1 + elif log_content[i] == "}": + brace_count -= 1 + if brace_count == 0: + json_end_pos = i + 1 + break + + json_str = log_content[json_start_pos:json_end_pos] + config_dict = json.loads(json_str) + + # Verify key fields are present in the logged config + self.assertIn("model_cfg", config_dict) + self.assertIn("optim_cfg", config_dict) + self.assertIn("global_batch_size", config_dict) + self.assertEqual(config_dict["global_batch_size"], 8) + else: + # On non-rank-0, config should not be logged + self.assertNotIn("Training config:", log_content, "Config should not be logged on non-rank-0") + + self.cleanup_trainer(trainer) + @property def world_size(self) -> int: return 8 From bf04f89eeb6686db5ed00c670ffbe0644028a547 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Tue, 23 Dec 2025 11:58:47 +0000 Subject: [PATCH 3/3] [Fix] add exception handling in LogCapture.__exit__ --- xtuner/_testing/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/xtuner/_testing/utils.py b/xtuner/_testing/utils.py index 8c412ba22..dbc1517a2 100644 --- a/xtuner/_testing/utils.py +++ b/xtuner/_testing/utils.py @@ -44,7 +44,12 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): - self._logger.remove(self._handle_id) + try: + self._logger.remove(self._handle_id) + except KeyboardInterrupt as e: + raise e + except: + ... def get_output(self) -> str: self._handle.seek(0)