Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ A powerful training toolkit for image generation models using Flow Matching tech
- Flexible configuration via JSON
- Multi-GPU training support with automatic device detection
- Configurable inference during training
- Wandb and Hugging Face integration
- Aim and Hugging Face integration
- Parameter efficient training with layer rotation and offloading

## Installation
Expand Down Expand Up @@ -70,10 +70,10 @@ The trainer is configured via a JSON file with the following sections:
"trained_double_blocks": 2,
"save_every": 6,
"save_folder": "checkpoints",
"wandb_key": null,
"wandb_project": null,
"wandb_run": "chroma",
"wandb_entity": null,
"aim_path": "./training",
"aim_experiment_name": "base",
"aim_hash": null,
"aim_steps": 0,
"hf_repo_id": null,
"hf_token": null
}
Expand All @@ -93,10 +93,10 @@ The trainer is configured via a JSON file with the following sections:
| `trained_double_blocks` | Number of trainable transformer double blocks |
| `save_every` | Save model checkpoint every X steps |
| `save_folder` | Directory to save model checkpoints |
| `wandb_key` | Weights & Biases API key (optional) |
| `wandb_project` | Weights & Biases project name (optional) |
| `wandb_run` | Weights & Biases run name (optional) |
| `wandb_entity` | Weights & Biases entity name (optional) |
| `aim_path` | Aim directory path (optional, Windows need to install by yourself) |
| `aim_experiment_name` | Aim experiment name (optional) |
| `aim_hash` | Aim hash (optional) |
| `aim_steps` | Aim steps (optional) |
| `hf_repo_id` | Hugging Face repository ID for pushing models (optional) |
| `hf_token` | Hugging Face API token (optional) |

Expand Down Expand Up @@ -243,10 +243,10 @@ You can set up multiple inference configurations to test different settings duri
"trained_double_blocks": 2,
"save_every": 6,
"save_folder": "testing",
"wandb_key": null,
"wandb_project": null,
"wandb_run": "chroma",
"wandb_entity": null,
"aim_path": "./training",
"aim_experiment_name": "base",
"aim_hash": null,
"aim_steps": 0,
"hf_repo_id": null,
"hf_token": null
},
Expand Down
14 changes: 9 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
torch==2.6.0
torchvision==0.21.0
torch>=2.6.0
torchvision>=0.21.0
numpy
matplotlib
tensorboard
torch_tb_profiler
tqdm
wandb
einops
pillow-avif-plugin
pillow-jxl-plugin
imageio[pyav]
transformers
safetensors
sentencepiece
bitsandbytes==0.45.3
bitsandbytes>=0.45.3
torchastic
torch-linear-assignment
huggingface_hub
scipy
scipy
aim; sys_platform != "win32"
triton; sys_platform != "win32"
triton-windows; sys_platform == "win32"
5 changes: 4 additions & 1 deletion src/dataloaders/bucketing_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@

from .utils import read_jsonl

csv.field_size_limit(sys.maxsize)
try:
csv.field_size_limit(sys.maxsize)
except OverflowError:
csv.field_size_limit(2147483647)
log = logging.getLogger(__name__)


Expand Down
44 changes: 32 additions & 12 deletions src/dataloaders/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@

from tqdm import tqdm
from PIL import Image
import pillow_avif
import imageio.v3 as iio
import torch
from torch.utils.data import Dataset
import torchvision.transforms.v2 as v2

from io import BytesIO
import concurrent.futures
import requests
from requests.exceptions import RequestException, Timeout

from .utils import read_jsonl
Expand Down Expand Up @@ -70,15 +71,14 @@ def __init__(
random.seed(seed)
# just simple pil image to tensor conversion
self.image_transforms = v2.Compose(
[v2.ToTensor(), v2.Normalize(mean=[0.5], std=[0.5])]
[v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]
)

# TODO: batches has to be preprocessed for batching!!!!
self.batches = self._load_batches()

# slice batches using round robbin
self._round_robin()
self.session = requests.Session()
self.thread_per_worker = thread_per_worker
# self.executor = concurrent.futures.ThreadPoolExecutor(thread_per_worker)

Expand Down Expand Up @@ -225,12 +225,11 @@ def _round_robin(self):

# </some utility method here>

def _load_image(self, sample, session, image_folder_path, timeout):
def _load_image(self, sample, image_folder_path, timeout):
try:
img_array = None
if sample["is_url_based"]:
response = session.get(sample["filename"], timeout=timeout)
response.raise_for_status() # Raises an HTTPError if the status code is 4xx/5xx
return Image.open(BytesIO(response.content)).convert("RGB")
img_array = iio.imread(sample["filename"])
else:
image_path = os.path.join(image_folder_path, sample["filename"])

Expand All @@ -243,18 +242,40 @@ def _load_image(self, sample, session, image_folder_path, timeout):
)
elif os.path.exists(image_path):
# Standard handling if the specified file exists
return Image.open(image_path).convert("RGB")
img_array = iio.imread(image_path)
else:
# Try alternative extensions if the main file doesn't exist
filename, _ = os.path.splitext(sample["filename"])
extensions = ["png", "jpg", "jpeg", "webp"]
extensions = ["png", "jpg", "jpeg", "webp", "bmp", "avif", "tif", "tiff"]
for ext in extensions:
alt_image_path = os.path.join(
image_folder_path, f"{filename}.{ext}"
)
if os.path.exists(alt_image_path):
return Image.open(alt_image_path).convert("RGB")
return None
img_array = iio.imread(alt_image_path)
if img_array is None:
return None
else:
if img_array.ndim == 2:
image = Image.fromarray(img_array, mode='L')
elif img_array.ndim == 3 or (img_array.ndim == 4):
if img_array.ndim == 4:
# When the image has a frame dimension, only the first frame is taken.
img_array = img_array[0]
height, width, channels = img_array.shape
if channels == 3: # RGB
image = Image.fromarray(img_array, mode='RGB')
elif channels == 4: # RGBA
image = Image.fromarray(img_array, mode='RGBA')
else:
raise ValueError(f"Unsupported number of channels: {channels}")
else:
raise ValueError(f"Unsupported image shape: {img_array.shape}")

if image.mode != 'RGB':
image = image.convert('RGB')

return image
except Exception as e:
log.error(
f"An error occurred: {e} for {sample['filename']} on rank {self.rank}"
Expand All @@ -274,7 +295,6 @@ def __getitem__(self, index):
executor.submit(
self._load_image,
sample,
self.session,
self.image_folder_path,
self.timeout,
)
Expand Down
11 changes: 7 additions & 4 deletions src/dataloaders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import random
from tqdm import tqdm

csv.field_size_limit(sys.maxsize)
try:
csv.field_size_limit(sys.maxsize)
except OverflowError:
csv.field_size_limit(2147483647)


def save_as_jsonl(data, filename):
Expand All @@ -18,9 +21,9 @@ def save_as_jsonl(data, filename):
if os.path.join(*os.path.split(filename)[:-1]) != "":
os.makedirs(os.path.join(*os.path.split(filename)[:-1]), exist_ok=True)

with open(filename, "w") as f:
with open(filename, "w", encoding="utf-8") as f:
for item in tqdm(data):
json.dump(item, f)
json.dump(item, f, ensure_ascii=False)
f.write("\n")


Expand All @@ -35,7 +38,7 @@ def read_jsonl(filename):
"""

data = []
with open(filename, "r") as f:
with open(filename, "r", encoding="utf-8") as f:
for line in tqdm(f):
data.append(json.loads(line))
return data
Expand Down
2 changes: 1 addition & 1 deletion src/general_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def save_file_multipart(
index["metadata"].update(metadata)

with open(os.path.join(base_folder, "model.safetensors.index.json"), "w") as f:
json.dump(index, f, indent=2)
json.dump(index, f, indent=2, ensure_ascii=False)

return num_shards

Expand Down
6 changes: 5 additions & 1 deletion src/math_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,12 @@ def _cuda_assignment(C):
from torch_linear_assignment import batch_linear_assignment
from torch_linear_assignment import assignment_to_indices

assignment = batch_linear_assignment(C.unsqueeze(dim=0))
original_device = C.device
C_cpu = C.cpu().to(torch.float32).unsqueeze(dim=0)
assignment = batch_linear_assignment(C_cpu)
row_indices, col_indices = assignment_to_indices(assignment)
row_indices = row_indices.to(original_device)
col_indices = col_indices.to(original_device)
matching_pairs = (row_indices, col_indices)

return C, matching_pairs
Expand Down
4 changes: 2 additions & 2 deletions src/models/chroma/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def forward(
# just in case in different GPU for simple pipeline parallel
if self.training:
img, txt = ckpt.checkpoint(
block, img, txt, pe, double_mod, txt_img_mask
block, img, txt, pe, double_mod, txt_img_mask, use_reentrant=False
)
else:
img, txt = block(
Expand All @@ -266,7 +266,7 @@ def forward(
for i, block in enumerate(self.single_blocks):
single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"]
if self.training:
img = ckpt.checkpoint(block, img, pe, single_mod, txt_img_mask)
img = ckpt.checkpoint(block, img, pe, single_mod, txt_img_mask, use_reentrant=False)
else:
img = block(img, pe=pe, distill_vec=single_mod, mask=txt_img_mask)
img = img[:, txt.shape[1] :, ...]
Expand Down
2 changes: 2 additions & 0 deletions src/models/chroma/module/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ def forward(
if position_bias != None
else position_bias
),
use_reentrant=False
)
pass
else:
Expand Down Expand Up @@ -668,6 +669,7 @@ def forward_mid(
if position_bias != None
else position_bias
),
use_reentrant=False
)
pass
else:
Expand Down
6 changes: 3 additions & 3 deletions src/models/lumina/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ def patchify_and_embed(
# refine context
for layer in self.context_refiner:
if self.training:
cap_feats = ckpt.checkpoint(layer, cap_feats, cap_mask, cap_freqs_cis)
cap_feats = ckpt.checkpoint(layer, cap_feats, cap_mask, cap_freqs_cis, use_reentrant=False)
else:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)

Expand Down Expand Up @@ -718,7 +718,7 @@ def patchify_and_embed(
for layer in self.noise_refiner:
if self.training:
padded_img_embed = ckpt.checkpoint(
layer, padded_img_embed, padded_img_mask, img_freqs_cis, t
layer, padded_img_embed, padded_img_mask, img_freqs_cis, t, use_reentrant=False
)
else:
padded_img_embed = layer(
Expand Down Expand Up @@ -804,7 +804,7 @@ def forward(self, x, t, cap_feats, cap_mask):

for layer in self.layers:
if self.training:
x = ckpt.checkpoint(layer, x, mask, freqs_cis, adaln_input)
x = ckpt.checkpoint(layer, x, mask, freqs_cis, adaln_input, use_reentrant=False)
else:
x = layer(x, mask, freqs_cis, adaln_input)

Expand Down
20 changes: 12 additions & 8 deletions src/trainer/train_chroma.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import platform
import sys
import os
import json
Expand Down Expand Up @@ -120,7 +121,11 @@ def setup_distributed(rank, world_size):
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Initialize process group
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
backend = "nccl" # Default backend for distributed training
if platform.system() == "Windows":
# Windows does not support NCCL, use GLOO instead
backend = "gloo"
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
torch.cuda.set_device(rank)


Expand Down Expand Up @@ -307,17 +312,17 @@ def cast_linear(module, dtype):

def save_config_to_json(filepath: str, **configs):
json_data = {key: asdict(value) for key, value in configs.items()}
with open(filepath, "w") as json_file:
json.dump(json_data, json_file, indent=4)
with open(filepath, "w", encoding="utf-8") as json_file:
json.dump(json_data, json_file, indent=4, ensure_ascii=False)


def dump_dict_to_json(data, file_path):
with open(file_path, "w") as json_file:
json.dump(data, json_file, indent=4)
with open(file_path, "w", encoding="utf-8") as json_file:
json.dump(data, json_file, indent=4, ensure_ascii=False)


def load_config_from_json(filepath: str):
with open(filepath, "r") as json_file:
with open(filepath, "r", encoding="utf-8") as json_file:
return json.load(json_file)


Expand Down Expand Up @@ -768,9 +773,8 @@ def train_chroma(rank, world_size, debug=False):
if not debug:
synchronize_gradients(model)

scheduler.step()

optimizer.step()
scheduler.step()
optimizer.zero_grad()

if rank == 0:
Expand Down
Loading