Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from algoperf.workloads.criteo1tb.criteo1tb_jax import models
from algoperf.workloads.criteo1tb.workload import \
BaseCriteo1TbDlrmSmallWorkload
from custom_pytorch_jax_converter import use_pytorch_weights



class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload):
Expand Down Expand Up @@ -103,6 +105,7 @@ def init_model_fn(
{'params': params_rng, 'dropout': dropout_rng},
jnp.ones(input_shape, jnp.float32))
initial_params = initial_variables['params']
initial_params = use_pytorch_weights(file_name="/results/pytorch_base_model_criteo1tb_1_july.pth")
self._param_shapes = param_utils.jax_param_shapes(initial_params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
return jax_utils.replicate(initial_params), None
Expand Down
1 change: 1 addition & 0 deletions algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def init_model_fn(
dropout_rate=dropout_rate,
use_layer_norm=self.use_layer_norm,
embedding_init_multiplier=self.embedding_init_multiplier)
torch.save(model.state_dict(), '/results/pytorch_base_model_criteo1tb_1_july.pth')
self._param_shapes = param_utils.pytorch_param_shapes(model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
model.to(DEVICE)
Expand Down
99 changes: 99 additions & 0 deletions custom_pytorch_jax_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torch
import numpy as np
import jax
import jax.numpy as jnp
import logging
import copy
import copy
from jax.tree_util import tree_map


def use_pytorch_weights(file_name: str):
"""
Jax default parameter structure:
dict_keys(['Dense_0', 'Dense_1', 'Dense_2', 'Dense_3', 'Dense_4', 'Dense_5', 'Dense_6', 'Dense_7', 'embedding_table'])

Pytorch stateduct structure:
dict_keys(['embedding_chunk_0', 'embedding_chunk_1', 'embedding_chunk_2', 'embedding_chunk_3', 'bot_mlp.0.weight', 'bot_mlp.0.bias', 'bot_mlp.2.weight', 'bot_mlp.2.bias', 'bot_mlp.4.weight', 'bot_mlp.4.bias', 'top_mlp.0.weight', 'top_mlp.0.bias', 'top_mlp.2.weight', 'top_mlp.2.bias', 'top_mlp.4.weight', 'top_mlp.4.bias', 'top_mlp.6.weight', 'top_mlp.6.bias', 'top_mlp.8.weight', 'top_mlp.8.bias'])


The following function converts the PyTorch weights to the Jax format
"""

jax_copy = {}

# Load PyTorch state_dict lazily to CPU
state_dict = torch.load(file_name, map_location='cpu')
print(state_dict.keys())

# Convert PyTorch tensors to NumPy arrays
numpy_weights = {k: v.cpu().numpy() for k, v in state_dict.items()}

# --- Embedding Table ---
embedding_table = np.concatenate([
numpy_weights[f'embedding_chunk_{i}'] for i in range(4)
], axis=0) # adjust axis if chunking is not vertical

jax_copy['embedding_table'] = jnp.array(embedding_table)

# --- Bot MLP: Dense_0 to Dense_2 ---
for i, j in zip([0, 2, 4], range(3)):
jax_copy[f'Dense_{j}'] = {}
jax_copy[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'bot_mlp.{i}.weight'].T)
jax_copy[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'bot_mlp.{i}.bias'])

# --- Top MLP: Dense_3 to Dense_7 ---
for i, j in zip([0, 2, 4, 6, 8], range(3, 8)):
jax_copy[f'Dense_{j}'] = {}
jax_copy[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'top_mlp.{i}.weight'].T)
jax_copy[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'top_mlp.{i}.bias'])

del state_dict
return jax_copy


def maybe_unreplicate(pytree):
"""If leading axis matches device count, strip it assuming it's pmap replication."""
num_devices = jax.device_count()
return jax.tree_util.tree_map(
lambda x: x[0] if isinstance(x, jax.Array) and x.shape[0] == num_devices else x,
pytree
)


def move_to_cpu(tree):
return jax.tree_util.tree_map(lambda x: jax.device_put(x, device=jax.devices("cpu")[0]), tree)


def are_weights_equal(params1, params2, atol=1e-6, rtol=1e-6):
"""Compares two JAX PyTrees of weights and logs where they differ, safely handling PMAP replication."""
# Attempt to unreplicate if needed

params1 = maybe_unreplicate(params1)
params2 = maybe_unreplicate(params2)

params1 = move_to_cpu(params1)
params2 = move_to_cpu(params2)

all_equal = True

def compare_fn(p1, p2):
nonlocal all_equal
if not jnp.allclose(p1, p2, atol=atol, rtol=rtol):
logging.info("❌ Mismatch found:")
logging.info(f"Shape : {p1.shape}, Shape 2: {p2.shape}")
logging.info(f"Max diff: {jnp.max(jnp.abs(p1 - p2))}")
all_equal = False
return jnp.allclose(p1, p2, atol=atol, rtol=rtol)

try:
jax.tree_util.tree_map(compare_fn, params1, params2)
except Exception as e:
logging.info("❌ Structure mismatch or error during comparison:", exc_info=True)
return False

if all_equal:
logging.info("✅ All weights are equal (within tolerance)")
del params1
del params2
return all_equal
226 changes: 226 additions & 0 deletions reference_algorithms/schedule_free/jax/submission.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
"""Submission file for an Schedule Free AdamW optimizer in Jax."""

import functools
from typing import Dict, Iterator, List, Tuple
import optax

from flax import jax_utils
import jax
from jax import lax
import jax.numpy as jnp
from optax.contrib import schedule_free_adamw
from algoperf import spec
from custom_pytorch_jax_converter import use_pytorch_weights, are_weights_equal

_GRAD_CLIP_EPS = 1e-6

HPARAMS = {
"dropout_rate": 0.1,
"learning_rate": 0.0025,
"one_minus_beta1": 0.1,
"beta2": 0.9955159689799007,
"weight_decay": 0.08121616522670176,
"warmup_factor": 0.02,
"weight_lr_power": 2,
"label_smoothing": 0.2,
"r": 0.75,
"eps": 1e-8,
}

def init_optimizer_state(workload: spec.Workload,
model_params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
rng: spec.RandomState) -> spec.OptimizerState:
"""Creates an AdamW optimizer and a learning rate schedule."""
model_params
del model_state
del rng

opt_init_fn, opt_update_fn = schedule_free_adamw(
learning_rate=HPARAMS['learning_rate'],
warmup_steps=int(HPARAMS['warmup_factor'] * workload.step_hint * 0.75),

b1=1.0 - HPARAMS['one_minus_beta1'],
b2=HPARAMS['beta2'],
eps=HPARAMS['eps'],
weight_decay=HPARAMS['weight_decay'],
weight_lr_power=HPARAMS['weight_lr_power'],
)

model_params = jax_utils.unreplicate(model_params)
optimizer_state = opt_init_fn(model_params)

return jax_utils.replicate(optimizer_state), opt_update_fn


@functools.partial(
jax.pmap,
axis_name='batch',
in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
static_broadcasted_argnums=(0, 1),
donate_argnums=(2, 3, 4))
def pmapped_train_step(workload,
opt_update_fn,
model_state,
optimizer_state,
current_param_container,
batch,
rng,
grad_clip,
label_smoothing):

def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
params,
batch,
model_state,
spec.ForwardPassMode.TRAIN,
rng,
update_batch_norm=True)
loss_dict = workload.loss_fn(
label_batch=batch['targets'],
logits_batch=logits,
mask_batch=batch.get('weights'),
label_smoothing=label_smoothing)
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)

grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
current_param_container)
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
(summed_loss, n_valid_examples, grad), axis_name='batch')
loss = summed_loss / n_valid_examples
grad = jax.tree_map(lambda x: x / n_valid_examples, grad)

grad_norm = jnp.sqrt(
sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))

# Extract the leaves of the pytree
leaves = jax.tree_util.tree_leaves(grad)
# Count the total number of elements in all leaves
total_size = sum(jnp.size(leaf) for leaf in leaves)

# jax.debug.print('GRAD NORM {}', grad_norm)
# jax.debug.print('NUM PARAMS {}', total_size)

if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad)

updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
current_param_container)
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del eval_results

optimizer_state, opt_update_fn = optimizer_state
per_device_rngs = jax.random.split(rng, jax.local_device_count())
if hasattr(hyperparameters, 'label_smoothing'):
label_smoothing = hyperparameters.label_smoothing
else:
label_smoothing = 0.0
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
outputs = pmapped_train_step(workload,
opt_update_fn,
model_state,
optimizer_state,
current_param_container,
batch,
per_device_rngs,
grad_clip,
label_smoothing)
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs

# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
{
'loss': loss[0],
'grad_norm': grad_norm[0],
}, global_step)

# Log the number of parameters.
if global_step % 100 == 0:
date_ = "2025-07-01"
file_name = f"/results/schedule_free_pytorch_weights/criteo1tb_{date_}_after_{global_step}_steps.pth"
params = use_pytorch_weights(file_name=file_name)
are_weights_equal(new_params, params)
del params

return (new_optimizer_state, opt_update_fn), new_params, new_model_state


def get_batch_size(workload_name):
# Return the global batch size.
if workload_name == 'criteo1tb':
return 262_144
elif workload_name == 'fastmri':
return 32
elif workload_name == 'imagenet_resnet':
return 1024
elif workload_name == 'imagenet_resnet_silu':
return 512
elif workload_name == 'imagenet_resnet_gelu':
return 512
elif workload_name == 'imagenet_vit':
return 1024
elif workload_name == 'librispeech_conformer':
return 256
elif workload_name == 'librispeech_deepspeech':
return 256
elif workload_name == 'ogbg':
return 512
elif workload_name == 'wmt':
return 128
elif workload_name == 'mnist':
return 16
else:
raise ValueError(f'Unsupported workload name: {workload_name}.')


def data_selection(workload: spec.Workload,
input_queue: Iterator[Dict[str, spec.Tensor]],
optimizer_state: spec.OptimizerState,
current_param_container: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
global_step: int,
rng: spec.RandomState) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
del workload
del optimizer_state
del current_param_container
del model_state
del hyperparameters
del global_step
del rng
batch = next(input_queue)
return batch

Loading
Loading