From 8da0c79d97bda034523c10c55c608ca5ddf11cdf Mon Sep 17 00:00:00 2001 From: rka97 Date: Sun, 30 Nov 2025 01:54:55 +0000 Subject: [PATCH 1/5] add mixed precision training --- .../finewebedu_lm/finewebedu_lm_jax/models.py | 91 ++++++++------- .../finewebedu_lm_jax/workload.py | 49 +++++++- .../finewebedu_lm_pytorch/models.py | 110 ++++++++++++------ .../finewebedu_lm_pytorch/workload.py | 95 +++++++++------ algoperf/workloads/finewebedu_lm/workload.py | 23 +++- pyproject.toml | 3 +- 6 files changed, 249 insertions(+), 122 deletions(-) diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py index d08e9b7bf..3419fe6fb 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py @@ -8,6 +8,7 @@ import jax import jax.numpy as jnp +import jmp from flax import linen as nn @@ -26,18 +27,24 @@ class ModelConfig: use_residual_scaling: bool = True tie_embeddings: bool = True # Whether to tie input and output embed qknorm_epsilon: float = 1e-6 - - dtype: jnp.dtype = jnp.float32 attention_init: nn.initializers.Initializer = nn.initializers.normal( stddev=0.02 ) linear_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) embed_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02) + param_dtype: jnp.dtype = jnp.float32 + compute_dtype: jnp.dtype = jnp.bfloat16 + output_dtype: jnp.dtype = jnp.bfloat16 def __post_init__(self): self.residual_init = nn.initializers.normal( stddev=0.02 / jnp.sqrt(2 * self.num_layers) ) + self.mp_policy = jmp.Policy( + compute_dtype=self.compute_dtype, + param_dtype=self.param_dtype, + output_dtype=self.output_dtype, + ) class Mlp(nn.Module): @@ -49,7 +56,11 @@ class Mlp(nn.Module): def __call__(self, x_BxLxD: jax.Array): cfg = self.cfg linear = partial( - nn.Dense, kernel_init=cfg.linear_init, use_bias=False, dtype=cfg.dtype + nn.Dense, + kernel_init=cfg.linear_init, + use_bias=False, + dtype=cfg.compute_dtype, + param_dtype=cfg.param_dtype, ) # Adjust hidden dimension to keep the number of parameters invariant to # the activation function used since the GLU MLP has 3 * hidden_dim * D @@ -65,7 +76,8 @@ def __call__(self, x_BxLxD: jax.Array): x_BxLxD = nn.Dense( cfg.model_dim, use_bias=False, - dtype=cfg.dtype, + dtype=cfg.compute_dtype, + param_dtype=cfg.param_dtype, kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init, @@ -96,7 +108,7 @@ def apply_rope(q, k, freqs_cis): def rotate_tensor(x): # Split into real and imaginary parts - x_r2 = x.reshape(*x.shape[:-1], -1, 2) + x_r2 = x.reshape(*x.shape[:-1], -1, 2).astype(jnp.float32) L = x.shape[1] freqs = freqs_cis[:, :L, :, :, :] @@ -109,7 +121,7 @@ def rotate_tensor(x): axis=-1, ) - return rotated_x_r2.reshape(*x.shape) + return rotated_x_r2.reshape(*x.shape).astype(x.dtype) # Apply rotation to Q and K separately rotated_q = rotate_tensor(q) @@ -141,7 +153,8 @@ def setup(self): features=(cfg.num_heads, self.Dh), kernel_init=cfg.attention_init, use_bias=False, - dtype=cfg.dtype, + dtype=cfg.compute_dtype, + param_dtype=cfg.param_dtype, ) self.multilinear_query = self.multilinear(name='query') self.multilinear_key = self.multilinear(name='key') @@ -150,7 +163,9 @@ def setup(self): seq_len = cfg.seq_len attn_scale0 = jnp.log2(seq_len**2 - seq_len) self.attn_scale = self.param( - 'attn_scale', nn.initializers.constant(attn_scale0), () + 'attn_scale', + nn.initializers.constant(attn_scale0, dtype=cfg.compute_dtype), + (), ) self.output_projection = nn.DenseGeneral( features=cfg.model_dim, @@ -160,7 +175,8 @@ def setup(self): if cfg.use_residual_scaling else cfg.linear_init, use_bias=False, - dtype=cfg.dtype, + dtype=cfg.compute_dtype, + param_dtype=cfg.param_dtype, ) def __call__(self, x_BxLxD: jax.Array): @@ -177,32 +193,17 @@ def __call__(self, x_BxLxD: jax.Array): # Apply QK normalization q_BxLxHxDh /= jnp.linalg.norm(q_BxLxHxDh, axis=-1, keepdims=True) + self.eps k_BxLxHxDh /= jnp.linalg.norm(k_BxLxHxDh, axis=-1, keepdims=True) + self.eps - - # Compute attention scores - att_BxHxLxL = jnp.einsum('...qhd,...khd->...hqk', q_BxLxHxDh, k_BxLxHxDh) - - # Causal attention mask - L = x_BxLxD.shape[1] - mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) - - # Apply mask and softmax - _NEG_INF = jnp.finfo(cfg.dtype).min - att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) - att_BxHxLxL = ( - self.attn_scale * att_BxHxLxL - ) # Learned scaling factor for QK norm - att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) - att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) - - # Compute attention output - out_BxLxHxDh = jnp.einsum('...hqk,...khd->...qhd', att_BxHxLxL, v_BxLxHxDh) - - # Reshape and project output + q_BxLxHxDh *= self.attn_scale + out_BxLxHxDh = jax.nn.dot_product_attention( + query=q_BxLxHxDh, + key=k_BxLxHxDh, + value=v_BxLxHxDh, + is_causal=True, + scale=1.0, + implementation='cudnn' if cfg.compute_dtype is not jnp.float32 else None, + ) out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) - - # Output projection out_BxLxD = self.output_projection(out_BxLxD) - return out_BxLxD @@ -216,16 +217,16 @@ def __call__(self, in_BxLxD: jax.Array): cfg = self.docfg # x = x + attn( attn_norm(x) ) - x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( - in_BxLxD - ) + x_BxLxD = nn.RMSNorm( + param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon + )(in_BxLxD) x_BxLxD = CausalAttn(cfg)(x_BxLxD) x_BxLxD += in_BxLxD # x = x + mlp( mlp_norm(x) ) - z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( - x_BxLxD - ) + z_BxLxD = nn.RMSNorm( + param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon + )(x_BxLxD) z_BxLxD = Mlp(cfg)(z_BxLxD) return x_BxLxD + z_BxLxD @@ -242,19 +243,24 @@ def setup(self): num_embeddings=cfg.vocab_size, features=cfg.model_dim, embedding_init=cfg.embed_init, + dtype=cfg.compute_dtype, + param_dtype=cfg.param_dtype, ) self.blocks = [TBlock(cfg) for _ in range(cfg.num_layers)] - self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) + self.out_ln = nn.RMSNorm( + param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon + ) # Output projection - tied to input embeddings if configured if cfg.tie_embeddings: - self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) + self.output_proj = lambda x: self.embed.attend(x) else: self.output_proj = nn.Dense( cfg.vocab_size, kernel_init=cfg.embed_init, - dtype=cfg.dtype, + dtype=cfg.compute_dtype, + param_dtype=cfg.param_dtype, name='output_proj', ) @@ -357,6 +363,7 @@ def main(): # Make a prediction (forward pass) print('\nRunning forward pass...') + params, x_BxL = cfg.mp_policy.cast_to_compute((params, x_BxL)) logits = model.apply(params, x_BxL) # Print output shape and sample values diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py index ee4cffbbc..14366d9ea 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py @@ -1,9 +1,11 @@ """LM workload implemented in Jax.""" +from functools import partial from typing import Any, Dict, Optional, Tuple import jax import jax.numpy as jnp +import jmp from algoperf import jax_sharding_utils, param_utils, spec from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import ( @@ -13,10 +15,33 @@ from algoperf.workloads.finewebedu_lm.input_pipeline import get_data_iter from algoperf.workloads.finewebedu_lm.workload import BaseLmWorkload +replicated_sharding = jax_sharding_utils.get_replicate_sharding() +batch_sharding = jax_sharding_utils.get_batch_dim_sharding() + +# Dtype mapping from string to JAX dtype +DTYPE_MAP = { + 'float32': jnp.float32, + 'float16': jnp.float16, + 'bfloat16': jnp.bfloat16, +} + class LmWorkload(BaseLmWorkload): """LM JAX workload.""" + # Convert dtype strings from base class to JAX dtypes + @property + def _compute_dtype(self) -> Any: + return DTYPE_MAP[self._compute_dtype_str] + + @property + def _param_dtype(self) -> Any: + return DTYPE_MAP[self._param_dtype_str] + + @property + def _output_dtype(self) -> Any: + return DTYPE_MAP[self._output_dtype_str] + def _build_input_queue( self, data_rng: jax.random.PRNGKey, @@ -53,8 +78,14 @@ def init_model_fn( num_layers=self._n_layers, # num layers vocab_size=self._vocab_size, expanded_model_dim=self._mlp_dim, # feedforward dim - dtype=jnp.float32, + rmsnorm_epsilon=self._rmsnorm_epsilon, + qknorm_epsilon=self._qknorm_epsilon, + tie_embeddings=self._tie_embeddings, + param_dtype=self._param_dtype, + compute_dtype=self._compute_dtype, + output_dtype=self._output_dtype, ) + self._mp_policy: jmp.Policy = cfg.mp_policy self._model = TransformerDo(cfg) input_shape = (1, self._seq_len) # For token IDs @@ -66,8 +97,7 @@ def init_model_fn( self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) params = jax_sharding_utils.replicate(params) - model_state = None - return params, model_state + return params, None def model_fn( self, @@ -81,10 +111,12 @@ def model_fn( ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode, rng, update_batch_norm, model_state, dropout_rate inputs = batch['inputs'] + params, inputs = self._mp_policy.cast_to_compute((params, inputs)) # Convert one-hot inputs to token IDs if needed if inputs.ndim == 3: # one-hot encoded inputs = jnp.argmax(inputs, axis=-1) logits = self._model.apply({'params': params}, inputs) + logits = self._mp_policy.cast_to_output(logits) return logits, None def loss_fn( @@ -139,6 +171,17 @@ def loss_fn( 'per_example': per_example_losses, } + @partial( + jax.jit, + static_argnums=(0,), + in_shardings=( + replicated_sharding, + batch_sharding, + replicated_sharding, + replicated_sharding, + ), + out_shardings=(replicated_sharding), + ) def _eval_batch( self, params: spec.ParameterContainer, diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py index edee8318c..4c60198cc 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/models.py @@ -26,14 +26,24 @@ class ModelConfig: qknorm_epsilon: float = 1e-6 use_residual_scaling: bool = True tie_embeddings: bool = True + compute_dtype: torch.dtype = torch.bfloat16 + param_dtype: torch.dtype = torch.float32 class MLP(nn.Module): - def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + dtype: torch.dtype = torch.float32, + ): super().__init__() - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) - self.fc2 = nn.Linear(hidden_dim, dim, bias=False) + hidden_dim = int( + multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + ) + self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False, dtype=dtype) + self.fc2 = nn.Linear(hidden_dim, dim, bias=False, dtype=dtype) self.glu = nn.GLU(dim=2) nn.init.normal_(self.fc1.weight, std=0.02) nn.init.normal_(self.fc2.weight, std=0.02) @@ -88,8 +98,12 @@ def __init__(self, cfg: ModelConfig): self.n_heads = cfg.num_heads self.head_dim = cfg.model_dim // cfg.num_heads - self.w_qkv = nn.Linear(cfg.model_dim, 3 * cfg.model_dim, bias=False) - self.w_out = nn.Linear(cfg.model_dim, cfg.model_dim, bias=False) + self.w_qkv = nn.Linear( + cfg.model_dim, 3 * cfg.model_dim, bias=False, dtype=cfg.param_dtype + ) + self.w_out = nn.Linear( + cfg.model_dim, cfg.model_dim, bias=False, dtype=cfg.param_dtype + ) # Split into Q, K, V sections wq, wk, wv = torch.chunk(self.w_qkv.weight, 3, dim=0) for w in [wq, wk, wv]: @@ -99,7 +113,9 @@ def __init__(self, cfg: ModelConfig): self.eps = cfg.qknorm_epsilon # e.g., 1e-6 seq_len = cfg.seq_len attn_scale0 = math.log2(seq_len**2 - seq_len) - self.attn_scale = nn.Parameter(torch.tensor(attn_scale0)) + self.attn_scale = nn.Parameter( + torch.tensor(attn_scale0, dtype=cfg.param_dtype) + ) def forward(self, x, freqs_cis): bsz, seqlen, d = x.shape # (bsz, seqlen, d) @@ -142,13 +158,18 @@ class Block(nn.Module): def __init__(self, layer_id: int, cfg: ModelConfig): super().__init__() self.attn = Attention(cfg) - self.attn_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) + self.attn_norm = nn.RMSNorm( + cfg.model_dim, eps=cfg.rmsnorm_epsilon, dtype=cfg.param_dtype + ) self.mlp = MLP( dim=cfg.model_dim, hidden_dim=cfg.expanded_model_dim, multiple_of=cfg.multiple_of, + dtype=cfg.param_dtype, + ) + self.mlp_norm = nn.RMSNorm( + cfg.model_dim, eps=cfg.rmsnorm_epsilon, dtype=cfg.param_dtype ) - self.mlp_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) self.layer_id = layer_id def forward(self, x, freqs_cis): @@ -166,12 +187,18 @@ def __init__(self, cfg: ModelConfig): head_dim = cfg.model_dim // cfg.num_heads assert cfg.model_dim % cfg.num_heads == 0 - self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.model_dim) + self.embed_tokens = nn.Embedding( + cfg.vocab_size, cfg.model_dim, dtype=cfg.param_dtype + ) self.layers = nn.ModuleList( [Block(idx, cfg) for idx in range(cfg.num_layers)] ) - self.out_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon) - self.lm_head = nn.Linear(cfg.model_dim, cfg.vocab_size, bias=False) + self.out_norm = nn.RMSNorm( + cfg.model_dim, eps=cfg.rmsnorm_epsilon, dtype=cfg.param_dtype + ) + self.lm_head = nn.Linear( + cfg.model_dim, cfg.vocab_size, bias=False, dtype=cfg.param_dtype + ) # Initialize freqs_cis on CPU first (more memory efficient) self.register_buffer( @@ -215,6 +242,7 @@ def forward(self, x, targets=None): for layer in self.layers: x = layer(x, freqs_cis) # (bsz, seqlen, dim) out = self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) + if targets is not None: loss = F.cross_entropy( out.view(-1, out.size(-1)), targets.view(-1), ignore_index=-100 @@ -232,40 +260,43 @@ def predict(self, x, k=1): Returns: Tuple of (input_ids, predicted_ids) """ + # Determine device type for autocast + device_type = 'cuda' if x.is_cuda else 'cpu' - # Store original input - original_input = x.clone() - generated_input = x.clone() + with torch.autocast(device_type=device_type, dtype=self.cfg.compute_dtype): + # Store original input + original_input = x.clone() + generated_input = x.clone() - # Generate k tokens autoregressively - for i in range(k): - # Get logits for the entire sequence - logits = self(generated_input) + # Generate k tokens autoregressively + for i in range(k): + # Get logits for the entire sequence + logits = self(generated_input) - # Get the logits for the last token in each sequence - next_token_logits = logits[:, -1, :] + # Get the logits for the last token in each sequence + next_token_logits = logits[:, -1, :] - # Zero out the last token ID to prevent repetition - # This is a common issue - the model gets stuck repeating the last token - last_token_id = generated_input[:, -1] - next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) + # Zero out the last token ID to prevent repetition + # This is a common issue - the model gets stuck repeating the last token + last_token_id = generated_input[:, -1] + next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) - # Get the most likely token - next_token = torch.argmax(next_token_logits, dim=-1) + # Get the most likely token + next_token = torch.argmax(next_token_logits, dim=-1) - # Append the predicted token to the sequence - next_token = next_token.unsqueeze(1) # Add sequence dimension - generated_input = torch.cat([generated_input, next_token], dim=1) + # Append the predicted token to the sequence + next_token = next_token.unsqueeze(1) # Add sequence dimension + generated_input = torch.cat([generated_input, next_token], dim=1) - # For debugging, print predictions for the first item in the batch - print('\nPyTorch detailed prediction (first item in batch):') - predicted_sequence = generated_input[0, -k:].tolist() - print(f' Predicted token IDs: {predicted_sequence}') - for i, token_id in enumerate(predicted_sequence): - print(f' Step {i + 1}: Predicted token {token_id}') + # For debugging, print predictions for the first item in the batch + print('\nPyTorch detailed prediction (first item in batch):') + predicted_sequence = generated_input[0, -k:].tolist() + print(f' Predicted token IDs: {predicted_sequence}') + for i, token_id in enumerate(predicted_sequence): + print(f' Step {i + 1}: Predicted token {token_id}') - # Return all tokens, not just the last k - return original_input, generated_input[:, -k:] + # Return all tokens, not just the last k + return original_input, generated_input[:, -k:] def _init_weights(self, module): if isinstance(module, nn.Linear): @@ -318,6 +349,8 @@ def main(): # Instantiate the model model = Transformer(config) print(f'Model has {model.count_params():,} parameters.') + for n, p in model.named_parameters(): + print(f'{n}.dtype == {p.dtype}') # Create some random input data batch_size = 2 @@ -330,6 +363,7 @@ def main(): # Run a forward pass print(f'Running forward pass with input shape: {input_ids.shape}') logits = model(input_ids) + print(f'Output logits dtype: {logits.dtype}') print(f'Output logits shape: {logits.shape}') # Run prediction diff --git a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py index a25ca334a..ed922f9c2 100644 --- a/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py +++ b/algoperf/workloads/finewebedu_lm/finewebedu_lm_pytorch/workload.py @@ -19,10 +19,25 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() +# Dtype mapping from string to PyTorch dtype +DTYPE_MAP = { + 'float32': torch.float32, + 'float16': torch.float16, + 'bfloat16': torch.bfloat16, +} + class LmWorkload(BaseLmWorkload): """LM PyTorch workload.""" + @property + def _compute_dtype(self) -> torch.dtype: + return DTYPE_MAP[self._compute_dtype_str] + + @property + def _param_dtype(self) -> torch.dtype: + return DTYPE_MAP[self._param_dtype_str] + def init_model_fn( self, rng: spec.RandomState, @@ -40,11 +55,14 @@ def init_model_fn( vocab_size=self._vocab_size, seq_len=self._seq_len, model_dim=self._emb_dim, # Model dimension - expanded_model_dim=self._mlp_dim, # MLP expansion factor - num_layers=self._n_layers, # Number of transformer layers - num_heads=self._n_heads, # Number of attention heads - rmsnorm_epsilon=1e-6, - tie_embeddings=True, + expanded_model_dim=self._mlp_dim, # MLP expanded dim + num_layers=self._n_layers, + num_heads=self._n_heads, + rmsnorm_epsilon=self._rmsnorm_epsilon, + qknorm_epsilon=self._qknorm_epsilon, + tie_embeddings=self._tie_embeddings, + compute_dtype=self._compute_dtype, + param_dtype=self._param_dtype, ) self._model = Transformer(cfg) self._param_shapes = param_utils.pytorch_param_shapes(self._model) @@ -81,13 +99,18 @@ def model_fn( spec.ForwardPassMode.EVAL: torch.no_grad, spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } + + # Determine device type for autocast + device_type = 'cuda' if DEVICE.type == 'cuda' else 'cpu' + with contexts[mode](): - # Convert one-hot inputs to token IDs if needed - inputs = augmented_and_preprocessed_input_batch['inputs'] - if inputs.dim() == 3: # one-hot encoded - inputs = inputs.argmax(dim=-1) + with torch.autocast(device_type=device_type, dtype=self._compute_dtype): + # Convert one-hot inputs to token IDs if needed + inputs = augmented_and_preprocessed_input_batch['inputs'] + if inputs.dim() == 3: # one-hot encoded + inputs = inputs.argmax(dim=-1) - logits = model(inputs) + logits = model(inputs) return logits, None @@ -121,7 +144,7 @@ def _build_input_queue( batch['targets'], device=DEVICE, dtype=torch.int64 ), 'weights': torch.tensor( - batch['weights'], device=DEVICE, dtype=torch.float32 + batch['weights'], device=DEVICE, dtype=self._param_dtype ) if batch['weights'] is not None else None, @@ -157,29 +180,35 @@ def loss_fn( - 'n_valid_examples': Scalar tensor with the count of valid (non-masked) examples. - 'per_example': Tensor of shape [batch, length] with individual losses per example. """ - vocab_size = logits_batch.size(-1) - - # Compute cross-entropy loss with label smoothing - per_example_losses = torch.nn.functional.cross_entropy( - logits_batch.view(-1, vocab_size), - label_batch.view(-1), - reduction='none', - label_smoothing=label_smoothing, - ) - per_example_losses = per_example_losses.view_as(label_batch) - - # Apply weights if provided - if mask_batch is not None: - per_example_losses = per_example_losses * mask_batch - - # Calculate number of valid examples - n_valid_examples = ( - mask_batch.sum() - if mask_batch is not None - else torch.tensor( - label_batch.numel(), dtype=torch.float32, device=label_batch.device + # Determine device type for autocast + device_type = 'cuda' if logits_batch.is_cuda else 'cpu' + + with torch.autocast(device_type=device_type, dtype=self._compute_dtype): + vocab_size = logits_batch.size(-1) + + # Compute cross-entropy loss with label smoothing + per_example_losses = torch.nn.functional.cross_entropy( + logits_batch.view(-1, vocab_size), + label_batch.view(-1), + reduction='none', + label_smoothing=label_smoothing, + ) + per_example_losses = per_example_losses.view_as(label_batch) + + # Apply weights if provided + if mask_batch is not None: + per_example_losses = per_example_losses * mask_batch + + # Calculate number of valid examples + n_valid_examples = ( + mask_batch.sum() + if mask_batch is not None + else torch.tensor( + label_batch.numel(), + dtype=self._param_dtype, + device=label_batch.device, + ) ) - ) return { 'summed': per_example_losses.sum(), diff --git a/algoperf/workloads/finewebedu_lm/workload.py b/algoperf/workloads/finewebedu_lm/workload.py index 5d6e3d742..b5a258f6f 100644 --- a/algoperf/workloads/finewebedu_lm/workload.py +++ b/algoperf/workloads/finewebedu_lm/workload.py @@ -27,6 +27,16 @@ class BaseLmWorkload(spec.Workload): _mlp_dim: int = 4096 warmup_factor: float = 0.1 + # Model configuration + _rmsnorm_epsilon: float = 1e-6 + _qknorm_epsilon: float = 1e-6 + _tie_embeddings: bool = True + + # Dtype configuration (as strings, to be converted by framework-specific subclasses) + _compute_dtype_str: str = 'bfloat16' + _param_dtype_str: str = 'float32' + _output_dtype_str: str = 'bfloat16' # Only used by JAX + def __init__(self) -> None: super().__init__() self._param_shapes = None @@ -85,11 +95,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 31_967 # 8.9 hours + return 31_967 # 8.9 hours @property def eval_period_time_sec(self) -> int: - return 2_571 # approximately 25 evals + return 2_571 # approximately 25 evals @property def step_hint(self) -> int: @@ -164,9 +174,12 @@ def _eval_model_on_split( eval_batch = next(self._eval_iters[split]) metrics = self._eval_batch(params, eval_batch, model_state, rng) for metric_name, metric_value in metrics.items(): - if metric_name not in eval_metrics: - eval_metrics[metric_name] = 0.0 - eval_metrics[metric_name] += metric_value + eval_metrics.update( + {metric_name: eval_metrics.get(metric_name, 0.0) + metric_value} + ) + print( + f"Completed eval batch {_ + 1}/{num_batches} for split '{split}' at global step {global_step}." + ) eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) eval_results['ppl'] = np.exp(eval_results['loss']).item() diff --git a/pyproject.toml b/pyproject.toml index 006e7e5cd..e3d86df3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,7 @@ librispeech_conformer = [ "pydub==0.25.1", ] wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.19.0"] -lm = ["transformers==4.26", "datasets==3.6.0"] +lm = ["transformers==4.26.0", "datasets==3.6.0"] # Frameworks jax_core_deps = [ @@ -99,6 +99,7 @@ jax_core_deps = [ "chex==0.1.86", "ml_dtypes==0.5.1", "protobuf==4.25.5", + "jmp>=0.0.4" ] jax_cpu = [ "jax==0.7.0", From 6f7d638adc190d9bce3f30ba3314c27dac1a8cc5 Mon Sep 17 00:00:00 2001 From: rka97 Date: Mon, 1 Dec 2025 04:33:49 +0000 Subject: [PATCH 2/5] ImageNet and CIFAR mixed-precision support, need to debug slow pytorch - Introduced DTYPE enum to standardize data types (FLOAT32, FLOAT16, BFLOAT16) for JAX and PyTorch. - Updated input pipelines and model definitions in CIFAR and ImageNet workloads to utilize mixed precision. - Implemented casting policies for parameters and inputs using jmp and torch.autocast. --- algoperf/spec.py | 23 ++++++ .../cifar/cifar_jax/input_pipeline.py | 2 - algoperf/workloads/cifar/cifar_jax/models.py | 8 ++- .../workloads/cifar/cifar_jax/workload.py | 29 ++++++-- .../workloads/cifar/cifar_pytorch/models.py | 24 ++++++- .../workloads/cifar/cifar_pytorch/workload.py | 9 ++- algoperf/workloads/cifar/workload.py | 2 + .../imagenet_resnet/imagenet_jax/models.py | 8 ++- .../imagenet_resnet/imagenet_jax/workload.py | 28 ++++++-- .../imagenet_pytorch/models.py | 52 +++++++++++--- .../imagenet_pytorch/workload.py | 9 ++- .../workloads/imagenet_resnet/workload.py | 2 + .../imagenet_vit/imagenet_jax/models.py | 48 +++++++++---- .../imagenet_vit/imagenet_jax/workload.py | 10 ++- .../imagenet_vit/imagenet_pytorch/models.py | 71 +++++++++++++------ .../imagenet_vit/imagenet_pytorch/workload.py | 12 ++-- algoperf/workloads/ogbg/workload.py | 2 +- .../external_tuning/jax_nadamw_full_budget.py | 2 + .../pytorch_nadamw_full_budget.py | 10 +-- scoring/performance_profile.py | 2 +- submission_runner.py | 25 +++++-- 21 files changed, 288 insertions(+), 90 deletions(-) diff --git a/algoperf/spec.py b/algoperf/spec.py index b86e55954..8dd00345c 100644 --- a/algoperf/spec.py +++ b/algoperf/spec.py @@ -6,11 +6,34 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union import jax +import jax.numpy as jnp +import torch import torch.nn.functional as F from absl import logging from torch import nn +class DTYPE(enum.Enum): + FLOAT32 = 0 + FLOAT16 = 1 + BFLOAT16 = 2 + + +# Mapping from DTYPE enum to JAX dtypes +JAX_DTYPE_MAP = { + DTYPE.FLOAT32: jnp.float32, + DTYPE.FLOAT16: jnp.float16, + DTYPE.BFLOAT16: jnp.bfloat16, +} + +# Mapping from DTYPE enum to PyTorch dtypes +PYTORCH_DTYPE_MAP = { + DTYPE.FLOAT32: torch.float32, + DTYPE.FLOAT16: torch.float16, + DTYPE.BFLOAT16: torch.bfloat16, +} + + class LossType(enum.Enum): SOFTMAX_CROSS_ENTROPY = 0 SIGMOID_CROSS_ENTROPY = 1 diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 7fbc95bc6..307e9e705 100644 --- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -11,7 +11,6 @@ import jax import tensorflow as tf import tensorflow_datasets as tfds -from flax import jax_utils from algoperf import spec from algoperf.data_utils import shard_and_maybe_pad_np @@ -186,5 +185,4 @@ def create_input_iter( ), ds, ) - it = jax_utils.prefetch_to_device(it, 2) return it diff --git a/algoperf/workloads/cifar/cifar_jax/models.py b/algoperf/workloads/cifar/cifar_jax/models.py index 95238c997..9a4f7fd96 100644 --- a/algoperf/workloads/cifar/cifar_jax/models.py +++ b/algoperf/workloads/cifar/cifar_jax/models.py @@ -31,7 +31,7 @@ def __call__( update_batch_norm: bool = True, use_running_average_bn: bool = None, ) -> spec.Tensor: - conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + conv = functools.partial(nn.Conv, use_bias=False, param_dtype=self.dtype) # Preserve default behavior for backwards compatibility if use_running_average_bn is None: @@ -41,7 +41,7 @@ def __call__( use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, - dtype=self.dtype, + param_dtype=self.dtype, ) x = conv( @@ -66,7 +66,9 @@ def __call__( x = nn.avg_pool(x, (4, 4), strides=(4, 4)) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense( - self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype + self.num_classes, + kernel_init=nn.initializers.normal(), + param_dtype=self.dtype, )(x) return x diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index defc30121..e6bc5b419 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -5,6 +5,7 @@ import jax import jax.numpy as jnp +import jmp import optax import tensorflow_datasets as tfds from flax import linen as nn @@ -18,6 +19,17 @@ class CifarWorkload(BaseCifarWorkload): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + compute_dtype = spec.JAX_DTYPE_MAP[self._compute_dtype] + param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] + output_dtype = compute_dtype + self._mp_policy = jmp.Policy( + compute_dtype=compute_dtype, + param_dtype=param_dtype, + output_dtype=output_dtype, + ) + def _build_cifar_dataset( self, data_rng: spec.RandomState, @@ -80,7 +92,8 @@ def sync_batch_stats( def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: """Dropout is unused.""" model_cls = getattr(models, 'ResNet18') - model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) + param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] + model = model_cls(num_classes=self._num_classes, dtype=param_dtype) self._model = model input_shape = (1, 32, 32, 3) variables = jax.jit(model.init)( @@ -89,7 +102,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_sharding_utils.replicate(params) + model_state = jax_sharding_utils.replicate(model_state) params = jax_sharding_utils.replicate(params) return params, model_state @@ -110,24 +123,32 @@ def model_fn( del mode del rng del dropout_rate + # Cast params and inputs to compute dtype + params, inputs = self._mp_policy.cast_to_compute( + (params, augmented_and_preprocessed_input_batch['inputs']) + ) variables = {'params': params, **model_state} if update_batch_norm: logits, new_model_state = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + inputs, update_batch_norm=update_batch_norm, mutable=['batch_stats'], use_running_average_bn=use_running_average_bn, ) + # Cast logits to output dtype + logits = self._mp_policy.cast_to_output(logits) return logits, new_model_state else: logits = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + inputs, update_batch_norm=update_batch_norm, mutable=False, use_running_average_bn=use_running_average_bn, ) + # Cast logits to output dtype + logits = self._mp_policy.cast_to_output(logits) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algoperf/workloads/cifar/cifar_pytorch/models.py b/algoperf/workloads/cifar/cifar_pytorch/models.py index 0e08f5c5a..b2b37c001 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/models.py +++ b/algoperf/workloads/cifar/cifar_pytorch/models.py @@ -29,11 +29,13 @@ def __init__( width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer + self.dtype = dtype self.inplanes = 64 self.dilation = 1 @@ -49,7 +51,13 @@ def __init__( self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( - 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False + 3, + self.inplanes, + kernel_size=3, + stride=1, + padding=1, + bias=False, + dtype=dtype, ) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) @@ -63,7 +71,7 @@ def __init__( self.layer4 = self._make_layer( block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] ) - self.fc = nn.Linear(512 * block.expansion, num_classes) + self.fc = nn.Linear(512 * block.expansion, num_classes, dtype=dtype) self.reset_parameters() def reset_parameters(self) -> None: @@ -105,7 +113,15 @@ def _make_layer( downsample = torch.nn.Sequential( collections.OrderedDict( [ - ('conv', conv1x1(self.inplanes, planes * block.expansion, stride)), + ( + 'conv', + conv1x1( + self.inplanes, + planes * block.expansion, + stride, + dtype=self.dtype, + ), + ), ('bn', norm_layer(planes * block.expansion)), ] ) @@ -122,6 +138,7 @@ def _make_layer( self.base_width, previous_dilation, norm_layer, + dtype=self.dtype, ) ) self.inplanes = planes * block.expansion @@ -134,6 +151,7 @@ def _make_layer( base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer, + dtype=self.dtype, ) ) diff --git a/algoperf/workloads/cifar/cifar_pytorch/workload.py b/algoperf/workloads/cifar/cifar_pytorch/workload.py index a6e8569cc..141bef922 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/workload.py +++ b/algoperf/workloads/cifar/cifar_pytorch/workload.py @@ -25,6 +25,8 @@ def __init__(self, *args, **kwargs) -> None: # Is set in submission_runner.py for workloads with PyTorch evaluation # data loaders via the `eval_num_workers` property. self._eval_num_workers = None + self._param_dtype_pt = spec.PYTORCH_DTYPE_MAP[self._param_dtype] + self._compute_dtype_pt = spec.PYTORCH_DTYPE_MAP[self._compute_dtype] @property def eval_num_workers(self) -> int: @@ -128,7 +130,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: return self._model, None torch.random.manual_seed(rng[0]) - self._model = resnet18(num_classes=self._num_classes) + self._model = resnet18( + num_classes=self._num_classes, dtype=self._param_dtype_pt + ) self._param_shapes = param_utils.pytorch_param_shapes(self._model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) self._model.to(DEVICE) @@ -175,7 +179,8 @@ def model_fn( spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): - logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) + with torch.autocast(device_type='cuda', dtype=self._compute_dtype_pt): + logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) return logits_batch, None # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algoperf/workloads/cifar/workload.py b/algoperf/workloads/cifar/workload.py index 31636807c..6866bc918 100644 --- a/algoperf/workloads/cifar/workload.py +++ b/algoperf/workloads/cifar/workload.py @@ -16,6 +16,8 @@ class BaseCifarWorkload(spec.Workload): _num_classes: int = 10 + _compute_dtype: spec.DTYPE = spec.DTYPE.BFLOAT16 + _param_dtype: spec.DTYPE = spec.DTYPE.FLOAT32 @property def target_metric_name(self) -> str: diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py index ee1ddf427..41551d4d2 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py @@ -90,7 +90,7 @@ def __call__( update_batch_norm: bool = True, use_running_average_bn: Optional[bool] = None, ) -> spec.Tensor: - conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + conv = functools.partial(nn.Conv, use_bias=False, param_dtype=self.dtype) # Preserve default behavior for backwards compatibility if use_running_average_bn is None: use_running_average_bn = not update_batch_norm @@ -99,7 +99,7 @@ def __call__( use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, - dtype=self.dtype, + param_dtype=self.dtype, ) x = conv( @@ -125,7 +125,9 @@ def __call__( )(x) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense( - self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype + self.num_classes, + kernel_init=nn.initializers.normal(), + param_dtype=self.dtype, )(x) return x diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index f73a1b26e..d7a8ede67 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -11,6 +11,7 @@ import jax import jax.numpy as jnp +import jmp import optax import tensorflow_datasets as tfds from flax import linen as nn @@ -29,6 +30,17 @@ class ImagenetResNetWorkload(BaseImagenetResNetWorkload): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + compute_dtype = spec.JAX_DTYPE_MAP[self._compute_dtype] + param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] + output_dtype = compute_dtype + self._mp_policy = jmp.Policy( + compute_dtype=compute_dtype, + param_dtype=param_dtype, + output_dtype=output_dtype, + ) + def _build_dataset( self, data_rng: spec.RandomState, @@ -89,11 +101,12 @@ def init_model_fn( else: act_fnc = nn.relu + param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] model = model_cls( num_classes=self._num_classes, act=act_fnc, bn_init_scale=self.bn_init_scale, - dtype=jnp.float32, + dtype=param_dtype, ) self._model = model input_shape = (1, 224, 224, 3) @@ -159,25 +172,28 @@ def model_fn( del mode del rng del dropout_rate + params, inputs = self._mp_policy.cast_to_compute( + (params, augmented_and_preprocessed_input_batch['inputs']) + ) variables = {'params': params, **model_state} if update_batch_norm: - logits, new_model_state = self._model.apply( + logits, model_state = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + inputs, update_batch_norm=update_batch_norm, mutable=['batch_stats'], use_running_average_bn=use_running_average_bn, ) - return logits, new_model_state else: logits = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + inputs, update_batch_norm=update_batch_norm, mutable=False, use_running_average_bn=use_running_average_bn, ) - return logits, model_state + logits = self._mp_policy.cast_to_output(logits) + return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py index c980faa06..f24ba66b9 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py @@ -20,6 +20,7 @@ def conv3x3( stride: int = 1, groups: int = 1, dilation: int = 1, + dtype: torch.dtype = torch.float32, ) -> nn.Conv2d: """3x3 convolution with padding.""" return nn.Conv2d( @@ -31,13 +32,24 @@ def conv3x3( groups=groups, bias=False, dilation=dilation, + dtype=dtype, ) -def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: +def conv1x1( + in_planes: int, + out_planes: int, + stride: int = 1, + dtype: torch.dtype = torch.float32, +) -> nn.Conv2d: """1x1 convolution.""" return nn.Conv2d( - in_planes, out_planes, kernel_size=1, stride=stride, bias=False + in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False, + dtype=dtype, ) @@ -57,6 +69,7 @@ def __init__( dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, act_fnc: nn.Module = nn.ReLU(inplace=True), + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: @@ -67,10 +80,10 @@ def __init__( raise NotImplementedError('Dilation > 1 not supported in BasicBlock') # Both self.conv1 and self.downsample layers downsample # the input when stride != 1. - self.conv1 = conv3x3(inplanes, planes, stride) + self.conv1 = conv3x3(inplanes, planes, stride, dtype=dtype) self.bn1 = norm_layer(planes) self.act_fnc = act_fnc - self.conv2 = conv3x3(planes, planes) + self.conv2 = conv3x3(planes, planes, dtype=dtype) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride @@ -110,6 +123,7 @@ def __init__( dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, act_fnc: nn.Module = nn.ReLU(inplace=True), + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: @@ -117,11 +131,11 @@ def __init__( width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample # the input when stride != 1. - self.conv1 = conv1x1(inplanes, width) + self.conv1 = conv1x1(inplanes, width, dtype=dtype) self.bn1 = norm_layer(width) - self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.conv2 = conv3x3(width, width, stride, groups, dilation, dtype=dtype) self.bn2 = norm_layer(width) - self.conv3 = conv1x1(width, planes * self.expansion) + self.conv3 = conv1x1(width, planes * self.expansion, dtype=dtype) self.bn3 = norm_layer(planes * self.expansion) self.act_fnc = act_fnc self.downsample = downsample @@ -163,11 +177,13 @@ def __init__( norm_layer: Optional[Callable[..., nn.Module]] = None, act_fnc: nn.Module = nn.ReLU(inplace=True), bn_init_scale: float = 0.0, + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer + self.dtype = dtype self.inplanes = 64 self.dilation = 1 @@ -183,7 +199,13 @@ def __init__( self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( - 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False + 3, + self.inplanes, + kernel_size=7, + stride=2, + padding=3, + bias=False, + dtype=dtype, ) self.bn1 = norm_layer(self.inplanes) self.act_fnc = act_fnc @@ -214,7 +236,7 @@ def __init__( dilate=replace_stride_with_dilation[2], ) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(512 * block.expansion, num_classes) + self.fc = nn.Linear(512 * block.expansion, num_classes, dtype=dtype) for m in self.modules(): if isinstance(m, nn.Conv2d): @@ -256,7 +278,15 @@ def _make_layer( downsample = torch.nn.Sequential( collections.OrderedDict( [ - ('conv', conv1x1(self.inplanes, planes * block.expansion, stride)), + ( + 'conv', + conv1x1( + self.inplanes, + planes * block.expansion, + stride, + dtype=self.dtype, + ), + ), ('bn', norm_layer(planes * block.expansion)), ] ) @@ -274,6 +304,7 @@ def _make_layer( previous_dilation, norm_layer, act_fnc, + dtype=self.dtype, ) ) self.inplanes = planes * block.expansion @@ -287,6 +318,7 @@ def _make_layer( dilation=self.dilation, norm_layer=norm_layer, act_fnc=act_fnc, + dtype=self.dtype, ) ) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index d5366c60d..3a88245ae 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -178,7 +178,10 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: else: act_fnc = torch.nn.ReLU(inplace=True) - model = resnet50(act_fnc=act_fnc, bn_init_scale=self.bn_init_scale) + param_dtype = spec.PYTORCH_DTYPE_MAP[self._param_dtype] + model = resnet50( + act_fnc=act_fnc, bn_init_scale=self.bn_init_scale, dtype=param_dtype + ) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -229,8 +232,10 @@ def model_fn( spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } + compute_dtype = spec.PYTORCH_DTYPE_MAP[self._compute_dtype] with contexts[mode](): - logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) + with torch.autocast(device_type='cuda', dtype=compute_dtype): + logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) return logits_batch, None diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py index de8458c92..bc5982f1d 100644 --- a/algoperf/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -8,6 +8,8 @@ class BaseImagenetResNetWorkload(spec.Workload): _num_classes: int = 1000 + _compute_dtype: spec.DTYPE = spec.DTYPE.BFLOAT16 + _param_dtype: spec.DTYPE = spec.DTYPE.FLOAT32 @property def target_metric_name(self) -> str: diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index e86233011..2e4630701 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -42,6 +42,7 @@ class MlpBlock(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim. use_glu: bool = False dropout_rate: float = DROPOUT_RATE + dtype: jnp.dtype = jnp.float32 @nn.compact def __call__( @@ -54,15 +55,15 @@ def __call__( } d = x.shape[2] - x = nn.Dense(self.mlp_dim or 4 * d, **inits)(x) + x = nn.Dense(self.mlp_dim or 4 * d, param_dtype=self.dtype, **inits)(x) x = nn.gelu(x) if self.use_glu: - y = nn.Dense(self.mlp_dim, **inits)(x) + y = nn.Dense(self.mlp_dim, param_dtype=self.dtype, **inits)(x) x = x * y x = Dropout(dropout_rate)(x, train, rate=dropout_rate) - x = nn.Dense(d, **inits)(x) + x = nn.Dense(d, param_dtype=self.dtype, **inits)(x) return x @@ -74,25 +75,30 @@ class Encoder1DBlock(nn.Module): use_glu: bool = False use_post_layer_norm: bool = False dropout_rate: float = 0.0 + dtype: jnp.dtype = jnp.float32 @nn.compact def __call__( self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate ) -> spec.Tensor: if not self.use_post_layer_norm: - y = nn.LayerNorm(name='LayerNorm_0')(x) + y = nn.LayerNorm(name='LayerNorm_0', param_dtype=self.dtype)(x) y = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=train, name='MultiHeadDotProductAttention_1', + param_dtype=self.dtype, )(y) y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y - y = nn.LayerNorm(name='LayerNorm_2')(x) + y = nn.LayerNorm(name='LayerNorm_2', param_dtype=self.dtype)(x) y = MlpBlock( - mlp_dim=self.mlp_dim, use_glu=self.use_glu, name='MlpBlock_3' + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dtype=self.dtype, + name='MlpBlock_3', )(y, train, dropout_rate=dropout_rate) y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y @@ -103,21 +109,23 @@ def __call__( kernel_init=nn.initializers.xavier_uniform(), deterministic=train, name='MultiHeadDotProductAttention_1', + param_dtype=self.dtype, )(y) y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y - x = nn.LayerNorm(name='LayerNorm_0')(x) + x = nn.LayerNorm(name='LayerNorm_0', param_dtype=self.dtype)(x) y = x y = MlpBlock( mlp_dim=self.mlp_dim, use_glu=self.use_glu, + dtype=self.dtype, name='MlpBlock_3', dropout_rate=dropout_rate, )(y, train, dropout_rate=dropout_rate) y = Dropout(dropout_rate)(y, train)(rate=dropout_rate) x = x + y - x = nn.LayerNorm(name='LayerNorm_2')(x) + x = nn.LayerNorm(name='LayerNorm_2', param_dtype=self.dtype)(x) return x @@ -130,6 +138,7 @@ class Encoder(nn.Module): num_heads: int = 12 use_glu: bool = False use_post_layer_norm: bool = False + dtype: jnp.dtype = jnp.float32 @nn.compact def __call__( @@ -143,9 +152,10 @@ def __call__( num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, + dtype=self.dtype, )(x, train=train, dropout_rate=dropout_rate) if not self.use_post_layer_norm: - return nn.LayerNorm(name='encoder_layernorm')(x) + return nn.LayerNorm(name='encoder_layernorm', param_dtype=self.dtype)(x) else: return x @@ -156,12 +166,13 @@ class MAPHead(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim num_heads: int = 12 dropout_rate: float = 0.0 + dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, x, dropout_rate=DROPOUT_RATE): n, _, d = x.shape probe = self.param( - 'probe', nn.initializers.xavier_uniform(), (1, 1, d), x.dtype + 'probe', nn.initializers.xavier_uniform(), (1, 1, d), self.dtype ) probe = jnp.tile(probe, [n, 1, 1]) @@ -169,10 +180,13 @@ def __call__(self, x, dropout_rate=DROPOUT_RATE): num_heads=self.num_heads, use_bias=True, kernel_init=nn.initializers.xavier_uniform(), + param_dtype=self.dtype, )(probe, x) - y = nn.LayerNorm()(x) - x = x + MlpBlock(mlp_dim=self.mlp_dim, dropout_rate=dropout_rate)(y) + y = nn.LayerNorm(param_dtype=self.dtype)(x) + x = x + MlpBlock( + mlp_dim=self.mlp_dim, dropout_rate=dropout_rate, dtype=self.dtype + )(y) return x[:, 0] @@ -192,6 +206,7 @@ class ViT(nn.Module): use_glu: bool = False use_post_layer_norm: bool = False use_map: bool = False + dtype: jnp.dtype = jnp.float32 def get_posemb( self, seqshape: tuple, width: int, dtype: jnp.dtype = jnp.float32 @@ -209,6 +224,7 @@ def __call__( strides=self.patch_size, padding='VALID', name='conv_patch_extract', + param_dtype=self.dtype, )(x) n, h, w, c = x.shape @@ -225,6 +241,7 @@ def __call__( num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, + dtype=self.dtype, name='Transformer', )(x, train=not train, dropout_rate=dropout_rate) @@ -233,18 +250,21 @@ def __call__( num_heads=self.num_heads, mlp_dim=self.mlp_dim, dropout_rate=dropout_rate, + dtype=self.dtype, )(x, dropout_rate=dropout_rate) else: x = jnp.mean(x, axis=1) if self.rep_size: rep_size = self.width if self.rep_size is True else self.rep_size - hid = nn.Dense(rep_size, name='pre_logits') + hid = nn.Dense(rep_size, name='pre_logits', param_dtype=self.dtype) x = nn.tanh(hid(x)) if self.num_classes: kw = {'kernel_init': nn.initializers.zeros} if self.head_zeroinit else {} - head = nn.Dense(self.num_classes, name='head', **kw) + head = nn.Dense( + self.num_classes, name='head', param_dtype=self.dtype, **kw + ) x = head(x) return x diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 8a33aeb47..6819a4862 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -32,11 +32,13 @@ def initialized( return params, model_state def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: + param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] self._model = models.ViT( num_classes=self._num_classes, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, use_map=self.use_map, + dtype=param_dtype, **decode_variant('S/16'), ) params, model_state = self.initialized(rng, self._model) @@ -62,15 +64,19 @@ def model_fn( ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm - del use_running_average_bn + # Cast params and inputs to compute dtype + params, inputs = self._mp_policy.cast_to_compute( + (params, augmented_and_preprocessed_input_batch['inputs']) + ) train = mode == spec.ForwardPassMode.TRAIN logits = self._model.apply( {'params': params}, - augmented_and_preprocessed_input_batch['inputs'], + inputs, rngs={'dropout': rng}, train=train, dropout_rate=dropout_rate, ) + logits = self._mp_policy.cast_to_output(logits) return logits, None def _eval_model_on_split( diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index fc2a3cd46..6dfb5fddf 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -46,22 +46,24 @@ def __init__( width: int, mlp_dim: Optional[int] = None, # Defaults to 4x input dim. use_glu: bool = False, + dtype: Any = torch.float32, ) -> None: super().__init__() self.width = width self.mlp_dim = mlp_dim or 4 * width self.use_glu = use_glu + self.dtype = dtype - self.linear1 = nn.Linear(self.width, self.mlp_dim) + self.linear1 = nn.Linear(self.width, self.mlp_dim, dtype=self.dtype) self.act_fnc = nn.GELU(approximate='tanh') if self.use_glu: - self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim) + self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim, dtype=self.dtype) else: self.glu_linear = None - self.linear2 = nn.Linear(self.mlp_dim, self.width) + self.linear2 = nn.Linear(self.mlp_dim, self.width, dtype=self.dtype) self.reset_parameters() @@ -85,14 +87,18 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: return x +# TODO(rka97): switch this to built-in attention with cudnn class SelfAttention(nn.Module): """Self-attention special case of multi-head dot-product attention.""" - def __init__(self, width: int, num_heads: int = 8) -> None: + def __init__( + self, width: int, num_heads: int = 8, dtype: Any = torch.float32 + ) -> None: super().__init__() self.width = width self.num_heads = num_heads + self.dtype = dtype assert width % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.' @@ -101,10 +107,10 @@ def __init__(self, width: int, num_heads: int = 8) -> None: self.head_dim = int(width / num_heads) self.all_head_dim = self.num_heads * self.head_dim - self.query = nn.Linear(self.width, self.all_head_dim) - self.key = nn.Linear(self.width, self.all_head_dim) - self.value = nn.Linear(self.width, self.all_head_dim) - self.out = nn.Linear(self.width, self.width) + self.query = nn.Linear(self.width, self.all_head_dim, dtype=self.dtype) + self.key = nn.Linear(self.width, self.all_head_dim, dtype=self.dtype) + self.value = nn.Linear(self.width, self.all_head_dim, dtype=self.dtype) + self.out = nn.Linear(self.width, self.width, dtype=self.dtype) self.reset_parameters() def reset_parameters(self) -> None: @@ -150,6 +156,7 @@ def __init__( num_heads: int = 12, use_glu: bool = False, use_post_layer_norm: bool = False, + dtype: Any = torch.float32, ) -> None: super().__init__() @@ -158,12 +165,18 @@ def __init__( self.num_heads = num_heads self.use_glu = use_glu self.use_post_layer_norm = use_post_layer_norm + self.dtype = dtype - self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6) - self.self_attention1 = SelfAttention(self.width, self.num_heads) - self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) + self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) + self.self_attention1 = SelfAttention( + self.width, self.num_heads, dtype=self.dtype + ) + self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) self.mlp3 = MlpBlock( - width=self.width, mlp_dim=self.mlp_dim, use_glu=self.use_glu + width=self.width, + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dtype=self.dtype, ) def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: @@ -203,6 +216,7 @@ def __init__( num_heads: int = 12, use_glu: bool = False, use_post_layer_norm: bool = False, + dtype: Any = torch.float32, ) -> None: super().__init__() @@ -212,6 +226,7 @@ def __init__( self.num_heads = num_heads self.use_glu = use_glu self.use_post_layer_norm = use_post_layer_norm + self.dtype = dtype self.net = nn.ModuleList( [ @@ -221,13 +236,14 @@ def __init__( self.num_heads, self.use_glu, self.use_post_layer_norm, + dtype=self.dtype, ) for _ in range(depth) ] ) if not self.use_post_layer_norm: - self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6) + self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) else: self.encoder_norm = None @@ -245,21 +261,32 @@ class MAPHead(nn.Module): """Multihead Attention Pooling.""" def __init__( - self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12 + self, + width: int, + mlp_dim: Optional[int] = None, + num_heads: int = 12, + dtype: torch.dtype = torch.float32, ): super().__init__() self.width = width self.mlp_dim = mlp_dim self.num_heads = num_heads + self.dtype = dtype self.probe = nn.Parameter(torch.zeros((1, 1, self.width))) nn.init.xavier_uniform_(self.probe.data) self.mha = MultiheadAttention( - self.width, num_heads=self.num_heads, self_attn=False, bias=True + self.width, + num_heads=self.num_heads, + self_attn=False, + bias=True, + dtype=self.dtype, + ) + self.layer_norm = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) + self.mlp = MlpBlock( + width=self.width, mlp_dim=self.mlp_dim, dtype=self.dtype ) - self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) - self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: n, _, _ = x.shape @@ -310,7 +337,7 @@ def __init__( if self.rep_size: rep_size = self.width if self.rep_size is True else self.rep_size - self.pre_logits = nn.Linear(self.width, rep_size) + self.pre_logits = nn.Linear(self.width, rep_size, dtype=self.dtype) self.conv_patch_extract = nn.Conv2d( self.channels, @@ -318,6 +345,7 @@ def __init__( self.patch_size, stride=self.patch_size, padding='valid', + dtype=self.dtype, ) self.encoder = Encoder( @@ -327,13 +355,16 @@ def __init__( num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, + dtype=self.dtype, ) if self.num_classes: - self.head = nn.Linear(self.width, self.num_classes) + self.head = nn.Linear(self.width, self.num_classes, dtype=self.dtype) if self.use_map: - self.map = MAPHead(self.width, self.mlp_dim, self.num_heads) + self.map = MAPHead( + self.width, self.mlp_dim, self.num_heads, dtype=self.dtype + ) else: self.map = None diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index 9c6faf70b..bfef3e0a9 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -23,11 +23,13 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) + param_dtype = spec.PYTORCH_DTYPE_MAP[self._param_dtype] model = models.ViT( num_classes=self._num_classes, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, use_map=self.use_map, + dtype=param_dtype, **decode_variant('S/16'), ) self._param_shapes = param_utils.pytorch_param_shapes(model) @@ -70,11 +72,13 @@ def model_fn( spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } + compute_dtype = spec.PYTORCH_DTYPE_MAP[self._compute_dtype] with contexts[mode](): - logits_batch = model( - augmented_and_preprocessed_input_batch['inputs'], - dropout_rate=dropout_rate, - ) + with torch.autocast(device_type='cuda', dtype=compute_dtype): + logits_batch = model( + augmented_and_preprocessed_input_batch['inputs'], + dropout_rate=dropout_rate, + ) return logits_batch, None diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 002576268..771b103a0 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -92,7 +92,7 @@ def max_allowed_runtime_sec(self) -> int: @property def eval_period_time_sec(self) -> int: - return 452 # approx 25 evals + return 452 # approx 25 evals def _build_input_queue( self, diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index 0577cd4e0..a6f36fd30 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -396,6 +396,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'cifar': + return 16384 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py index 0b32199ba..285727885 100644 --- a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -5,7 +5,6 @@ import torch import torch.distributed.nn as dist_nn -from absl import logging from torch import Tensor from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR @@ -315,13 +314,6 @@ def update_params( }, global_step, ) - logging.info( - '%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item(), - ) - return (optimizer_state, current_param_container, new_model_state) @@ -372,6 +364,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'cifar': + return 16384 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index b200c6865..043a65791 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -71,7 +71,7 @@ 'wer', 'l1_loss', 'loss', - 'ppl' + 'ppl', ] MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] diff --git a/submission_runner.py b/submission_runner.py index 552c99b79..84ae3307b 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -266,6 +266,7 @@ def train_once( 'librispeech_deepspeech', 'ogbg', 'wmt', + 'cifar', ] base_workload = workloads.get_base_workload_name(workload_name) if base_workload in compile_error_workloads: @@ -409,10 +410,15 @@ def train_once( train_state['training_complete'] = True train_step_end_time = get_time() - - train_state['accumulated_submission_time'] += ( - train_step_end_time - train_state['last_step_end_time'] - ) + step_time = train_step_end_time - train_state['last_step_end_time'] + train_state['accumulated_submission_time'] += step_time + # Log training progress periodically + if global_step % 10 == 0: + logging.info( + f'Step: {global_step}, ' + f'\tLast step time: {step_time:.4f}s, ' + f'\tTotal time: {train_state["accumulated_submission_time"]:.2f}s' + ) # Check if submission is eligible for an untimed eval. if ( @@ -512,10 +518,19 @@ def train_once( latest_eval_result['accumulated_logging_time'] = train_state[ 'accumulated_logging_time' ] + # Calculate average per-step time + avg_per_step_time = ( + train_state['accumulated_submission_time'] / global_step + if global_step > 0 + else 0.0 + ) + latest_eval_result['avg_per_step_time'] = avg_per_step_time time_since_start = latest_eval_result['total_duration'] logging.info( f'Time since start: {time_since_start:.2f}s, ' - f'\tStep: {global_step}, \t{latest_eval_result}' + f'\tStep: {global_step}, ' + f'\tAvg per-step time: {avg_per_step_time:.4f}s, ' + f'\t{latest_eval_result}' ) eval_results.append((global_step, latest_eval_result)) From 9cb7f82f541812b9e3f1ca6b48b3b2dd1519075f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 2 Dec 2025 23:45:06 +0000 Subject: [PATCH 3/5] adjust step time calculation --- scoring/score_submissions.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 4b7bed2b5..83e7b8436 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -119,9 +119,13 @@ def get_summary_df(workload, workload_df, include_test_split=False): axis=1, ) - summary_df['step_time (s)'] = ( - workload_df['accumulated_submission_time'] / workload_df['global_step'] - ).iloc[-1][-1] + # compute the step times + def delta(series): + return series.shift(1, fill_value=0) - series + accumulated_time_intervals = delta(workload_df['accumulated_submission_time']) + step_intervals = delta(workload_df['global_step']) + + summary_df['step_time (s)'] = np.median((accumulated_time_intervals / step_intervals).iloc[0]) summary_df['step_hint'] = scoring_utils.get_workload_stephint(workload) From 93b47bc1e0ea7755db3b77e4e302d6b28e27adb6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 5 Dec 2025 21:20:35 +0000 Subject: [PATCH 4/5] set target to target found w narrower search space --- algoperf/workloads/finewebedu_lm/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algoperf/workloads/finewebedu_lm/workload.py b/algoperf/workloads/finewebedu_lm/workload.py index b5a258f6f..8eddfd337 100644 --- a/algoperf/workloads/finewebedu_lm/workload.py +++ b/algoperf/workloads/finewebedu_lm/workload.py @@ -52,7 +52,7 @@ def has_reached_validation_target(self, eval_result: float) -> bool: @property def validation_target_value(self) -> float: - return 22.432 # Target perplexity + return 22.2995 # Target perplexity def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: return True # No test targets From 9daa2833a0d58690ad83e4b46e2ec4b30f0e24ee Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 8 Dec 2025 18:37:19 +0000 Subject: [PATCH 5/5] add finewebedu_lm to docker startup script valid args --- docker/scripts/startup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index d92107e90..1cd676d2a 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -185,7 +185,7 @@ VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_ "librispeech_conformer_gelu" "fastmri_model_size" "fastmri_tanh" \ "librispeech_deepspeech_tanh" \ "librispeech_deepspeech_no_resnet" "librispeech_deepspeech_norm_and_spec_aug" - "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size" "lm") + "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size" "finewebedu_lm") VALID_RULESETS=("self" "external") # Set data and experiment paths