Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ scoring/plots/
!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv

algoperf/_version.py
core*
2 changes: 1 addition & 1 deletion algoperf/workloads/fastmri/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def eval_period_time_sec(self) -> int:
@property
def step_hint(self) -> int:
"""Approx. steps the baseline can do in the allowed runtime budget."""
return 18_094
return 36_189

def _build_input_queue(
self,
Expand Down
2 changes: 1 addition & 1 deletion algoperf/workloads/imagenet_resnet/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,4 @@ def _build_input_queue(
@property
def step_hint(self) -> int:
"""Approx. steps the baseline can do in the allowed runtime budget."""
return 195_999
return 186_666
2 changes: 1 addition & 1 deletion algoperf/workloads/imagenet_vit/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,4 @@ def _build_dataset(
@property
def step_hint(self) -> int:
"""Approx. steps the baseline can do in the allowed runtime budget."""
return 167_999
return 186_666
2 changes: 1 addition & 1 deletion algoperf/workloads/librispeech_conformer/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,4 @@ def eval_period_time_sec(self) -> int:
@property
def step_hint(self) -> int:
"""Approx. steps the baseline can do in the allowed runtime budget."""
return 76_000
return 80_000
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_target_value(self) -> float:
@property
def step_hint(self) -> int:
"""Approx. steps the baseline can do in the allowed runtime budget."""
return 38_400
return 48_000

@property
def max_allowed_runtime_sec(self) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_target_value(self) -> float:
@property
def step_hint(self) -> int:
"""Approx. steps the baseline can do in the allowed runtime budget."""
return 38_400
return 48_000

@property
def max_allowed_runtime_sec(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion algoperf/workloads/ogbg/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def loss_fn(
@property
def step_hint(self) -> int:
"""Approx. steps the baseline can do in the allowed runtime budget."""
return 52_000
return 80_000

@abc.abstractmethod
def _normalize_eval_metrics(
Expand Down
2 changes: 1 addition & 1 deletion algoperf/workloads/wmt/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def eval_period_time_sec(self) -> int:
@property
def step_hint(self) -> int:
"""Approx. steps the baseline can do in the allowed runtime budget."""
return 120_000
return 133_333

@property
def pre_ln(self) -> bool:
Expand Down
128 changes: 128 additions & 0 deletions docker/scripts/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# plot_results.py

import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import collections

# --- Configuration ---
# The base directory to start searching from.
# The script assumes it's run from the 'experiment_runs' directory.
SEARCH_DIR = Path("~/experiment_runs/tests/regression_tests/adamw").expanduser()

# The directory where the output plots will be saved.
OUTPUT_DIR = Path("ssim_plots")

# The columns to use for the x and y axes.
X_AXIS_COL = "global_step"
Y_AXIS_CANDIDATES = ["validation/loss", "validation/ctc_loss"]
# ---------------------

def generate_plots():
"""
Finds all 'measurements.csv' files, groups them by workflow,
and generates a JAX vs. PyTorch plot for each.
"""
# Create the output directory if it doesn't already exist
OUTPUT_DIR.mkdir(exist_ok=True)
print(f"📊 Plots will be saved to the '{OUTPUT_DIR}' directory.")

# Use a dictionary to group file paths by their workflow name
# e.g., {'fastmri': [...], 'wmt': [...]}
workflow_files = collections.defaultdict(list)

# Recursively find all 'measurements.csv' files in the search directory
for csv_path in SEARCH_DIR.rglob("measurements.csv"):
try:
# The directory name looks like 'fastmri_jax' or 'wmt_pytorch'.
# We get this from the parent of the parent of the csv file.
# e.g., .../fastmri_jax/trial_1/measurements.csv
workflow_framework_name = csv_path.parent.parent.name

# Split the name to get the framework (last part) and workflow (everything else)
parts = workflow_framework_name.split('_')
framework = parts[-1]
workflow = '_'.join(parts[:-1])

# Store the path and framework for this workflow
if framework in ['jax', 'pytorch']:
workflow_files[workflow].append({'path': csv_path, 'framework': framework})

except IndexError:
# This handles cases where the directory name might not match the expected pattern
print(f"⚠️ Could not parse workflow/framework from path: {csv_path}")
continue

if not workflow_files:
print("❌ No 'measurements.csv' files found. Check the SEARCH_DIR variable and your folder structure.")
return

print(f"\nFound {len(workflow_files)} workflows. Generating plots...")

# Iterate through each workflow and its associated files to create a plot
for workflow, files in workflow_files.items():
plt.style.use('seaborn-v0_8-whitegrid')
fig, ax = plt.subplots(figsize=(12, 7))

print(f" -> Processing workflow: '{workflow}'")

y_axis_col_used = None # To store the name of the y-axis column for the plot labels

# Plot data for each framework (JAX and PyTorch) on the same figure
for item in files:
try:
df = pd.read_csv(item['path'])

y_axis_col = None
for candidate in Y_AXIS_CANDIDATES:
if candidate in df.columns:
y_axis_col = candidate
if not y_axis_col_used:
y_axis_col_used = y_axis_col # Set the label from the first file
break # Found a valid column, no need to check further

# if item['framework'] == 'jax':
# y_axis_col = None

# Check if the required columns exist in the CSV
if X_AXIS_COL in df.columns and y_axis_col:

# 1. Forward-fill 'global_step' to propagate the last valid step downwards.
df[X_AXIS_COL] = df[X_AXIS_COL].ffill()

# 2. Drop any rows where 'validation/ssim' is empty (NaN).
df_cleaned = df.dropna(subset=[y_axis_col])

# Plot the cleaned data
ax.plot(
df_cleaned[X_AXIS_COL],
df_cleaned[y_axis_col],
label=item['framework'].capitalize(), # e.g., 'Jax'
marker='.',
linestyle='-',
alpha=0.8
)
else:
print(f" - Skipping {item['path']} (missing required columns).")

except Exception as e:
print(f" - ❗️ Error reading {item['path']}: {e}")

# Customize and save the plot
ax.set_title(f'Validation loss vs. Global Step for {workflow.replace("_", " ").title()}', fontsize=16)
ax.set_xlabel("Global Step", fontsize=12)
ax.set_ylabel("Validation loss", fontsize=12)
ax.legend(title="Framework", fontsize=10)
plt.tight_layout()
plt.yscale('log')

# Define the output filename and save the figure
output_filename = OUTPUT_DIR / f"{workflow}_comparison.png"
plt.savefig(output_filename, dpi=150)
plt.close(fig) # Close the figure to free up memory

print("\n✅ All plots generated successfully!")


if __name__ == "__main__":
generate_plots()
30 changes: 16 additions & 14 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
import struct
import time
import optax
from inspect import signature
from types import MappingProxyType
from typing import Any, Dict, Optional, Tuple
Expand Down Expand Up @@ -79,20 +80,19 @@
help='Which tuning ruleset to use.',
)
flags.DEFINE_string(
'tuning_search_space',
None,
'The path to the JSON file describing the external tuning search space.',
)
flags.DEFINE_integer(
'num_tuning_trials', 1, 'The number of external hyperparameter trials to run.'
)
'tuning_search_space',
None,
'The path to the JSON file describing the external tuning search space.')
flags.DEFINE_integer('num_tuning_trials',
1,
'The number of external hyperparameter trials to run.')
flags.DEFINE_string('data_dir', '~/data', 'Dataset location.')
flags.DEFINE_string(
'imagenet_v2_data_dir', None, 'Dataset location for ImageNet-v2.'
)
flags.DEFINE_string(
'librispeech_tokenizer_vocab_path', '', 'Location to librispeech tokenizer.'
)
flags.DEFINE_string('imagenet_v2_data_dir',
None,
'Dataset location for ImageNet-v2.')
flags.DEFINE_string('librispeech_tokenizer_vocab_path',
'',
'Location to librispeech tokenizer.')

flags.DEFINE_enum(
'framework',
Expand Down Expand Up @@ -861,9 +861,11 @@ def main(_):


if __name__ == '__main__':
print(optax.__version__)
print("!!!!")
flags.mark_flag_as_required('workload')
flags.mark_flag_as_required('framework')
flags.mark_flag_as_required('submission_path')
flags.mark_flag_as_required('experiment_dir')
flags.mark_flag_as_required('experiment_name')
app.run(main)
app.run(main)
Empty file.
76 changes: 76 additions & 0 deletions tests/test_algorithms/schedule_free_adamw/compare_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import pandas as pd
import matplotlib.pyplot as plt
import os
import subprocess
import numpy as np
from scipy.interpolate import make_interp_spline, BSpline

# Paths to CSV files
pytorch_csv = '/root/experiments/sfadamw6/ogbg_pytorch/trial_1/measurements.csv'
jax_csv = '/root/experiments/sfadamw7/ogbg_jax/trial_1/measurements.csv'

# Read CSVs
try:
pytorch_df = pd.read_csv(pytorch_csv)
print("PyTorch CSV columns:", pytorch_df.columns)
jax_df = pd.read_csv(jax_csv)
print("JAX CSV columns:", jax_df.columns)
except FileNotFoundError as e:
print(f"Error: Could not find CSV file: {e}")
exit()
except pd.errors.EmptyDataError as e:
print(f"Error: CSV file is empty: {e}")
exit()
except Exception as e:
print(f"Error reading CSV file: {e}")
exit()

# Define the correct column names based on inspection
x_column = 'global_step'
y_column = 'validation/loss'

# Create output directory in /tmp which is writable
output_dir = '/tmp/plots'
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, 'adam_ogbg_plot.png')

# --- FIX: Clean the data by dropping rows with NaN values ---
pytorch_df_cleaned = pytorch_df.dropna(subset=[y_column])
jax_df_cleaned = jax_df.dropna(subset=[y_column])

# Print data after cleaning
print("\nCleaned PyTorch data:")
print(pytorch_df_cleaned[[x_column, y_column]].head())
print("\nCleaned JAX data:")
print(jax_df_cleaned[[x_column, y_column]].head())

# Create the plot
plt.figure(figsize=(10, 6))
plt.plot(pytorch_df_cleaned[x_column], pytorch_df_cleaned[y_column], label='PyTorch', marker='o')
plt.plot(jax_df_cleaned[x_column], jax_df_cleaned[y_column], label='JAX', marker='x')

plt.xlabel('Step')
plt.ylabel('Metric')
plt.title('Comparison of PyTorch and JAX Metrics')
plt.legend()
plt.grid(True)
plt.tight_layout()

# Save the plot
try:
plt.savefig(output_path)
print(f"Plot saved to: {output_path}")
except Exception as e:
print(f"Error saving plot: {e}")

plt.close()

# Open the plot in the default browser (if possible)
try:
subprocess.run(["$BROWSER", output_path], shell=True, check=True)
except FileNotFoundError:
print("No browser found to open the image.")
except subprocess.CalledProcessError:
print("Error opening the image in the browser.")
except Exception as e:
print(f"Error opening plot in browser: {e}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pandas as pd

# Paths to CSV files
pytorch_csv = '/root/experiments/sfadamw6/ogbg_pytorch/trial_1/measurements.csv'
jax_csv = '/root/experiments/sfadamw6/ogbg_jax/trial_1/measurements.csv'

# Read CSVs
try:
pytorch_df = pd.read_csv(pytorch_csv)
jax_df = pd.read_csv(jax_csv)
except FileNotFoundError as e:
print(f"Error: Could not find CSV file: {e}")
exit()
except pd.errors.EmptyDataError as e:
print(f"Error: CSV file is empty: {e}")
exit()
except Exception as e:
print(f"Error reading CSV file: {e}")
exit()

# Print the number of data points
print(f"Number of data points in PyTorch CSV: {len(pytorch_df)}")
print(f"Number of data points in JAX CSV: {len(jax_df)}")
Empty file.
Loading
Loading