Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions bsmetadata/experiments/with_metadata.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import functools
import logging

from accelerate import DistributedType
from datasets import config, load_dataset
from torch.utils.data import DataLoader
from transformers import default_data_collator
from transformers import DataCollatorWithPadding, default_data_collator

from bsmetadata.metadata_utils import add_metadata_and_chunk_examples

Expand Down Expand Up @@ -152,15 +153,18 @@ def create_labels_column(examples):
logger.info(f" Num validation examples = {len(val_dataset)}")

# DataLoaders creation:
data_collator = default_data_collator
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to pad to longest for GPU?

if args.distributed_type == DistributedType.TPU:
data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=args.max_seq_len)
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
collate_fn=default_data_collator,
collate_fn=data_collator,
batch_size=args.per_device_train_batch_size,
)
val_dataloader1 = DataLoader(
val_dataset,
collate_fn=default_data_collator,
collate_fn=data_collator,
batch_size=args.per_device_eval_batch_size,
)
return train_dataloader, {"val1": val_dataloader1}
12 changes: 8 additions & 4 deletions bsmetadata/experiments/without_metadata.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging

from accelerate import DistributedType
from datasets import config, load_dataset
from torch.utils.data import DataLoader
from transformers import default_data_collator
from transformers import DataCollatorWithPadding, default_data_collator


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -107,7 +108,7 @@ def get_dataloaders(tokenizer, args):
text_column_name = "text" if "text" in column_names else column_names[0]

def tokenize_function(examples):
return tokenizer(examples[text_column_name])
return tokenizer(examples[text_column_name], truncation=True, max_length=args.max_seq_len)
Copy link
Collaborator Author

@tianjianjiang tianjianjiang Sep 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Truncate to max_seq_len (512 by default) to preserve space for metadata, cf. #29 (comment).


logger.info("Tokenize dataset")
tokenized_datasets = datasets.map(
Expand Down Expand Up @@ -179,15 +180,18 @@ def group_texts(examples):
logger.info(f" Num validation examples = {len(val_dataset)}")

# DataLoaders creation:
data_collator = default_data_collator
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as #29 (comment)

if args.distributed_type == DistributedType.TPU:
data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=args.max_seq_len)
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
collate_fn=default_data_collator,
collate_fn=data_collator,
batch_size=args.per_device_train_batch_size,
)
val_dataloader1 = DataLoader(
val_dataset,
collate_fn=default_data_collator,
collate_fn=data_collator,
batch_size=args.per_device_eval_batch_size,
)
return train_dataloader, {"val1": val_dataloader1}
3 changes: 3 additions & 0 deletions bsmetadata/input_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass, field
from typing import Optional

from accelerate import DistributedType

from bsmetadata.metadata_utils import MetadataConfig


Expand Down Expand Up @@ -53,6 +55,7 @@ class DataConfig:
" the dataset. If you are using `with_metadata` the recommended batch size is 1.."
},
)
distributed_type: DistributedType = field(default=DistributedType.NO)


def get_dataloaders(tokenizer, cfg: DataConfig):
Expand Down
4 changes: 2 additions & 2 deletions bsmetadata/metadata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def add_metadata_and_chunk_examples(
# Get the global metadata prefix that is prepended to each training example.
metadata_prefix = create_metadata_prefix(example, cfg)
metadata_prefix_encoded = (
tokenizer.encode_plus(cfg.metadata_prefix_start_seq + metadata_prefix).input_ids
tokenizer.encode_plus(cfg.metadata_prefix_start_seq + metadata_prefix, truncation=True).input_ids
if metadata_prefix
else []
)
Expand All @@ -89,7 +89,7 @@ def add_metadata_and_chunk_examples(
text_with_local_metadata = " " + text_with_local_metadata
char_level_metadata_mask = [False] + char_level_metadata_mask

text_with_local_metadata_encoded = tokenizer.encode_plus(text_with_local_metadata)
text_with_local_metadata_encoded = tokenizer.encode_plus(text_with_local_metadata, truncation=True)
Copy link
Collaborator Author

@tianjianjiang tianjianjiang Sep 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Truncate to model's max seq len (1024) for the whole seq with metadata, cf. #29 (comment).


def is_metadata(idx: int) -> bool:
char_span = text_with_local_metadata_encoded.token_to_chars(idx)
Expand Down
4 changes: 3 additions & 1 deletion bsmetadata/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def loss_fn(batch, outputs, metadata_mask=None):

@hydra.main(config_path=None, config_name="config")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really related to this PR but I need it for testing.

def main(args: CFG) -> None:
accelerator = Accelerator()
args.data_config.distributed_type = accelerator.distributed_type

print(OmegaConf.to_yaml(args))
config_dict = OmegaConf.to_container(args)

Expand All @@ -144,7 +147,6 @@ def main(args: CFG) -> None:
args = OmegaConf.to_object(args)

set_seed(args.seed)
accelerator = Accelerator()
is_local_main_process = accelerator.is_local_main_process
tqdm = partial(original_tqdm, disable=not is_local_main_process, position=0)

Expand Down