-
Notifications
You must be signed in to change notification settings - Fork 11
perf: collator with padding for tpu #29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
7853657
aa57ad6
c67eb68
2c55192
e7f3b3a
256927a
3d1ec52
3460c8d
3d600dd
d124dee
01ab4ce
9c078c7
c3c45f4
649d129
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,9 @@ | ||
| import logging | ||
|
|
||
| from accelerate import DistributedType | ||
| from datasets import config, load_dataset | ||
| from torch.utils.data import DataLoader | ||
| from transformers import default_data_collator | ||
| from transformers import DataCollatorWithPadding, default_data_collator | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -107,7 +108,7 @@ def get_dataloaders(tokenizer, args): | |
| text_column_name = "text" if "text" in column_names else column_names[0] | ||
|
|
||
| def tokenize_function(examples): | ||
| return tokenizer(examples[text_column_name]) | ||
| return tokenizer(examples[text_column_name], truncation=True, max_length=args.max_seq_len) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Truncate to |
||
|
|
||
| logger.info("Tokenize dataset") | ||
| tokenized_datasets = datasets.map( | ||
|
|
@@ -179,15 +180,18 @@ def group_texts(examples): | |
| logger.info(f" Num validation examples = {len(val_dataset)}") | ||
|
|
||
| # DataLoaders creation: | ||
| data_collator = default_data_collator | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question as #29 (comment) |
||
| if args.distributed_type == DistributedType.TPU: | ||
| data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=args.max_seq_len) | ||
| train_dataloader = DataLoader( | ||
| train_dataset, | ||
| shuffle=True, | ||
| collate_fn=default_data_collator, | ||
| collate_fn=data_collator, | ||
| batch_size=args.per_device_train_batch_size, | ||
| ) | ||
| val_dataloader1 = DataLoader( | ||
| val_dataset, | ||
| collate_fn=default_data_collator, | ||
| collate_fn=data_collator, | ||
| batch_size=args.per_device_eval_batch_size, | ||
| ) | ||
| return train_dataloader, {"val1": val_dataloader1} | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -71,7 +71,7 @@ def add_metadata_and_chunk_examples( | |
| # Get the global metadata prefix that is prepended to each training example. | ||
| metadata_prefix = create_metadata_prefix(example, cfg) | ||
| metadata_prefix_encoded = ( | ||
| tokenizer.encode_plus(cfg.metadata_prefix_start_seq + metadata_prefix).input_ids | ||
| tokenizer.encode_plus(cfg.metadata_prefix_start_seq + metadata_prefix, truncation=True).input_ids | ||
| if metadata_prefix | ||
| else [] | ||
| ) | ||
|
|
@@ -89,7 +89,7 @@ def add_metadata_and_chunk_examples( | |
| text_with_local_metadata = " " + text_with_local_metadata | ||
| char_level_metadata_mask = [False] + char_level_metadata_mask | ||
|
|
||
| text_with_local_metadata_encoded = tokenizer.encode_plus(text_with_local_metadata) | ||
| text_with_local_metadata_encoded = tokenizer.encode_plus(text_with_local_metadata, truncation=True) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Truncate to model's max seq len (1024) for the whole seq with metadata, cf. #29 (comment). |
||
|
|
||
| def is_metadata(idx: int) -> bool: | ||
| char_span = text_with_local_metadata_encoded.token_to_chars(idx) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -136,6 +136,9 @@ def loss_fn(batch, outputs, metadata_mask=None): | |
|
|
||
| @hydra.main(config_path=None, config_name="config") | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not really related to this PR but I need it for testing. |
||
| def main(args: CFG) -> None: | ||
| accelerator = Accelerator() | ||
| args.data_config.distributed_type = accelerator.distributed_type | ||
|
|
||
| print(OmegaConf.to_yaml(args)) | ||
| config_dict = OmegaConf.to_container(args) | ||
|
|
||
|
|
@@ -144,7 +147,6 @@ def main(args: CFG) -> None: | |
| args = OmegaConf.to_object(args) | ||
|
|
||
| set_seed(args.seed) | ||
| accelerator = Accelerator() | ||
| is_local_main_process = accelerator.is_local_main_process | ||
| tqdm = partial(original_tqdm, disable=not is_local_main_process, position=0) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to pad to longest for GPU?