diff --git a/bsmetadata/experiments/with_metadata.py b/bsmetadata/experiments/with_metadata.py index b7270dc6..5ebcd64f 100644 --- a/bsmetadata/experiments/with_metadata.py +++ b/bsmetadata/experiments/with_metadata.py @@ -1,9 +1,10 @@ import functools import logging +from accelerate import DistributedType from datasets import config, load_dataset from torch.utils.data import DataLoader -from transformers import default_data_collator +from transformers import DataCollatorWithPadding, default_data_collator from bsmetadata.metadata_utils import add_metadata_and_chunk_examples @@ -152,15 +153,18 @@ def create_labels_column(examples): logger.info(f" Num validation examples = {len(val_dataset)}") # DataLoaders creation: + data_collator = default_data_collator + if args.distributed_type == DistributedType.TPU: + data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=args.max_seq_len) train_dataloader = DataLoader( train_dataset, shuffle=True, - collate_fn=default_data_collator, + collate_fn=data_collator, batch_size=args.per_device_train_batch_size, ) val_dataloader1 = DataLoader( val_dataset, - collate_fn=default_data_collator, + collate_fn=data_collator, batch_size=args.per_device_eval_batch_size, ) return train_dataloader, {"val1": val_dataloader1} diff --git a/bsmetadata/experiments/without_metadata.py b/bsmetadata/experiments/without_metadata.py index 8f6b197b..3e96603b 100644 --- a/bsmetadata/experiments/without_metadata.py +++ b/bsmetadata/experiments/without_metadata.py @@ -1,8 +1,9 @@ import logging +from accelerate import DistributedType from datasets import config, load_dataset from torch.utils.data import DataLoader -from transformers import default_data_collator +from transformers import DataCollatorWithPadding, default_data_collator logger = logging.getLogger(__name__) @@ -107,7 +108,7 @@ def get_dataloaders(tokenizer, args): text_column_name = "text" if "text" in column_names else column_names[0] def tokenize_function(examples): - return tokenizer(examples[text_column_name]) + return tokenizer(examples[text_column_name], truncation=True, max_length=args.max_seq_len) logger.info("Tokenize dataset") tokenized_datasets = datasets.map( @@ -179,15 +180,18 @@ def group_texts(examples): logger.info(f" Num validation examples = {len(val_dataset)}") # DataLoaders creation: + data_collator = default_data_collator + if args.distributed_type == DistributedType.TPU: + data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=args.max_seq_len) train_dataloader = DataLoader( train_dataset, shuffle=True, - collate_fn=default_data_collator, + collate_fn=data_collator, batch_size=args.per_device_train_batch_size, ) val_dataloader1 = DataLoader( val_dataset, - collate_fn=default_data_collator, + collate_fn=data_collator, batch_size=args.per_device_eval_batch_size, ) return train_dataloader, {"val1": val_dataloader1} diff --git a/bsmetadata/input_pipeline.py b/bsmetadata/input_pipeline.py index 5a18b4da..619ce024 100644 --- a/bsmetadata/input_pipeline.py +++ b/bsmetadata/input_pipeline.py @@ -1,6 +1,8 @@ from dataclasses import dataclass, field from typing import Optional +from accelerate import DistributedType + from bsmetadata.metadata_utils import MetadataConfig @@ -53,6 +55,7 @@ class DataConfig: " the dataset. If you are using `with_metadata` the recommended batch size is 1.." }, ) + distributed_type: DistributedType = field(default=DistributedType.NO) def get_dataloaders(tokenizer, cfg: DataConfig): diff --git a/bsmetadata/metadata_utils.py b/bsmetadata/metadata_utils.py index 63447c25..0a22b920 100644 --- a/bsmetadata/metadata_utils.py +++ b/bsmetadata/metadata_utils.py @@ -71,7 +71,7 @@ def add_metadata_and_chunk_examples( # Get the global metadata prefix that is prepended to each training example. metadata_prefix = create_metadata_prefix(example, cfg) metadata_prefix_encoded = ( - tokenizer.encode_plus(cfg.metadata_prefix_start_seq + metadata_prefix).input_ids + tokenizer.encode_plus(cfg.metadata_prefix_start_seq + metadata_prefix, truncation=True).input_ids if metadata_prefix else [] ) @@ -89,7 +89,7 @@ def add_metadata_and_chunk_examples( text_with_local_metadata = " " + text_with_local_metadata char_level_metadata_mask = [False] + char_level_metadata_mask - text_with_local_metadata_encoded = tokenizer.encode_plus(text_with_local_metadata) + text_with_local_metadata_encoded = tokenizer.encode_plus(text_with_local_metadata, truncation=True) def is_metadata(idx: int) -> bool: char_span = text_with_local_metadata_encoded.token_to_chars(idx) diff --git a/bsmetadata/train.py b/bsmetadata/train.py index 49ca5ef6..86b507b2 100644 --- a/bsmetadata/train.py +++ b/bsmetadata/train.py @@ -136,6 +136,9 @@ def loss_fn(batch, outputs, metadata_mask=None): @hydra.main(config_path=None, config_name="config") def main(args: CFG) -> None: + accelerator = Accelerator() + args.data_config.distributed_type = accelerator.distributed_type + print(OmegaConf.to_yaml(args)) config_dict = OmegaConf.to_container(args) @@ -144,7 +147,6 @@ def main(args: CFG) -> None: args = OmegaConf.to_object(args) set_seed(args.seed) - accelerator = Accelerator() is_local_main_process = accelerator.is_local_main_process tqdm = partial(original_tqdm, disable=not is_local_main_process, position=0)