Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,13 @@ export R2_WRITE_ACCESS_KEY_ID=
export R2_WRITE_SECRET_ACCESS_KEY=
export R2_ADMIN_ACCESS_KEY_ID=
export R2_ADMIN_SECRET_ACCESS_KEY=

# R2 Dataset Configuration (for loading training data from R2 instead of HuggingFace)
# To use R2 for training data, set these variables and run with --neuron.data_source r2
export R2_DATASET_ACCOUNT_ID=
export R2_DATASET_BUCKET_NAME=
export R2_DATASET_ACCESS_KEY_ID=
export R2_DATASET_SECRET_ACCESS_KEY=

# Data source for eval scripts (huggingface or r2)
export DATA_SOURCE=huggingface
11 changes: 8 additions & 3 deletions distributed_training/averaging/avg_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any, Dict, List, Tuple
from distributed_training.averaging.exceptions import AllReduceError, ModelStateError
from distributed_training.protocol import AllReduce
from distributed_training.data.dataset import DatasetLoader
from distributed_training.data.dataset import DatasetLoader, get_dataset_loader
from distributed_training.utils.dendrite import (
async_dendrite_forward,
)
Expand All @@ -36,6 +36,7 @@ def __init__(
device,
logger,
parameters_list=None,
data_source="huggingface",
):
self.model = model
self.inner_optimizer = optimizer
Expand All @@ -55,6 +56,7 @@ def __init__(
self.logger = logger
self.parameters_list = parameters_list
self.master = True
self.data_source = data_source

def _get_weights_sample(self) -> List[float]:
"""Get a sample of model weights for validation."""
Expand Down Expand Up @@ -83,17 +85,20 @@ async def _validate_weight_update(
async def fetch_training_data(self, block):
"""Async function to fetch training data"""
attempt = 0
# Get the appropriate loader based on config
LoaderClass = get_dataset_loader(self.data_source)

while attempt < self.retry_limit:
try:
pages = await DatasetLoader.next_pages(
pages = await LoaderClass.next_pages(
offset=block,
n_pages=5,
seed=self.uid,
)
random.seed(self.uid)
random.shuffle(pages)

dataset = await DatasetLoader.create(
dataset = await LoaderClass.create(
batch_size=4,
sequence_length=1024,
pages_info=pages,
Expand Down
Loading