Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ __pycache__
*eggs
.vscode/
env/
test/
venv/
workdir/
makefile
Expand All @@ -14,13 +15,16 @@ makefile
*events.out.tfevents*
algoperf/workloads/librispeech_conformer/data_dir
algoperf/workloads/librispeech_conformer/work_dir
algoperf/workloads/*
algorithms/archived_paper_baselines
*.flac
*.npy
*.csv
*.vocab
wandb/
*.txt
scoring/plots/
datasets/dataset_setup.py

!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv
!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv
Expand Down
53 changes: 53 additions & 0 deletions algoperf/dist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os
from collections import namedtuple
import logging
import json

def dict_to_namedtuple(d, name="Hyperparameters"):
NT = namedtuple(name, d.keys())
return NT(**d)

def write_hparams(hparams, tuning_dir_name):
"""Write hparams.json if absent; otherwise load it for consistency.

Supports either:
- dict
- objects with _asdict() (e.g., namedtuple)
Returns a dict.
"""
path = os.path.join(tuning_dir_name, 'hparams.json')

# Normalize to dict for writing/training.
payload = hparams._asdict() if hasattr(hparams, "_asdict") else hparams
if not isinstance(payload, dict):
raise TypeError(f"hparams must be dict-like; got {type(payload)}")

# If exists, try to load. If empty/corrupt, treat as missing.
if os.path.exists(path):
logging.info('Loading hparams from %s.', path)
try:
with open(path, 'r', encoding='UTF-8') as f:
content = f.read().strip()
if not content:
raise ValueError("hparams.json is empty")
loaded = json.loads(content)
if not isinstance(loaded, dict):
raise ValueError("hparams.json is not a dict")
return loaded
except Exception as e:
logging.warning(
'Failed to load %s (%s). Re-writing hparams.json.',
path, repr(e)
)
# fall through to rewrite

# Atomic write: write temp then replace
tmp_path = path + '.tmp'
with open(tmp_path, 'w', encoding='UTF-8') as f:
json.dump(payload, f, indent=2)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, path)

logging.info('Saving hparams to %s.', path)
return payload
10 changes: 6 additions & 4 deletions algoperf/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ def get_log_dir(
)
else:
if RANK == 0:
resume = input(
'Found existing experiment dir with the same name: {}. Do you wish '
'to resume training from this dir? [y/N]:'.format(experiment_path)
)

# resume = input(
# 'Found existing experiment dir with the same name: {}. Do you wish '
# 'to resume training from this dir? [y/N]:'.format(experiment_path)
# )
resume = "y"
if resume.lower() != 'y':
sys.exit()

Expand Down
2 changes: 1 addition & 1 deletion algoperf/workloads/cifar/cifar_pytorch/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _build_dataset(
dataset = CIFAR10(
root=data_dir,
train=split in ['train', 'eval_train', 'validation'],
download=False,
download=True,
transform=transform,
)
assert self.num_train_examples + self.num_validation_examples == 50000
Expand Down
1 change: 1 addition & 0 deletions algoperf/workloads/librispeech_conformer/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def load_tokenizer(


def wer(tokenizer_vocab_path):
print("CAWDA",tokenizer_vocab_path)
tokenizer = load_tokenizer(tokenizer_vocab_path)

@flax.struct.dataclass
Expand Down
236 changes: 236 additions & 0 deletions algorithms/dadapt/submission.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
"""Submission file for an AdamW optimizer with warmup+cosine LR in PyTorch."""

from typing import Any, Dict, Iterator, List, Optional, Tuple

import torch
import torch.distributed.nn as dist_nn
from absl import logging
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR, LambdaLR
from dadaptation import DAdaptSGD
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup

USE_PYTORCH_DDP = pytorch_setup()[0]


def init_optimizer_state(
workload: spec.Workload,
model_params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
rng: spec.RandomState,
) -> spec.OptimizerState:
"""Creates an AdamW optimizer and a learning rate schedule."""
del model_state
del rng

optimizer_state = {
'optimizer': DAdaptSGD(
model_params.parameters(),
lr=1.0, # ignored by D-Adapt
d0=getattr(hyperparameters, "d0", 1e-6), # initial D value
),
}

def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup_steps = hyperparameters.warmup_factor * step_hint
# warmup_steps = 0
warmup = LinearLR(
optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
)
constant_decay = LambdaLR(optimizer, lr_lambda=lambda step: 1.0)
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
optimizer, schedulers=[warmup, constant_decay], milestones=[warmup_steps]
)

optimizer_state['scheduler'] = pytorch_cosine_warmup(
workload.step_hint, hyperparameters, optimizer_state['optimizer']
)

return optimizer_state


def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None,
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results

current_model = current_param_container
current_model.train()
optimizer_state['optimizer'].zero_grad()

logits_batch, new_model_state = workload.model_fn(
params=current_model,
augmented_and_preprocessed_input_batch=batch,
model_state=model_state,
mode=spec.ForwardPassMode.TRAIN,
rng=rng,
update_batch_norm=True,
dropout_rate=hyperparameters.dropout_rate,
)

label_smoothing = (
hyperparameters.label_smoothing
if hasattr(hyperparameters, 'label_smoothing')
else 0.0
)

loss_dict = workload.loss_fn(
label_batch=batch['targets'],
logits_batch=logits_batch,
mask_batch=batch.get('weights'),
label_smoothing=label_smoothing,
)
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
# Use dist_nn.all_reduce to ensure correct loss and gradient scaling.
summed_loss = dist_nn.all_reduce(summed_loss)
n_valid_examples = dist_nn.all_reduce(n_valid_examples)
loss = summed_loss / n_valid_examples

loss.backward()

if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
torch.nn.utils.clip_grad_norm_(
current_model.parameters(), max_norm=grad_clip
)

grads_nonzero = any(p.grad is not None and torch.any(p.grad != 0)
for p in current_model.parameters())
if grads_nonzero:
optimizer_state['optimizer'].step()
else:
logging.warning(f"[step={global_step}] skipped step due to zero gradients")
optimizer_state['scheduler'].step()



# Log training metrics - loss, grad_norm, batch_size.
if global_step <= 10 or global_step % 500 == 0:
group = optimizer_state['optimizer'].param_groups[0]
base_lr = group["lr"]
d_value = group.get("d", 1.0)
g0_norm = group.get("g0_norm", 1.0)
effective_lr = base_lr * d_value / g0_norm

current_lr = optimizer_state['optimizer'].param_groups[0]['lr']
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
)
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
{
'loss': loss.item(),
'grad_norm': grad_norm.item(),
'lr': current_lr,
"lr_effective": effective_lr,
"d_value": d_value,
},
global_step,
)
logging.info(
'%d) loss = %0.3f, grad_norm = %0.3f',
global_step,
loss.item(),
grad_norm.item(),
)

return (optimizer_state, current_param_container, new_model_state)


def prepare_for_eval(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
del current_params_types
del loss_type
del eval_results
del global_step
del rng
return (optimizer_state, current_param_container, model_state)


def get_batch_size(workload_name):
# Return the global batch size.
if workload_name == 'criteo1tb':
return 262_144
elif workload_name == 'fastmri':
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_resnet_silu':
return 512
elif workload_name == 'imagenet_resnet_gelu':
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
return 128
elif workload_name == 'librispeech_deepspeech':
return 128
elif workload_name == 'ogbg':
return 512
elif workload_name == 'wmt':
return 128
elif workload_name == 'mnist':
return 16
elif workload_name == 'cifar':
return 64
else:
raise ValueError(f'Unsupported workload name: {workload_name}.')


def data_selection(
workload: spec.Workload,
input_queue: Iterator[Dict[str, spec.Tensor]],
optimizer_state: spec.OptimizerState,
current_param_container: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
global_step: int,
rng: spec.RandomState,
) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
del workload
del optimizer_state
del current_param_container
del model_state
del hyperparameters
del global_step
del rng
batch = next(input_queue)
return batch
14 changes: 14 additions & 0 deletions algorithms/dadapt/tuning_search_space.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"warmup_factor": {
"feasible_points": [0.05]
},
"label_smoothing": {
"feasible_points": [0.1, 0.2]
},
"dropout_rate": {
"feasible_points": [0.0, 0.1]
},
"d0": {
"feasible_points": [1e-6]
}
}
Loading
Loading