diff --git a/.env.example b/.env.example index e0af16c5..81b2aa22 100644 --- a/.env.example +++ b/.env.example @@ -6,3 +6,13 @@ export R2_WRITE_ACCESS_KEY_ID= export R2_WRITE_SECRET_ACCESS_KEY= export R2_ADMIN_ACCESS_KEY_ID= export R2_ADMIN_SECRET_ACCESS_KEY= + +# R2 Dataset Configuration (for loading training data from R2 instead of HuggingFace) +# To use R2 for training data, set these variables and run with --neuron.data_source r2 +export R2_DATASET_ACCOUNT_ID= +export R2_DATASET_BUCKET_NAME= +export R2_DATASET_ACCESS_KEY_ID= +export R2_DATASET_SECRET_ACCESS_KEY= + +# Data source for eval scripts (huggingface or r2) +export DATA_SOURCE=huggingface diff --git a/distributed_training/averaging/avg_handler.py b/distributed_training/averaging/avg_handler.py index 75476b76..b01541c5 100644 --- a/distributed_training/averaging/avg_handler.py +++ b/distributed_training/averaging/avg_handler.py @@ -10,7 +10,7 @@ from typing import Any, Dict, List, Tuple from distributed_training.averaging.exceptions import AllReduceError, ModelStateError from distributed_training.protocol import AllReduce -from distributed_training.data.dataset import DatasetLoader +from distributed_training.data.dataset import DatasetLoader, get_dataset_loader from distributed_training.utils.dendrite import ( async_dendrite_forward, ) @@ -36,6 +36,7 @@ def __init__( device, logger, parameters_list=None, + data_source="huggingface", ): self.model = model self.inner_optimizer = optimizer @@ -55,6 +56,7 @@ def __init__( self.logger = logger self.parameters_list = parameters_list self.master = True + self.data_source = data_source def _get_weights_sample(self) -> List[float]: """Get a sample of model weights for validation.""" @@ -83,9 +85,12 @@ async def _validate_weight_update( async def fetch_training_data(self, block): """Async function to fetch training data""" attempt = 0 + # Get the appropriate loader based on config + LoaderClass = get_dataset_loader(self.data_source) + while attempt < self.retry_limit: try: - pages = await DatasetLoader.next_pages( + pages = await LoaderClass.next_pages( offset=block, n_pages=5, seed=self.uid, @@ -93,7 +98,7 @@ async def fetch_training_data(self, block): random.seed(self.uid) random.shuffle(pages) - dataset = await DatasetLoader.create( + dataset = await LoaderClass.create( batch_size=4, sequence_length=1024, pages_info=pages, diff --git a/distributed_training/data/dataset.py b/distributed_training/data/dataset.py index 39751593..8f8fd350 100644 --- a/distributed_training/data/dataset.py +++ b/distributed_training/data/dataset.py @@ -15,17 +15,36 @@ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. import asyncio +import io +import os import random import typing import aiohttp import numpy as np +import pandas as pd +import pyarrow.parquet as pq import torch import time import bittensor as bt from torch.utils.data import IterableDataset from transformers import AutoTokenizer +# Optional: aiobotocore for async S3/R2 operations +try: + from aiobotocore.session import get_session as get_aio_session + HAS_AIOBOTOCORE = True +except ImportError: + HAS_AIOBOTOCORE = False + +# Fallback to boto3 for sync operations +try: + import boto3 + from botocore.config import Config as BotoConfig + HAS_BOTO3 = True +except ImportError: + HAS_BOTO3 = False + class SubsetLoader(IterableDataset): """ @@ -553,3 +572,545 @@ async def next_pages_async( ) result.append((str(config), int(choice), configs_data[config]["split"])) return result + + +class R2DatasetLoader(SubsetLoader): + """ + DatasetLoader that fetches data from Cloudflare R2 instead of HuggingFace API. + + This loader is designed to work with parquet files stored in R2 that contain + the same data format as HuggingFace's fineweb-edu dataset. The parquet files + should have a 'text' column containing the training text. + + R2 Configuration: + The loader reads R2 configuration from environment variables: + - R2_DATASET_ACCOUNT_ID: Cloudflare account ID + - R2_DATASET_BUCKET_NAME: Name of the R2 bucket containing dataset + - R2_DATASET_ACCESS_KEY_ID: Access key for reading the bucket + - R2_DATASET_SECRET_ACCESS_KEY: Secret key for reading the bucket + + Bucket Structure: + The bucket should contain parquet files organized by config/shard: + - {config_name}/{shard_id}.parquet + - Each parquet file should have a 'text' column with the content + + A metadata file 'configs.json' at the bucket root should describe available + configs with their row counts and splits. + """ + + # R2 endpoint URL pattern + R2_ENDPOINT_PATTERN: str = "https://{account_id}.r2.cloudflarestorage.com" + + retry_limit: int = 5 + retry_delay: int = 60 + num_rows_per_page: int = 100 + + logger = bt.logging + + # Cache for configs data + _configs_cache: typing.Dict[str, typing.Dict] = None + _configs_cache_time: float = 0 + _configs_cache_ttl: float = 300 # 5 minutes + + def __init__( + self, + batch_size=None, + sequence_length=None, + num_pages=None, + pages_info=None, + tokenizer: AutoTokenizer = None, + pack_samples: bool = False, + r2_account_id: str = None, + r2_bucket_name: str = None, + r2_access_key_id: str = None, + r2_secret_access_key: str = None, + ): + super().__init__( + batch_size, sequence_length, num_pages, tokenizer, pack_samples + ) + + # R2 configuration from parameters or environment + self.r2_account_id = r2_account_id or os.getenv("R2_DATASET_ACCOUNT_ID") + self.r2_bucket_name = r2_bucket_name or os.getenv("R2_DATASET_BUCKET_NAME") + self.r2_access_key_id = r2_access_key_id or os.getenv("R2_DATASET_ACCESS_KEY_ID") + self.r2_secret_access_key = r2_secret_access_key or os.getenv("R2_DATASET_SECRET_ACCESS_KEY") + + # Validate configuration + if not all([self.r2_account_id, self.r2_bucket_name, self.r2_access_key_id, self.r2_secret_access_key]): + raise ValueError( + "R2 configuration incomplete. Please set R2_DATASET_ACCOUNT_ID, " + "R2_DATASET_BUCKET_NAME, R2_DATASET_ACCESS_KEY_ID, and R2_DATASET_SECRET_ACCESS_KEY " + "environment variables or pass them as parameters." + ) + + self.r2_endpoint = self.R2_ENDPOINT_PATTERN.format(account_id=self.r2_account_id) + + # Initialize properties + self.configs_data = None + self.pages = [] + self.buffer = [] + self.lock = asyncio.Lock() + + def _get_s3_client(self): + """Create a boto3 S3 client configured for R2.""" + if not HAS_BOTO3: + raise ImportError("boto3 is required for R2DatasetLoader. Please install it with: pip install boto3") + + return boto3.client( + "s3", + endpoint_url=self.r2_endpoint, + aws_access_key_id=self.r2_access_key_id, + aws_secret_access_key=self.r2_secret_access_key, + region_name="auto", + config=BotoConfig( + retries={"max_attempts": 3, "mode": "adaptive"}, + connect_timeout=30, + read_timeout=60, + ), + ) + + @classmethod + async def create( + cls, + batch_size=None, + sequence_length=None, + num_pages=None, + pages_info=None, + tokenizer: AutoTokenizer = None, + pack_samples: bool = False, + r2_account_id: str = None, + r2_bucket_name: str = None, + r2_access_key_id: str = None, + r2_secret_access_key: str = None, + ): + """Factory method to create R2DatasetLoader asynchronously.""" + self = cls( + batch_size=batch_size, + sequence_length=sequence_length, + num_pages=num_pages, + tokenizer=tokenizer, + pack_samples=pack_samples, + r2_account_id=r2_account_id, + r2_bucket_name=r2_bucket_name, + r2_access_key_id=r2_access_key_id, + r2_secret_access_key=r2_secret_access_key, + ) + + # Fetch dataset configs asynchronously + self.configs_data = await self.fetch_dataset_configs() + + if pages_info is not None: + await self._fetch(pages_info) + elif self.num_pages: + await self._fetch_data_to_buffer(self.num_pages) + + return self + + async def fetch_dataset_configs(self) -> typing.Dict[str, typing.Dict]: + """ + Fetch dataset configuration from R2. + + Expects a 'configs.json' file at the bucket root with structure: + { + "config_name": { + "num_rows": 123456, + "split": "train", + "num_shards": 10, + "rows_per_shard": 12345 + }, + ... + } + + If configs.json doesn't exist, it will try to discover configs + by listing parquet files in the bucket. + """ + # Check cache + current_time = time.time() + if ( + R2DatasetLoader._configs_cache is not None + and (current_time - R2DatasetLoader._configs_cache_time) < R2DatasetLoader._configs_cache_ttl + ): + return R2DatasetLoader._configs_cache + + attempt = 0 + while attempt < self.retry_limit: + try: + configs_data = await asyncio.to_thread(self._fetch_configs_sync) + R2DatasetLoader._configs_cache = configs_data + R2DatasetLoader._configs_cache_time = current_time + return configs_data + except Exception as e: + attempt += 1 + if attempt < self.retry_limit: + self.logger.debug( + f"Retrying configs fetch due to error: {e}. Attempt {attempt}" + ) + await asyncio.sleep(self.retry_delay * attempt) + else: + raise + + def _fetch_configs_sync(self) -> typing.Dict[str, typing.Dict]: + """Synchronously fetch configs from R2.""" + import json + from botocore.exceptions import ClientError + + s3 = self._get_s3_client() + + try: + # Try to get configs.json first + response = s3.get_object(Bucket=self.r2_bucket_name, Key="configs.json") + configs_data = json.loads(response["Body"].read().decode("utf-8")) + return configs_data + except ClientError as e: + # If configs.json doesn't exist, discover configs by listing bucket + if e.response.get("Error", {}).get("Code") in ("NoSuchKey", "404"): + return self._discover_configs_sync(s3) + raise + + def _discover_configs_sync(self, s3) -> typing.Dict[str, typing.Dict]: + """ + Discover available configs by listing parquet files in the bucket. + Assumes structure: {config_name}/{shard_id}.parquet + """ + configs_data = {} + paginator = s3.get_paginator("list_objects_v2") + + for page in paginator.paginate(Bucket=self.r2_bucket_name): + for obj in page.get("Contents", []): + key = obj["Key"] + if key.endswith(".parquet"): + parts = key.rsplit("/", 1) + if len(parts) == 2: + config_name = parts[0] + if config_name not in configs_data: + configs_data[config_name] = { + "num_rows": 0, + "split": "train", + "shards": [], + } + configs_data[config_name]["shards"].append(key) + + # Estimate row counts by reading first shard of each config + for config_name, config_info in configs_data.items(): + if config_info["shards"]: + first_shard = config_info["shards"][0] + try: + response = s3.get_object(Bucket=self.r2_bucket_name, Key=first_shard) + parquet_data = response["Body"].read() + table = pq.read_table(io.BytesIO(parquet_data)) + rows_per_shard = len(table) + config_info["num_rows"] = rows_per_shard * len(config_info["shards"]) + config_info["rows_per_shard"] = rows_per_shard + except Exception as e: + self.logger.warning(f"Could not read shard {first_shard}: {e}") + config_info["num_rows"] = len(config_info["shards"]) * 10000 # Estimate + + return configs_data + + @staticmethod + async def next_pages( + offset: int, n_pages: int, seed: str, num_rows_per_page: int = 100, + r2_account_id: str = None, + r2_bucket_name: str = None, + r2_access_key_id: str = None, + r2_secret_access_key: str = None, + ): + """ + Deterministically select pages based on offset and seed. + + Uses the same seeding mechanism as DatasetLoader for consistency: + - seed: typically uid + local_rank + - offset: blockchain block number + + This ensures the same UID gets the same data pages for a given block. + """ + # Create a temporary instance to fetch configs + loader = R2DatasetLoader( + batch_size=1, + sequence_length=1024, + tokenizer=None, # Not needed for page selection + r2_account_id=r2_account_id or os.getenv("R2_DATASET_ACCOUNT_ID"), + r2_bucket_name=r2_bucket_name or os.getenv("R2_DATASET_BUCKET_NAME"), + r2_access_key_id=r2_access_key_id or os.getenv("R2_DATASET_ACCESS_KEY_ID"), + r2_secret_access_key=r2_secret_access_key or os.getenv("R2_DATASET_SECRET_ACCESS_KEY"), + ) + + configs_data = await loader.fetch_dataset_configs() + keys = sorted(configs_data.keys()) + + # Use the same RNG seeding as DatasetLoader + rng = np.random.default_rng(hash(seed) & 0xFFFFFFFF) + rng.bit_generator.advance(offset) + + result = [] + for _ in range(n_pages): + idx = rng.integers(0, len(keys)) + cfg = keys[idx] + config = rng.choice(list(configs_data.keys())) + max_row = configs_data[cfg]["num_rows"] - 1 - num_rows_per_page + if max_row < 0: + max_row = 0 + choice = rng.integers(0, max(1, max_row)) + result.append((cfg, int(choice), configs_data[cfg].get("split", "train"))) + + return result + + async def _fetch(self, page_info: typing.Tuple[str, int, str], batch_size: int = 5): + """Fetch data for specified pages.""" + self.pages = list(page_info) + + for i in range(0, len(self.pages), batch_size): + batch = self.pages[i : i + batch_size] + tasks = [ + self._fetch_data_for_page((config_name, page, split)) + for (config_name, page, split) in batch + ] + await asyncio.gather(*tasks) + + async def _fetch_data_to_buffer(self, num_pages): + """Randomly sample pages and add their data to the buffer.""" + self.pages = [] + pages_to_fetch = self.get_random_pages(num_pages) + + tasks = [self._fetch_data_for_page(page) for page in pages_to_fetch] + await asyncio.gather(*tasks) + + async def _fetch_data_for_page(self, page): + """ + Fetch data for a single page from R2 parquet files. + + Args: + page: Tuple of (config_name, row_offset, split) + """ + retry_limit = self.retry_limit + attempt = 0 + + while attempt < retry_limit: + config_name, row_offset, split = page + + try: + # Fetch rows from parquet file(s) + rows = await asyncio.to_thread( + self._fetch_rows_from_parquet_sync, + config_name, + row_offset, + self.num_rows_per_page, + ) + + # Tokenize the content + buffer_to_append = [] + for text in rows: + if self.tokenizer is not None: + input_ids = await asyncio.to_thread( + self.tokenizer.encode, + text, + truncation=True, + max_length=self.sequence_length, + ) + input_ids.append(self.tokenizer.eos_token_id) + buffer_to_append.extend(input_ids) + + async with self.lock: + self.buffer.extend(buffer_to_append) + self.pages.append((config_name, row_offset, split)) + break + + except Exception as e: + attempt += 1 + if attempt < retry_limit: + self.logger.debug( + f"Retrying page {page} due to error: {e}. Attempt {attempt}" + ) + await asyncio.sleep(self.retry_delay * attempt) + else: + raise Exception( + f"Maximum retry attempts exceeded for page {page}" + ) from e + + def _fetch_rows_from_parquet_sync( + self, config_name: str, row_offset: int, num_rows: int + ) -> typing.List[str]: + """ + Synchronously fetch rows from parquet files in R2. + + The parquet files can be organized in two ways: + 1. Single file per config: {config_name}.parquet + 2. Sharded files: {config_name}/{shard_id}.parquet + + Args: + config_name: Name of the config/subset + row_offset: Starting row number + num_rows: Number of rows to fetch + + Returns: + List of text strings + """ + s3 = self._get_s3_client() + + config_info = self.configs_data.get(config_name, {}) + shards = config_info.get("shards", []) + rows_per_shard = config_info.get("rows_per_shard", 10000) + + if shards: + # Sharded structure - find the right shard + shard_idx = row_offset // rows_per_shard + local_offset = row_offset % rows_per_shard + + if shard_idx >= len(shards): + shard_idx = len(shards) - 1 + local_offset = 0 + + shard_key = shards[shard_idx] + else: + # Try single file structure + shard_key = f"{config_name}.parquet" + local_offset = row_offset + + try: + response = s3.get_object(Bucket=self.r2_bucket_name, Key=shard_key) + parquet_data = response["Body"].read() + + # Read parquet file + table = pq.read_table(io.BytesIO(parquet_data)) + df = table.to_pandas() + + # Extract rows + end_offset = min(local_offset + num_rows, len(df)) + if local_offset >= len(df): + local_offset = max(0, len(df) - num_rows) + end_offset = len(df) + + # Get text column (try common column names) + text_column = None + for col_name in ["text", "content", "data", "raw_content"]: + if col_name in df.columns: + text_column = col_name + break + + if text_column is None: + # Use first string column + for col in df.columns: + if df[col].dtype == object: + text_column = col + break + + if text_column is None: + raise ValueError(f"No text column found in parquet file {shard_key}") + + rows = df[text_column].iloc[local_offset:end_offset].tolist() + return [str(row) for row in rows if row is not None] + + except Exception as e: + self.logger.error(f"Error fetching from {shard_key}: {e}") + raise + + def get_random_pages(self, num_pages): + """Randomly sample pages.""" + pages = [] + + for _ in range(num_pages): + config_name = random.choice(list(self.configs_data.keys())) + config_info = self.configs_data[config_name] + max_page = config_info["num_rows"] - 1 - self.num_rows_per_page + if max_page < 0: + max_page = 0 + page = random.randint(0, max(0, max_page)) + split = config_info.get("split", "train") + pages.append((config_name, page, split)) + + return pages + + def get_page_names(self): + """Return page names as strings.""" + page_names = [] + if hasattr(self, "pages"): + page_names = [ + f"{cfg_name}_{num_rows}_{split}" + for cfg_name, num_rows, split in self.pages + ] + return page_names + + +def get_dataset_loader(data_source: str = "huggingface"): + """ + Factory function to get the appropriate DatasetLoader class based on config. + + Args: + data_source: Either "huggingface" or "r2" + + Returns: + The appropriate DatasetLoader class (not an instance) + + Usage: + DataLoader = get_dataset_loader(config.neuron.data_source) + pages = await DataLoader.next_pages(offset=block, n_pages=n_pages, seed=seed) + dataset = await DataLoader.create(...) + """ + if data_source == "r2": + return R2DatasetLoader + else: + return DatasetLoader + + +async def create_dataset( + data_source: str = "huggingface", + batch_size: int = None, + sequence_length: int = None, + num_pages: int = None, + pages_info=None, + tokenizer=None, + pack_samples: bool = False, +): + """ + Convenience function to create a dataset using the appropriate loader. + + Args: + data_source: Either "huggingface" or "r2" + batch_size: Batch size for training + sequence_length: Sequence length for tokenization + num_pages: Number of pages to fetch (if pages_info not provided) + pages_info: Pre-selected pages to fetch + tokenizer: Tokenizer to use + pack_samples: Whether to pack samples + + Returns: + An instance of either DatasetLoader or R2DatasetLoader + """ + LoaderClass = get_dataset_loader(data_source) + return await LoaderClass.create( + batch_size=batch_size, + sequence_length=sequence_length, + num_pages=num_pages, + pages_info=pages_info, + tokenizer=tokenizer, + pack_samples=pack_samples, + ) + + +async def get_next_pages( + data_source: str = "huggingface", + offset: int = 0, + n_pages: int = 1, + seed: str = "", + num_rows_per_page: int = 100, +): + """ + Get next pages using the appropriate loader. + + Args: + data_source: Either "huggingface" or "r2" + offset: Block number offset for RNG + n_pages: Number of pages to select + seed: Seed for RNG (typically uid + local_rank) + num_rows_per_page: Rows per page + + Returns: + List of (config_name, row_offset, split) tuples + """ + LoaderClass = get_dataset_loader(data_source) + return await LoaderClass.next_pages( + offset=offset, + n_pages=n_pages, + seed=seed, + num_rows_per_page=num_rows_per_page, + ) diff --git a/distributed_training/utils/config.py b/distributed_training/utils/config.py index eb6e2e78..72ef8b4d 100644 --- a/distributed_training/utils/config.py +++ b/distributed_training/utils/config.py @@ -179,6 +179,14 @@ def add_args(cls, parser, prefix=None): default="dstrbtd/llama-1b", ) + parser.add_argument( + "--neuron.data_source", + type=str, + choices=["huggingface", "r2"], + help="Data source for training data: 'huggingface' (default) or 'r2' (Cloudflare R2)", + default="huggingface", + ) + parser.add_argument( "--neuron.master_ss58_address", type=str, diff --git a/distributed_training/utils/state_loader.py b/distributed_training/utils/state_loader.py index ee5b8377..b35a20f0 100644 --- a/distributed_training/utils/state_loader.py +++ b/distributed_training/utils/state_loader.py @@ -796,6 +796,7 @@ def load_model_optimizer_gradient_averager( self.device, self.logger, # parameters_list, + data_source=getattr(self.config.neuron, "data_source", "huggingface"), ) if ( (self.master) diff --git a/distributed_training/validator/reward.py b/distributed_training/validator/reward.py index bebc85bd..f7557612 100644 --- a/distributed_training/validator/reward.py +++ b/distributed_training/validator/reward.py @@ -36,7 +36,7 @@ from transformers import AutoConfig, AutoModelForCausalLM from distributed_training import __run__ -from distributed_training.data.dataset import DatasetLoader +from distributed_training.data.dataset import DatasetLoader, get_dataset_loader from distributed_training.utils.progress_tracker import get_progress, get_r2_client from distributed_training.utils.state_loader import ( cleanup_old_cache, @@ -81,9 +81,13 @@ async def fetch_training_data( """ attempt = 0 + # Get the appropriate loader based on config + data_source = getattr(self.config.neuron, "data_source", "huggingface") + LoaderClass = get_dataset_loader(data_source) + while attempt < self.retry_limit: try: - pages = await DatasetLoader.next_pages( + pages = await LoaderClass.next_pages( offset=block, n_pages=n_pages, seed=uid + self.local_rank, @@ -93,7 +97,7 @@ async def fetch_training_data( self.logger.debug(pages) - dataset = await DatasetLoader.create( + dataset = await LoaderClass.create( batch_size=self.config.neuron.local_batch_size_train, sequence_length=1024, pages_info=pages, diff --git a/eval/eval_loss.py b/eval/eval_loss.py index 789ed0bc..154907ab 100644 --- a/eval/eval_loss.py +++ b/eval/eval_loss.py @@ -15,7 +15,7 @@ import json import torch.distributed as dist from distributed_training import __run__ -from distributed_training.data.dataset import DatasetLoader +from distributed_training.data.dataset import DatasetLoader, get_dataset_loader from transformers import AutoTokenizer, AutoModelForCausalLM from huggingface_hub import HfApi, snapshot_download from huggingface_hub.constants import HF_HUB_CACHE @@ -242,6 +242,11 @@ async def fetch_training_data(tokenizer): retry_delay = 60 attempt = 0 local_batch_size_train = 4 + + # Get data source from environment variable (default: huggingface) + data_source = os.getenv("DATA_SOURCE", "huggingface") + LoaderClass = get_dataset_loader(data_source) + if dist.get_rank() == 0: current_block = random.randint(6193881 * 2, 6193881 * 4) uid = random.randint(300, 1000000) @@ -256,7 +261,7 @@ async def fetch_training_data(tokenizer): # print(SELF.local_rank, f"Fetched block {current_block} with uid {uid}") while attempt < retry_limit: try: - pages = await DatasetLoader.next_pages( + pages = await LoaderClass.next_pages( offset=current_block, n_pages=5, seed=uid, @@ -264,7 +269,7 @@ async def fetch_training_data(tokenizer): random.seed(uid) random.shuffle(pages) - dataset = await DatasetLoader.create( + dataset = await LoaderClass.create( batch_size=local_batch_size_train, sequence_length=1024, pages_info=pages, diff --git a/neurons/miner.py b/neurons/miner.py index c7e7f552..de76ac56 100644 --- a/neurons/miner.py +++ b/neurons/miner.py @@ -76,7 +76,7 @@ # from distributed_training.averaging.avg_handler import AllReduceError from distributed_training.base.miner import BaseMinerNeuron, TrainingStatus -from distributed_training.data.dataset import DatasetLoader +from distributed_training.data.dataset import DatasetLoader, get_dataset_loader from distributed_training.utils.chain import log_r2_to_chain from distributed_training.utils.misc import ( init_dht, @@ -569,11 +569,15 @@ def resume_training(self): async def fetch_training_data(self): """Async function to fetch training data""" attempt = 0 + # Get the appropriate loader based on config + data_source = getattr(self.config.neuron, "data_source", "huggingface") + LoaderClass = get_dataset_loader(data_source) + while attempt < self.retry_limit: try: self.set_current_block_across_ranks() - pages = await DatasetLoader.next_pages( + pages = await LoaderClass.next_pages( offset=self.current_block, n_pages=5, seed=self.uid + self.local_rank, @@ -583,7 +587,7 @@ async def fetch_training_data(self): self.logger.debug(pages) - dataset = await DatasetLoader.create( + dataset = await LoaderClass.create( batch_size=self.config.neuron.local_batch_size_train, sequence_length=1024, pages_info=pages, diff --git a/requirements.txt b/requirements.txt index b552072e..9cd30cbb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,5 @@ muon-optimizer @ git+https://github.com/KellerJordan/Muon@f90a42b28e00b8d9d2d058 rich==14.1.0 bittensor-cli==9.11.2 boto3==1.40.45 +pyarrow>=14.0.0 +pandas>=2.0.0