From a1a0ddfb00cdab74d7e5866d42ae4068595debce Mon Sep 17 00:00:00 2001 From: sjh00 Date: Mon, 9 Jun 2025 19:34:58 +0800 Subject: [PATCH 1/6] Fix Windows support. Add ImageIO for AVIF/TIFF/BMP support, handle animated formats. Ensure UTF-8 for multilingual. 1. Update the usage of the new transforms v2.Compose method, add partial code for the Windows system to solve the error caused by the mismatch between the Python int maximum value and the C long maximum value on the Windows system, and address the issue that the code does not handle the lack of NCCL support on the Windows system. 2. Use ImageIO to read images better, add support for AVIF/TIFF/BMP, and add support for reading the first frame of AVIF/WEBP images with frame dimensions (which could not be recognized before). 3. Ensure that JSON files are saved in UTF-8 encoding to be compatible with file paths, names, and configuration text in more languages. --- requirements.txt | 3 ++ src/dataloaders/bucketing_logic.py | 5 ++- src/dataloaders/dataloader.py | 44 ++++++++++++++++------- src/dataloaders/utils.py | 11 +++--- src/general_utils.py | 2 +- src/trainer/train_chroma.py | 17 +++++---- src/trainer/train_chroma_lora.py | 18 ++++++---- src/trainer/train_chroma_rectification.py | 17 +++++---- src/trainer/train_lumina.py | 17 +++++---- test/dataloaders/dataloader_test.py | 2 +- 10 files changed, 92 insertions(+), 44 deletions(-) diff --git a/requirements.txt b/requirements.txt index 7245b0f..5f50e99 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,10 @@ torch_tb_profiler tqdm wandb einops +pillow-avif-plugin pillow-jxl-plugin +imageio[pillow] +imageio[pyav] transformers safetensors sentencepiece diff --git a/src/dataloaders/bucketing_logic.py b/src/dataloaders/bucketing_logic.py index 1709348..d9249ce 100644 --- a/src/dataloaders/bucketing_logic.py +++ b/src/dataloaders/bucketing_logic.py @@ -15,7 +15,10 @@ from .utils import read_jsonl -csv.field_size_limit(sys.maxsize) +try: + csv.field_size_limit(sys.maxsize) +except OverflowError: + csv.field_size_limit(2147483647) log = logging.getLogger(__name__) diff --git a/src/dataloaders/dataloader.py b/src/dataloaders/dataloader.py index 60ade19..0ac3012 100644 --- a/src/dataloaders/dataloader.py +++ b/src/dataloaders/dataloader.py @@ -7,13 +7,14 @@ from tqdm import tqdm from PIL import Image +import pillow_avif +import imageio.v3 as iio import torch from torch.utils.data import Dataset import torchvision.transforms.v2 as v2 from io import BytesIO import concurrent.futures -import requests from requests.exceptions import RequestException, Timeout from .utils import read_jsonl @@ -70,7 +71,7 @@ def __init__( random.seed(seed) # just simple pil image to tensor conversion self.image_transforms = v2.Compose( - [v2.ToTensor(), v2.Normalize(mean=[0.5], std=[0.5])] + [v2.ToImage(), v2.ToDtype(torch.float32, scale=True)] ) # TODO: batches has to be preprocessed for batching!!!! @@ -78,7 +79,6 @@ def __init__( # slice batches using round robbin self._round_robin() - self.session = requests.Session() self.thread_per_worker = thread_per_worker # self.executor = concurrent.futures.ThreadPoolExecutor(thread_per_worker) @@ -225,12 +225,11 @@ def _round_robin(self): # - def _load_image(self, sample, session, image_folder_path, timeout): + def _load_image(self, sample, image_folder_path, timeout): try: + img_array = None if sample["is_url_based"]: - response = session.get(sample["filename"], timeout=timeout) - response.raise_for_status() # Raises an HTTPError if the status code is 4xx/5xx - return Image.open(BytesIO(response.content)).convert("RGB") + img_array = iio.imread(sample["filename"]) else: image_path = os.path.join(image_folder_path, sample["filename"]) @@ -243,18 +242,40 @@ def _load_image(self, sample, session, image_folder_path, timeout): ) elif os.path.exists(image_path): # Standard handling if the specified file exists - return Image.open(image_path).convert("RGB") + img_array = iio.imread(image_path) else: # Try alternative extensions if the main file doesn't exist filename, _ = os.path.splitext(sample["filename"]) - extensions = ["png", "jpg", "jpeg", "webp"] + extensions = ["png", "jpg", "jpeg", "webp", "bmp", "avif", "tif", "tiff"] for ext in extensions: alt_image_path = os.path.join( image_folder_path, f"{filename}.{ext}" ) if os.path.exists(alt_image_path): - return Image.open(alt_image_path).convert("RGB") - return None + img_array = iio.imread(alt_image_path) + if img_array is None: + return None + else: + if img_array.ndim == 2: + image = Image.fromarray(img_array, mode='L') + elif img_array.ndim == 3 or (img_array.ndim == 4 and img_array.shape[0] == 1): + if img_array.ndim == 4: + # When the image has a frame dimension, only the first frame is taken. + img_array = img_array[0] + height, width, channels = img_array.shape + if channels == 3: # RGB + image = Image.fromarray(img_array, mode='RGB') + elif channels == 4: # RGBA + image = Image.fromarray(img_array, mode='RGBA') + else: + raise ValueError(f"Unsupported number of channels: {channels}") + else: + raise ValueError(f"Unsupported image shape: {img_array.shape}") + + if image.mode != 'RGB': + image = image.convert('RGB') + + return image except Exception as e: log.error( f"An error occurred: {e} for {sample['filename']} on rank {self.rank}" @@ -274,7 +295,6 @@ def __getitem__(self, index): executor.submit( self._load_image, sample, - self.session, self.image_folder_path, self.timeout, ) diff --git a/src/dataloaders/utils.py b/src/dataloaders/utils.py index b2daaed..64ad3f8 100644 --- a/src/dataloaders/utils.py +++ b/src/dataloaders/utils.py @@ -5,7 +5,10 @@ import random from tqdm import tqdm -csv.field_size_limit(sys.maxsize) +try: + csv.field_size_limit(sys.maxsize) +except OverflowError: + csv.field_size_limit(2147483647) def save_as_jsonl(data, filename): @@ -18,9 +21,9 @@ def save_as_jsonl(data, filename): if os.path.join(*os.path.split(filename)[:-1]) != "": os.makedirs(os.path.join(*os.path.split(filename)[:-1]), exist_ok=True) - with open(filename, "w") as f: + with open(filename, "w", encoding="utf-8") as f: for item in tqdm(data): - json.dump(item, f) + json.dump(item, f, ensure_ascii=False) f.write("\n") @@ -35,7 +38,7 @@ def read_jsonl(filename): """ data = [] - with open(filename, "r") as f: + with open(filename, "r", encoding="utf-8") as f: for line in tqdm(f): data.append(json.loads(line)) return data diff --git a/src/general_utils.py b/src/general_utils.py index d63b20f..736f565 100644 --- a/src/general_utils.py +++ b/src/general_utils.py @@ -77,7 +77,7 @@ def save_file_multipart( index["metadata"].update(metadata) with open(os.path.join(base_folder, "model.safetensors.index.json"), "w") as f: - json.dump(index, f, indent=2) + json.dump(index, f, indent=2, ensure_ascii=False) return num_shards diff --git a/src/trainer/train_chroma.py b/src/trainer/train_chroma.py index 6d6ef82..0a64479 100644 --- a/src/trainer/train_chroma.py +++ b/src/trainer/train_chroma.py @@ -1,3 +1,4 @@ +import platform import sys import os import json @@ -120,7 +121,11 @@ def setup_distributed(rank, world_size): os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Initialize process group - dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + backend = "nccl" # Default backend for distributed training + if platform.system() == "Windows": + # Windows does not support NCCL, use GLOO instead + backend = "gloo" + dist.init_process_group(backend=backend, rank=rank, world_size=world_size) torch.cuda.set_device(rank) @@ -307,17 +312,17 @@ def cast_linear(module, dtype): def save_config_to_json(filepath: str, **configs): json_data = {key: asdict(value) for key, value in configs.items()} - with open(filepath, "w") as json_file: - json.dump(json_data, json_file, indent=4) + with open(filepath, "w", encoding="utf-8") as json_file: + json.dump(json_data, json_file, indent=4, ensure_ascii=False) def dump_dict_to_json(data, file_path): - with open(file_path, "w") as json_file: - json.dump(data, json_file, indent=4) + with open(file_path, "w", encoding="utf-8") as json_file: + json.dump(data, json_file, indent=4, ensure_ascii=False) def load_config_from_json(filepath: str): - with open(filepath, "r") as json_file: + with open(filepath, "r", encoding="utf-8") as json_file: return json.load(json_file) diff --git a/src/trainer/train_chroma_lora.py b/src/trainer/train_chroma_lora.py index 4a5dbdf..e76068b 100644 --- a/src/trainer/train_chroma_lora.py +++ b/src/trainer/train_chroma_lora.py @@ -1,4 +1,4 @@ -import sys +import platform import os import json from datetime import datetime @@ -128,7 +128,11 @@ def setup_distributed(rank, world_size): os.environ["WORLD_SIZE"] = str(world_size) # Initialize process group - dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + backend = "nccl" # Default backend for distributed training + if platform.system() == "Windows": + # Windows does not support NCCL, use GLOO instead + backend = "gloo" + dist.init_process_group(backend=backend, rank=rank, world_size=world_size) torch.cuda.set_device(rank) @@ -312,17 +316,17 @@ def cast_linear(module, dtype): def save_config_to_json(filepath: str, **configs): json_data = {key: asdict(value) for key, value in configs.items()} - with open(filepath, "w") as json_file: - json.dump(json_data, json_file, indent=4) + with open(filepath, "w", encoding="utf-8") as json_file: + json.dump(json_data, json_file, indent=4, ensure_ascii=False) def dump_dict_to_json(data, file_path): - with open(file_path, "w") as json_file: - json.dump(data, json_file, indent=4) + with open(file_path, "w", encoding="utf-8") as json_file: + json.dump(data, json_file, indent=4, ensure_ascii=False) def load_config_from_json(filepath: str): - with open(filepath, "r") as json_file: + with open(filepath, "r", encoding="utf-8") as json_file: return json.load(json_file) diff --git a/src/trainer/train_chroma_rectification.py b/src/trainer/train_chroma_rectification.py index 8427e8f..fa26179 100644 --- a/src/trainer/train_chroma_rectification.py +++ b/src/trainer/train_chroma_rectification.py @@ -1,3 +1,4 @@ +import platform import sys import os import json @@ -124,7 +125,11 @@ def setup_distributed(rank, world_size): os.environ["WORLD_SIZE"] = str(world_size) # Initialize process group - dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + backend = "nccl" # Default backend for distributed training + if platform.system() == "Windows": + # Windows does not support NCCL, use GLOO instead + backend = "gloo" + dist.init_process_group(backend=backend, rank=rank, world_size=world_size) torch.cuda.set_device(rank) @@ -414,17 +419,17 @@ def cast_linear(module, dtype): def save_config_to_json(filepath: str, **configs): json_data = {key: asdict(value) for key, value in configs.items()} - with open(filepath, "w") as json_file: - json.dump(json_data, json_file, indent=4) + with open(filepath, "w", encoding="utf-8") as json_file: + json.dump(json_data, json_file, indent=4, ensure_ascii=False) def dump_dict_to_json(data, file_path): - with open(file_path, "w") as json_file: - json.dump(data, json_file, indent=4) + with open(file_path, "w", encoding="utf-8") as json_file: + json.dump(data, json_file, indent=4, ensure_ascii=False) def load_config_from_json(filepath: str): - with open(filepath, "r") as json_file: + with open(filepath, "r", encoding="utf-8") as json_file: return json.load(json_file) diff --git a/src/trainer/train_lumina.py b/src/trainer/train_lumina.py index 4f5bce7..844238d 100644 --- a/src/trainer/train_lumina.py +++ b/src/trainer/train_lumina.py @@ -1,3 +1,4 @@ +import platform import sys import os import json @@ -106,7 +107,11 @@ def setup_distributed(rank, world_size): os.environ["WORLD_SIZE"] = str(world_size) # Initialize process group - dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + backend = "nccl" # Default backend for distributed training + if platform.system() == "Windows": + # Windows does not support NCCL, use GLOO instead + backend = "gloo" + dist.init_process_group(backend=backend, rank=rank, world_size=world_size) torch.cuda.set_device(rank) @@ -295,17 +300,17 @@ def cast_linear(module, dtype): def save_config_to_json(filepath: str, **configs): json_data = {key: asdict(value) for key, value in configs.items()} - with open(filepath, "w") as json_file: - json.dump(json_data, json_file, indent=4) + with open(filepath, "w", encoding="utf-8") as json_file: + json.dump(json_data, json_file, indent=4, ensure_ascii=False) def dump_dict_to_json(data, file_path): - with open(file_path, "w") as json_file: - json.dump(data, json_file, indent=4) + with open(file_path, "w", encoding="utf-8") as json_file: + json.dump(data, json_file, indent=4, ensure_ascii=False) def load_config_from_json(filepath: str): - with open(filepath, "r") as json_file: + with open(filepath, "r", encoding="utf-8") as json_file: return json.load(json_file) diff --git a/test/dataloaders/dataloader_test.py b/test/dataloaders/dataloader_test.py index 58c48b1..c68badb 100644 --- a/test/dataloaders/dataloader_test.py +++ b/test/dataloaders/dataloader_test.py @@ -36,7 +36,7 @@ images, caption, index = dataset[i] with open(f"preview/{i}.jsonl", 'w') as f: for item in caption: - json.dump(item, f) # Dump the item as a JSON object + json.dump(item, f, ensure_ascii=False) # Dump the item as a JSON object f.write('\n') # Write a newline after each JSON object save_image(make_grid(images.clip(-1, 1)), f"preview/{i}.jpg", normalize=True) From 4325b180f5543267ed3b825d479bc0d596acda9b Mon Sep 17 00:00:00 2001 From: sjh00 Date: Mon, 9 Jun 2025 21:54:02 +0800 Subject: [PATCH 2/6] Optimize the training process and fix some bugs. 1. add triton or triton-windows 2. batch_linear_assignment --> .cpu().to(torch.float32) 3. use_reentrant=False 4. ensure the optimizer is executed before the scheduler --- requirements.txt | 4 +++- src/math_utils.py | 6 +++++- src/models/chroma/model.py | 4 ++-- src/models/chroma/module/t5.py | 2 ++ src/models/lumina/model.py | 6 +++--- src/trainer/train_chroma.py | 3 +-- src/trainer/train_chroma_lora.py | 3 +-- src/trainer/train_chroma_rectification.py | 3 +-- src/trainer/train_lumina.py | 3 +-- 9 files changed, 19 insertions(+), 15 deletions(-) diff --git a/requirements.txt b/requirements.txt index 5f50e99..64b953f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,4 +18,6 @@ bitsandbytes==0.45.3 torchastic torch-linear-assignment huggingface_hub -scipy \ No newline at end of file +scipy +triton; sys_platform != "win32" +triton-windows; sys_platform == "win32" \ No newline at end of file diff --git a/src/math_utils.py b/src/math_utils.py index 2b16228..3af1cb7 100644 --- a/src/math_utils.py +++ b/src/math_utils.py @@ -42,8 +42,12 @@ def _cuda_assignment(C): from torch_linear_assignment import batch_linear_assignment from torch_linear_assignment import assignment_to_indices - assignment = batch_linear_assignment(C.unsqueeze(dim=0)) + original_device = C.device + C_cpu = C.cpu().to(torch.float32).unsqueeze(dim=0) + assignment = batch_linear_assignment(C_cpu) row_indices, col_indices = assignment_to_indices(assignment) + row_indices = row_indices.to(original_device) + col_indices = col_indices.to(original_device) matching_pairs = (row_indices, col_indices) return C, matching_pairs diff --git a/src/models/chroma/model.py b/src/models/chroma/model.py index 9b1c779..563a146 100644 --- a/src/models/chroma/model.py +++ b/src/models/chroma/model.py @@ -249,7 +249,7 @@ def forward( # just in case in different GPU for simple pipeline parallel if self.training: img, txt = ckpt.checkpoint( - block, img, txt, pe, double_mod, txt_img_mask + block, img, txt, pe, double_mod, txt_img_mask, use_reentrant=False ) else: img, txt = block( @@ -260,7 +260,7 @@ def forward( for i, block in enumerate(self.single_blocks): single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] if self.training: - img = ckpt.checkpoint(block, img, pe, single_mod, txt_img_mask) + img = ckpt.checkpoint(block, img, pe, single_mod, txt_img_mask, use_reentrant=False) else: img = block(img, pe=pe, distill_vec=single_mod, mask=txt_img_mask) img = img[:, txt.shape[1] :, ...] diff --git a/src/models/chroma/module/t5.py b/src/models/chroma/module/t5.py index e9d4252..33de33b 100644 --- a/src/models/chroma/module/t5.py +++ b/src/models/chroma/module/t5.py @@ -606,6 +606,7 @@ def forward( if position_bias != None else position_bias ), + use_reentrant=False ) pass else: @@ -668,6 +669,7 @@ def forward_mid( if position_bias != None else position_bias ), + use_reentrant=False ) pass else: diff --git a/src/models/lumina/model.py b/src/models/lumina/model.py index 669f63e..74e9567 100644 --- a/src/models/lumina/model.py +++ b/src/models/lumina/model.py @@ -689,7 +689,7 @@ def patchify_and_embed( # refine context for layer in self.context_refiner: if self.training: - cap_feats = ckpt.checkpoint(layer, cap_feats, cap_mask, cap_freqs_cis) + cap_feats = ckpt.checkpoint(layer, cap_feats, cap_mask, cap_freqs_cis, use_reentrant=False) else: cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) @@ -718,7 +718,7 @@ def patchify_and_embed( for layer in self.noise_refiner: if self.training: padded_img_embed = ckpt.checkpoint( - layer, padded_img_embed, padded_img_mask, img_freqs_cis, t + layer, padded_img_embed, padded_img_mask, img_freqs_cis, t, use_reentrant=False ) else: padded_img_embed = layer( @@ -804,7 +804,7 @@ def forward(self, x, t, cap_feats, cap_mask): for layer in self.layers: if self.training: - x = ckpt.checkpoint(layer, x, mask, freqs_cis, adaln_input) + x = ckpt.checkpoint(layer, x, mask, freqs_cis, adaln_input, use_reentrant=False) else: x = layer(x, mask, freqs_cis, adaln_input) diff --git a/src/trainer/train_chroma.py b/src/trainer/train_chroma.py index 0a64479..4da4f7b 100644 --- a/src/trainer/train_chroma.py +++ b/src/trainer/train_chroma.py @@ -773,9 +773,8 @@ def train_chroma(rank, world_size, debug=False): if not debug: synchronize_gradients(model) - scheduler.step() - optimizer.step() + scheduler.step() optimizer.zero_grad() if rank == 0: diff --git a/src/trainer/train_chroma_lora.py b/src/trainer/train_chroma_lora.py index e76068b..241bb30 100644 --- a/src/trainer/train_chroma_lora.py +++ b/src/trainer/train_chroma_lora.py @@ -782,9 +782,8 @@ def train_chroma(rank, world_size, debug=False): if not debug: synchronize_gradients(model) - scheduler.step() - optimizer.step() + scheduler.step() optimizer.zero_grad() if training_config.wandb_project is not None and rank == 0: diff --git a/src/trainer/train_chroma_rectification.py b/src/trainer/train_chroma_rectification.py index fa26179..c506888 100644 --- a/src/trainer/train_chroma_rectification.py +++ b/src/trainer/train_chroma_rectification.py @@ -937,9 +937,8 @@ def train_chroma(rank, world_size, debug=False): if not debug: synchronize_gradients(model) - scheduler.step() - optimizer.step() + scheduler.step() optimizer.zero_grad() if rank == 0: diff --git a/src/trainer/train_lumina.py b/src/trainer/train_lumina.py index 844238d..32ef825 100644 --- a/src/trainer/train_lumina.py +++ b/src/trainer/train_lumina.py @@ -734,9 +734,8 @@ def train_lumina(rank, world_size, debug=False): if not debug: synchronize_gradients(model) - scheduler.step() - optimizer.step() + scheduler.step() optimizer.zero_grad() if training_config.wandb_project is not None and rank == 0: From dd264f79935665a2f3e76a875eb0d3f06d3558e4 Mon Sep 17 00:00:00 2001 From: sjh00 Date: Mon, 9 Jun 2025 22:00:22 +0800 Subject: [PATCH 3/6] modify an if-else statement --- src/dataloaders/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dataloaders/dataloader.py b/src/dataloaders/dataloader.py index 0ac3012..1de0464 100644 --- a/src/dataloaders/dataloader.py +++ b/src/dataloaders/dataloader.py @@ -258,7 +258,7 @@ def _load_image(self, sample, image_folder_path, timeout): else: if img_array.ndim == 2: image = Image.fromarray(img_array, mode='L') - elif img_array.ndim == 3 or (img_array.ndim == 4 and img_array.shape[0] == 1): + elif img_array.ndim == 3 or (img_array.ndim == 4): if img_array.ndim == 4: # When the image has a frame dimension, only the first frame is taken. img_array = img_array[0] From 2b9f9da1445ff03d5f0302bdf8e94e7de2ca3802 Mon Sep 17 00:00:00 2001 From: sjh00 Date: Tue, 10 Jun 2025 03:10:21 +0800 Subject: [PATCH 4/6] Update requirements.txt --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 64b953f..b045d68 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -torch==2.6.0 -torchvision==0.21.0 +torch>=2.6.0 +torchvision>=0.21.0 numpy matplotlib tensorboard From e0d7811233c46bb0a60dbd62f88f89538692ff5b Mon Sep 17 00:00:00 2001 From: sjh00 Date: Tue, 17 Jun 2025 17:55:39 +0800 Subject: [PATCH 5/6] refactor(monitoring): Replace wandb with aim and update related configurations Update requirements.txt to add aim dependency Remove wandb imports in train_chroma_lora.py Modify readme.md documentation and configuration examples --- readme.md | 26 +++++++++++++------------- requirements.txt | 3 ++- src/trainer/train_chroma_lora.py | 1 - 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/readme.md b/readme.md index c23f3c0..5c31f1e 100644 --- a/readme.md +++ b/readme.md @@ -8,7 +8,7 @@ A powerful training toolkit for image generation models using Flow Matching tech - Flexible configuration via JSON - Multi-GPU training support with automatic device detection - Configurable inference during training -- Wandb and Hugging Face integration +- Aim and Hugging Face integration - Parameter efficient training with layer rotation and offloading ## Installation @@ -70,10 +70,10 @@ The trainer is configured via a JSON file with the following sections: "trained_double_blocks": 2, "save_every": 6, "save_folder": "checkpoints", - "wandb_key": null, - "wandb_project": null, - "wandb_run": "chroma", - "wandb_entity": null, + "aim_path": "./training", + "aim_experiment_name": "base", + "aim_hash": null, + "aim_steps": 0, "hf_repo_id": null, "hf_token": null } @@ -93,10 +93,10 @@ The trainer is configured via a JSON file with the following sections: | `trained_double_blocks` | Number of trainable transformer double blocks | | `save_every` | Save model checkpoint every X steps | | `save_folder` | Directory to save model checkpoints | -| `wandb_key` | Weights & Biases API key (optional) | -| `wandb_project` | Weights & Biases project name (optional) | -| `wandb_run` | Weights & Biases run name (optional) | -| `wandb_entity` | Weights & Biases entity name (optional) | +| `aim_path` | Aim directory path (optional) | +| `aim_experiment_name` | Aim experiment name (optional) | +| `aim_hash` | Aim hash (optional) | +| `aim_steps` | Aim steps (optional) | | `hf_repo_id` | Hugging Face repository ID for pushing models (optional) | | `hf_token` | Hugging Face API token (optional) | @@ -243,10 +243,10 @@ You can set up multiple inference configurations to test different settings duri "trained_double_blocks": 2, "save_every": 6, "save_folder": "testing", - "wandb_key": null, - "wandb_project": null, - "wandb_run": "chroma", - "wandb_entity": null, + "aim_path": "./training", + "aim_experiment_name": "base", + "aim_hash": null, + "aim_steps": 0, "hf_repo_id": null, "hf_token": null }, diff --git a/requirements.txt b/requirements.txt index b045d68..aaec0cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,4 +20,5 @@ torch-linear-assignment huggingface_hub scipy triton; sys_platform != "win32" -triton-windows; sys_platform == "win32" \ No newline at end of file +triton-windows; sys_platform == "win32" +aim \ No newline at end of file diff --git a/src/trainer/train_chroma_lora.py b/src/trainer/train_chroma_lora.py index fbae386..add2916 100644 --- a/src/trainer/train_chroma_lora.py +++ b/src/trainer/train_chroma_lora.py @@ -21,7 +21,6 @@ import random from transformers import T5Tokenizer -import wandb from src.dataloaders.dataloader import TextImageDataset from src.models.chroma.model import Chroma, chroma_params From 7d6771bd13ca0cdfc852c14ec7cabfde5df6737b Mon Sep 17 00:00:00 2001 From: sjh00 Date: Wed, 18 Jun 2025 09:21:20 +0800 Subject: [PATCH 6/6] fix --- readme.md | 2 +- requirements.txt | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/readme.md b/readme.md index 5c31f1e..516f7b1 100644 --- a/readme.md +++ b/readme.md @@ -93,7 +93,7 @@ The trainer is configured via a JSON file with the following sections: | `trained_double_blocks` | Number of trainable transformer double blocks | | `save_every` | Save model checkpoint every X steps | | `save_folder` | Directory to save model checkpoints | -| `aim_path` | Aim directory path (optional) | +| `aim_path` | Aim directory path (optional, Windows need to install by yourself) | | `aim_experiment_name` | Aim experiment name (optional) | | `aim_hash` | Aim hash (optional) | | `aim_steps` | Aim steps (optional) | diff --git a/requirements.txt b/requirements.txt index aaec0cb..69de9f4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,20 +5,18 @@ matplotlib tensorboard torch_tb_profiler tqdm -wandb einops pillow-avif-plugin pillow-jxl-plugin -imageio[pillow] imageio[pyav] transformers safetensors sentencepiece -bitsandbytes==0.45.3 +bitsandbytes>=0.45.3 torchastic torch-linear-assignment huggingface_hub scipy +aim; sys_platform != "win32" triton; sys_platform != "win32" -triton-windows; sys_platform == "win32" -aim \ No newline at end of file +triton-windows; sys_platform == "win32" \ No newline at end of file