From 57db62390bf9bef8e8fdc226d3318b7f61306810 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Thu, 8 Dec 2022 15:23:48 +0400 Subject: [PATCH 01/14] Add RHO-Loss --- .../codeparrot/scripts/arguments.py | 17 ++++++++-- .../codeparrot/scripts/codeparrot_training.py | 34 ++++++++++++++++++- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/examples/research_projects/codeparrot/scripts/arguments.py b/examples/research_projects/codeparrot/scripts/arguments.py index 4def9ac3b8..68882a68e0 100644 --- a/examples/research_projects/codeparrot/scripts/arguments.py +++ b/examples/research_projects/codeparrot/scripts/arguments.py @@ -20,8 +20,9 @@ class TrainingArguments: dataset_name_valid: Optional[str] = field( default="codeparrot/codeparrot-clean-valid", metadata={"help": "Name or path of validation dataset."} ) - train_batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for training."}) - valid_batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for evaluation."}) + train_batch_size: Optional[int] = field(default=320, metadata={"help": "Batch size for training."}) + train_batch_size_select: Optional[int] = field(default=32, metadata={"help": "Batch size to subselect for training."}) + valid_batch_size: Optional[int] = field(default=32, metadata={"help": "Batch size for evaluation."}) weight_decay: Optional[float] = field(default=0.1, metadata={"help": "Value of weight decay."}) shuffle_buffer: Optional[int] = field( default=10000, metadata={"help": "Size of buffer used to shuffle streaming dataset."} @@ -47,6 +48,18 @@ class TrainingArguments: default=1024, metadata={"help": "Interval to save checkpoints. Measured as number of forward passes not training steps."}, ) + selection_method: Optional[str] = field( + default=None, metadata={"help": "Selection method to subselect from the batch size. Can be uniform or rholoss"} + ) + irred_losses: Optional[str] = field( + default="irred_losses.pt", metadata={"help": "Path to irreducible losses pt file. Must be supplied if selection_method is rholoss"} + ) + compute_irred_losses: Optional[bool] = field( + default=False, + metadata={ + "help": "If True irreducible losses are computed and saved to the path specified by irred_losses." + }, + ) resume_from_checkpoint: Optional[str] = field( default=None, metadata={"help": "States path if the training should continue from a checkpoint folder."} ) diff --git a/examples/research_projects/codeparrot/scripts/codeparrot_training.py b/examples/research_projects/codeparrot/scripts/codeparrot_training.py index b2af8767a2..867d957b21 100644 --- a/examples/research_projects/codeparrot/scripts/codeparrot_training.py +++ b/examples/research_projects/codeparrot/scripts/codeparrot_training.py @@ -263,11 +263,38 @@ def get_lr(): completed_steps = 0 t_start = time.time() loss_tracking = 0 + +if args.compute_irred_losses: + irred_losses = torch.zeros(len(train_dataloader) * args.train_batch_size, device="cpu") +elif args.irred_losses and not args.compute_irred_losses: + # Should be of shape len(train_dataloader) * args.train_batch_size + irred_losses = torch.load(args.irred_losses, map_location=torch.device("cpu")) + assert irred_losses.shape[0] == len(train_dataloader) * args.train_batch_size + for step, batch in enumerate(train_dataloader, start=1): if args.resume_from_checkpoint and step < resume_step: continue # we need to skip steps until we reach the resumed step + if args.selection_method: + if args.selection_method == "uniform": + batch = {k: v[:args.train_batch_size_select] for k,v in batch.items()} + elif args.selection_method == "rholoss": + with torch.no_grad(): + out = model(batch, labels=batch, use_cache=False).loss + losses = accelerator.gather(loss.repeat(args.train_batch_size)) + cur_irred_losses = irred_losses[step-1:args.train_batch_size*step] + assert losses.shape == cur_irred_losses.shape + red_losses = losses - cur_irred_losses + # Select the top args.train_batch_size_select losses & produce a new batch + top_losses, top_indices = torch.topk(red_losses, args.train_batch_size_select) + batch = {k: v[top_indices] for k,v in batch.items()} + if args.compute_irred_losses: + with torch.no_grad(): + out = model(batch, labels=batch, use_cache=False).loss + losses = accelerator.gather(loss.repeat(args.train_batch_size)) + irred_losses[step-1:args.train_batch_size*step] = losses + continue loss = model(batch, labels=batch, use_cache=False).loss - avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size_select)).mean() loss_tracking += avg_loss.item() / args.gradient_accumulation_steps log_metrics(step, {"samples": step * samples_per_step, "loss_per_step/train": loss.item()}) loss = loss / args.gradient_accumulation_steps @@ -313,6 +340,11 @@ def get_lr(): if completed_steps >= args.max_train_steps: break +# Save irred losses +if args.compute_irred_losses: + torch.save(irred_losses, os.path.join(args.save_dir, args.irred_losses)) + exit() + # Evaluate and save the last checkpoint logger.info("Evaluating and saving model after training") eval_loss, perplexity = evaluate(args) From 95d2009b80848009a9c5b670e855f5712640e3cb Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Thu, 8 Dec 2022 18:16:01 +0400 Subject: [PATCH 02/14] Separate valid data --- .../codeparrot/scripts/codeparrot_training.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/research_projects/codeparrot/scripts/codeparrot_training.py b/examples/research_projects/codeparrot/scripts/codeparrot_training.py index 867d957b21..e7a0b04133 100644 --- a/examples/research_projects/codeparrot/scripts/codeparrot_training.py +++ b/examples/research_projects/codeparrot/scripts/codeparrot_training.py @@ -124,7 +124,13 @@ def create_dataloaders(args): ds_kwargs = {"streaming": True} train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs) train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) - valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs) + if args.dataset_name_train == args.dataset_name_valid: + valid_size = int(len(train_data) * 0.05) + train_size = len(train_data) - valid_size + valid_data = train_data.select(range(valid_size)) + train_data = train_data.select(range(valid_size, train_size)) + else: + valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs) train_dataset = ConstantLengthDataset( tokenizer, train_data, infinite=True, seq_length=args.seq_length, tokenized=args.tokenized ) From 9432ee24b615c145099e785f8f0c2f358d37c7c7 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Thu, 8 Dec 2022 18:28:18 +0400 Subject: [PATCH 03/14] Train test split --- .../codeparrot/scripts/codeparrot_training.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/research_projects/codeparrot/scripts/codeparrot_training.py b/examples/research_projects/codeparrot/scripts/codeparrot_training.py index e7a0b04133..e0ce931ef9 100644 --- a/examples/research_projects/codeparrot/scripts/codeparrot_training.py +++ b/examples/research_projects/codeparrot/scripts/codeparrot_training.py @@ -125,10 +125,10 @@ def create_dataloaders(args): train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs) train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) if args.dataset_name_train == args.dataset_name_valid: - valid_size = int(len(train_data) * 0.05) - train_size = len(train_data) - valid_size - valid_data = train_data.select(range(valid_size)) - train_data = train_data.select(range(valid_size, train_size)) + # Split the dataset into train and validation + train_data, valid_data = train_data.train_test_split( + test_size=0.05, shuffle=False, seed=args.seed + ) else: valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs) train_dataset = ConstantLengthDataset( From 48e29c6ed19a54c677f1b1e1dcf368ed69ae2400 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Thu, 8 Dec 2022 21:43:55 +0400 Subject: [PATCH 04/14] Add nostreaming kwarg --- .../research_projects/codeparrot/scripts/arguments.py | 3 +++ .../codeparrot/scripts/codeparrot_training.py | 11 +++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/research_projects/codeparrot/scripts/arguments.py b/examples/research_projects/codeparrot/scripts/arguments.py index 68882a68e0..bb89706461 100644 --- a/examples/research_projects/codeparrot/scripts/arguments.py +++ b/examples/research_projects/codeparrot/scripts/arguments.py @@ -27,6 +27,9 @@ class TrainingArguments: shuffle_buffer: Optional[int] = field( default=10000, metadata={"help": "Size of buffer used to shuffle streaming dataset."} ) + no_streaming: Optional[bool] = field( + default=False, metadata={"help": "Whether not to use streaming for the dataset."} + ) learning_rate: Optional[float] = field(default=2e-4, metadata={"help": "Learning rate fo training."}) lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "Learning rate."}) num_warmup_steps: Optional[int] = field( diff --git a/examples/research_projects/codeparrot/scripts/codeparrot_training.py b/examples/research_projects/codeparrot/scripts/codeparrot_training.py index e0ce931ef9..b7395da0c0 100644 --- a/examples/research_projects/codeparrot/scripts/codeparrot_training.py +++ b/examples/research_projects/codeparrot/scripts/codeparrot_training.py @@ -121,14 +121,17 @@ def setup_logging(args): def create_dataloaders(args): - ds_kwargs = {"streaming": True} + ds_kwargs = {"streaming": args.no_streaming} train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs) train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) if args.dataset_name_train == args.dataset_name_valid: # Split the dataset into train and validation - train_data, valid_data = train_data.train_test_split( - test_size=0.05, shuffle=False, seed=args.seed - ) + if args.no_streaming: + train_data = train_data.shuffle() + valid_data = train_data.select(range(int(len(train_data) * 0.05))) + train_data = train_data.select(range(int(len(train_data) * 0.05), len(train_data))) + else: + valid_data = train_data.take(5000) else: valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs) train_dataset = ConstantLengthDataset( From 6d0a5e5b2f8e20ba0ceed836ea24e949f248d933 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Fri, 9 Dec 2022 00:04:54 +0400 Subject: [PATCH 05/14] Deactivate shuffling --- .../codeparrot/scripts/codeparrot_training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/research_projects/codeparrot/scripts/codeparrot_training.py b/examples/research_projects/codeparrot/scripts/codeparrot_training.py index b7395da0c0..d171fe51c9 100644 --- a/examples/research_projects/codeparrot/scripts/codeparrot_training.py +++ b/examples/research_projects/codeparrot/scripts/codeparrot_training.py @@ -119,7 +119,6 @@ def setup_logging(args): transformers.utils.logging.set_verbosity_error() return logger, run_name - def create_dataloaders(args): ds_kwargs = {"streaming": args.no_streaming} train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs) @@ -141,7 +140,8 @@ def create_dataloaders(args): tokenizer, valid_data, infinite=False, seq_length=args.seq_length, tokenized=args.tokenized ) train_dataset = train_dataset.shuffle(buffer_size=args.shuffle_buffer) - train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True) + # ShufflerIterDataPipe does not work in torch 1.13 hence shuffle=False + train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=False) eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size) return train_dataloader, eval_dataloader From a5bd7d422d8f90654618675258306f70acfc6b8b Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Fri, 9 Dec 2022 15:04:44 +0400 Subject: [PATCH 06/14] Update examples/research_projects/codeparrot/scripts/codeparrot_training.py Co-authored-by: Loubna Ben Allal <44069155+loubnabnl@users.noreply.github.com> --- .../codeparrot/scripts/codeparrot_training.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/research_projects/codeparrot/scripts/codeparrot_training.py b/examples/research_projects/codeparrot/scripts/codeparrot_training.py index d171fe51c9..f49c7a8c14 100644 --- a/examples/research_projects/codeparrot/scripts/codeparrot_training.py +++ b/examples/research_projects/codeparrot/scripts/codeparrot_training.py @@ -125,14 +125,14 @@ def create_dataloaders(args): train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) if args.dataset_name_train == args.dataset_name_valid: # Split the dataset into train and validation - if args.no_streaming: - train_data = train_data.shuffle() - valid_data = train_data.select(range(int(len(train_data) * 0.05))) - train_data = train_data.select(range(int(len(train_data) * 0.05), len(train_data))) - else: - valid_data = train_data.take(5000) + data = train_data.train_test_split( + test_size=0.005, shuffle=False, seed=args.seed + ) + train_data = data["train"] + valid_data = data["test"] + print(f"Size of train set: {len(valid_data)} and validation set {len(train_data)}") else: - valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs) + valid_data = load_dataset(args.dataset_name_valid, split="train", use_auth_token=True) train_dataset = ConstantLengthDataset( tokenizer, train_data, infinite=True, seq_length=args.seq_length, tokenized=args.tokenized ) From 9af45a80c780c2082d0de50b24274f278b808a3c Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Fri, 9 Dec 2022 15:04:49 +0400 Subject: [PATCH 07/14] Update examples/research_projects/codeparrot/scripts/codeparrot_training.py Co-authored-by: Loubna Ben Allal <44069155+loubnabnl@users.noreply.github.com> --- .../research_projects/codeparrot/scripts/codeparrot_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/research_projects/codeparrot/scripts/codeparrot_training.py b/examples/research_projects/codeparrot/scripts/codeparrot_training.py index f49c7a8c14..0234f31339 100644 --- a/examples/research_projects/codeparrot/scripts/codeparrot_training.py +++ b/examples/research_projects/codeparrot/scripts/codeparrot_training.py @@ -121,7 +121,7 @@ def setup_logging(args): def create_dataloaders(args): ds_kwargs = {"streaming": args.no_streaming} - train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs) + train_data = load_dataset(args.dataset_name_train, split="train", use_auth_token=True) train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) if args.dataset_name_train == args.dataset_name_valid: # Split the dataset into train and validation From e29ead46c45dab234d349123d62522406d99ff10 Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Fri, 9 Dec 2022 15:04:55 +0400 Subject: [PATCH 08/14] Update examples/research_projects/codeparrot/scripts/codeparrot_training.py Co-authored-by: Loubna Ben Allal <44069155+loubnabnl@users.noreply.github.com> --- .../research_projects/codeparrot/scripts/codeparrot_training.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/research_projects/codeparrot/scripts/codeparrot_training.py b/examples/research_projects/codeparrot/scripts/codeparrot_training.py index 0234f31339..8a78e62c92 100644 --- a/examples/research_projects/codeparrot/scripts/codeparrot_training.py +++ b/examples/research_projects/codeparrot/scripts/codeparrot_training.py @@ -120,7 +120,6 @@ def setup_logging(args): return logger, run_name def create_dataloaders(args): - ds_kwargs = {"streaming": args.no_streaming} train_data = load_dataset(args.dataset_name_train, split="train", use_auth_token=True) train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) if args.dataset_name_train == args.dataset_name_valid: From 197a23adc9676f6cdbdd40aaa698d43eca982ca0 Mon Sep 17 00:00:00 2001 From: loubna Date: Mon, 12 Dec 2022 11:36:05 +0000 Subject: [PATCH 09/14] use a fork of transformers where we use reduction=none in GPT2 loss and update accelerate --- examples/research_projects/codeparrot/requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/research_projects/codeparrot/requirements.txt b/examples/research_projects/codeparrot/requirements.txt index 7eff3ac7f1..70ef16d3a2 100644 --- a/examples/research_projects/codeparrot/requirements.txt +++ b/examples/research_projects/codeparrot/requirements.txt @@ -1,9 +1,9 @@ -transformers==4.19.0 +git+https://github.com/loubnabnl/transformers.git@loss-reduction-none +accelerate==0.15.0 datasets==1.16.0 wandb==0.12.0 tensorboard==2.6.0 torch==1.11.0 huggingface-hub==0.1.0 -git+https://github.com/huggingface/accelerate.git@3c45b6f760ad8745be9ebc9bbb26f5b04dea4abe datasketch==1.5.7 dpu_utils \ No newline at end of file From 700922514f344d90b4183b598cf5043c7087c3b2 Mon Sep 17 00:00:00 2001 From: loubna Date: Mon, 12 Dec 2022 11:36:32 +0000 Subject: [PATCH 10/14] add sanity check on order of irred losses when loaded wrt batches --- .../codeparrot/scripts/utils_rholoss.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 examples/research_projects/codeparrot/scripts/utils_rholoss.py diff --git a/examples/research_projects/codeparrot/scripts/utils_rholoss.py b/examples/research_projects/codeparrot/scripts/utils_rholoss.py new file mode 100644 index 0000000000..c5d9e19db7 --- /dev/null +++ b/examples/research_projects/codeparrot/scripts/utils_rholoss.py @@ -0,0 +1,60 @@ +import os + +import torch + + +def get_control_examples( + accelerator, dataloader, train_batch_size, limit_control_examples=600, seq_length=1024 +): + """get control examples for computing the irreducible losses over the entire dataset""" + control_examples = torch.zeros((limit_control_examples, seq_length), device="cpu").type(torch.LongTensor) + for step, batch in enumerate(dataloader, start=1): + indexes = list( + range( + (step - 1) * train_batch_size * accelerator.state.num_processes, + step * train_batch_size * accelerator.state.num_processes, + ) + ) + if max(indexes) >= limit_control_examples: + break + with torch.no_grad(): + batches = accelerator.gather(batch) + control_examples[indexes] = batches.cpu() + return control_examples + + +def compute_save_control_examples( + accelerator, + dataloader, + train_batch_size, + limit_control_examples=600, + seq_length=1024, + save_dir="./control_examples", +): + """compute and save control examples for computing the irreducible losses over the entire dataset""" + control_examples = get_control_examples( + accelerator, dataloader, train_batch_size, limit_control_examples, seq_length + ) + if accelerator.is_main_process: + # save each losses and examples + os.makedirs(save_dir, exist_ok=True) + torch.save(control_examples, f"{save_dir}/control_examples.pt") + + +def sanity_check_irred_losses( + accelerator, + dataloader, + train_batch_size, + limit_control_examples=600, + seq_length=1024, + save_dir="./control_examples", +): + """sanity check of the order of loaded irreducible losses wrt batches of the current dataset""" + loaded_examples = torch.load(f"{save_dir}/control_examples.pt") + control_examples = get_control_examples( + accelerator, dataloader, train_batch_size, limit_control_examples, seq_length + ) + # check if the loaded tensors are the same as the ones we saved + assert torch.all(torch.eq(loaded_examples, control_examples)) + if accelerator.is_main_process: + print("Sanity check for irreducible loss order passed") From 0c8550e93a27b11fc33fe0cd5e00d9b63084ab7b Mon Sep 17 00:00:00 2001 From: loubna Date: Mon, 12 Dec 2022 11:36:48 +0000 Subject: [PATCH 11/14] fix issues in code and add sanity checks --- .../codeparrot/scripts/codeparrot_training.py | 109 ++++++++++++++---- 1 file changed, 88 insertions(+), 21 deletions(-) diff --git a/examples/research_projects/codeparrot/scripts/codeparrot_training.py b/examples/research_projects/codeparrot/scripts/codeparrot_training.py index 8a78e62c92..ec59f19fcd 100644 --- a/examples/research_projects/codeparrot/scripts/codeparrot_training.py +++ b/examples/research_projects/codeparrot/scripts/codeparrot_training.py @@ -18,6 +18,7 @@ from huggingface_hub import Repository from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed +from utils_roloss import sanity_check_irred_losses, compute_save_control_examples class ConstantLengthDataset(IterableDataset): """ @@ -272,37 +273,103 @@ def get_lr(): t_start = time.time() loss_tracking = 0 +samples_per_step = args.train_batch_size * accelerator.state.num_processes +total_samples = args.max_train_steps * samples_per_step + if args.compute_irred_losses: - irred_losses = torch.zeros(len(train_dataloader) * args.train_batch_size, device="cpu") + irred_losses = torch.zeros(total_samples, device="cpu") + compute_save_control_examples( + accelerator, + train_dataloader, + args.train_batch_size, + limit_control_examples=600, + seq_length=args.seq_length, + save_dir="./control_examples", + ) + if accelerator.is_main_process: + print(f"Total number of samples: {total_samples}") + print("Irreducible loss examples were saved in ./control_examples for future sanity checks on the order") + elif args.irred_losses and not args.compute_irred_losses: - # Should be of shape len(train_dataloader) * args.train_batch_size + # Should be of shape total_samples irred_losses = torch.load(args.irred_losses, map_location=torch.device("cpu")) - assert irred_losses.shape[0] == len(train_dataloader) * args.train_batch_size + assert irred_losses.shape[0] == total_samples, print( + f"shape of irred_losses: {irred_losses.shape[0]}, len of train_dataloader: {total_samples}" + ) + +if args.selection_method and args.selection_method == "rholoss": + if accelerator.is_main_process: + print("Running sanity checks for irreducible losses") + # run sanity checks to verify the order of irreducible losses wrt current batches + sanity_check_irred_losses( + accelerator, + train_dataloader, + args.train_batch_size, + limit_control_examples=600, + seq_length=args.seq_length, + save_dir="./control_examples", + ) for step, batch in enumerate(train_dataloader, start=1): - if args.resume_from_checkpoint and step < resume_step: - continue # we need to skip steps until we reach the resumed step if args.selection_method: + # select a smaller batch of examples to train on using a selection method if args.selection_method == "uniform": - batch = {k: v[:args.train_batch_size_select] for k,v in batch.items()} + batch = batch[torch.randperm(batch.size(0))[: args.train_batch_size_select]] elif args.selection_method == "rholoss": + # in this setting global_batch_size = train_batch_size x nb_workers with torch.no_grad(): - out = model(batch, labels=batch, use_cache=False).loss - losses = accelerator.gather(loss.repeat(args.train_batch_size)) - cur_irred_losses = irred_losses[step-1:args.train_batch_size*step] - assert losses.shape == cur_irred_losses.shape - red_losses = losses - cur_irred_losses + # we use reduction="none" in GPT2LMHeadModel loss implementation (install transformers from a fork) + loss = model(batch, labels=batch, use_cache=False).loss + loss = loss.view(batch.size(0), -1).mean(dim=1) + assert loss.shape == torch.Size( + [args.train_batch_size] + ), "make sure you are using GPT2LMHeadModel with reduction=none in the loss" + # TODO check data at the end that may be duplicated to divide batch equally among all workers + losses = accelerator.gather(loss) + cur_irred_losses = irred_losses[(step - 1) * samples_per_step : step * samples_per_step] + try: + losses.shape == cur_irred_losses.shape + except: + print( + f"Size mismatch between training losses {losses.shape} and irreducible losses {cur_irred_losses.shape}" + ) + red_losses = losses - cur_irred_losses.to(losses.device) + print(f"shape red_losses {red_losses.shape} and max is {torch.max(red_losses[0])}") # Select the top args.train_batch_size_select losses & produce a new batch top_losses, top_indices = torch.topk(red_losses, args.train_batch_size_select) - batch = {k: v[top_indices] for k,v in batch.items()} + batches = accelerator.gather(batch) + batch = torch.index_select(batches, 0, top_indices) + # assert first element has the highest loss + assert torch.eq(top_losses[0], torch.max(red_losses)), "first element should have the highest loss" + print(f"new size of batch is {batch.shape}") + if args.compute_irred_losses: + # compute irreducible losses over the entire dataset and exit with torch.no_grad(): - out = model(batch, labels=batch, use_cache=False).loss - losses = accelerator.gather(loss.repeat(args.train_batch_size)) - irred_losses[step-1:args.train_batch_size*step] = losses + # we use reduction="none" in GPT2LMHeadModel loss implementation + loss = model(batch, labels=batch, use_cache=False).loss + loss = loss.view(batch.size(0), -1).mean(dim=1) + assert loss.shape == torch.Size( + [args.train_batch_size] + ), "make sure you are using GPT2LMHeadModel with reduction=none in the loss" + losses = accelerator.gather(loss) + try: + irred_losses[(step - 1) * samples_per_step : step * samples_per_step] = losses + except: + print( + f"Size mismatch, step {step} between current losses {losses.shape} and irreducible losses \ + {irred_losses[(step-1) * samples_per_step: step * samples_per_step].shape}" + ) + break + if step >= args.max_train_steps: + break continue - loss = model(batch, labels=batch, use_cache=False).loss - avg_loss = accelerator.gather(loss.repeat(args.train_batch_size_select)).mean() + + # model training + # TODO! 32 batch doesn't fit => split batch and add "proper" grad accumulation + # we are using reduction="none" in GPT2LMHeadModel loss => we add .mean() + avg_loss = model(batch, labels=batch, use_cache=False).loss.mean() + # no need to do accelerate gather we would just be repeating the loss (workers have the same batch) loss_tracking += avg_loss.item() / args.gradient_accumulation_steps log_metrics(step, {"samples": step * samples_per_step, "loss_per_step/train": loss.item()}) loss = loss / args.gradient_accumulation_steps @@ -321,7 +388,7 @@ def get_lr(): lr_scheduler.step() optimizer.zero_grad() elapsed_time = time.time() - t_start - tflops = compute_tflops(elapsed_time, accelerator, args) + tflops = compute_tflops(model, tokenizer, elapsed_time, accelerator, args) log_metrics( step, { @@ -342,15 +409,15 @@ def get_lr(): accelerator.wait_for_everyone() save_dir = os.path.join(args.save_dir, f"step_{step}") accelerator.save_state(save_dir) - if accelerator.is_main_process: - hf_repo.push_to_hub(commit_message=f"step {step}") model.train() if completed_steps >= args.max_train_steps: break # Save irred losses if args.compute_irred_losses: - torch.save(irred_losses, os.path.join(args.save_dir, args.irred_losses)) + if accelerator.is_main_process: + print(f"saving irred losses with shape: {irred_losses.shape}") + torch.save(irred_losses, "irred_losses.pt") exit() # Evaluate and save the last checkpoint From bbf8e4644a8a0384d34a1b9b9e6a2afc50130175 Mon Sep 17 00:00:00 2001 From: loubna Date: Mon, 12 Dec 2022 11:56:23 +0000 Subject: [PATCH 12/14] clean file and move some utility function to another file --- .../codeparrot/scripts/codeparrot_training.py | 139 ++---------------- 1 file changed, 16 insertions(+), 123 deletions(-) diff --git a/examples/research_projects/codeparrot/scripts/codeparrot_training.py b/examples/research_projects/codeparrot/scripts/codeparrot_training.py index ec59f19fcd..2b3f7857f7 100644 --- a/examples/research_projects/codeparrot/scripts/codeparrot_training.py +++ b/examples/research_projects/codeparrot/scripts/codeparrot_training.py @@ -1,10 +1,7 @@ -import logging import os import time from argparse import Namespace -from pathlib import Path -import datasets import torch from datasets import load_dataset from torch.optim import AdamW @@ -12,13 +9,12 @@ from torch.utils.data.dataloader import DataLoader from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe -import transformers from accelerate import Accelerator, DistributedType from arguments import TrainingArguments -from huggingface_hub import Repository from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed +from utils_rholoss import compute_save_control_examples, sanity_check_irred_losses +from utils_training import compute_tflops, evaluate, get_grouped_params, get_lr, setup_logging -from utils_roloss import sanity_check_irred_losses, compute_save_control_examples class ConstantLengthDataset(IterableDataset): """ @@ -95,42 +91,15 @@ def shuffle(self, buffer_size=1000): return ShufflerIterDataPipe(self, buffer_size=buffer_size) -def setup_logging(args): - project_name = args.model_ckpt.split("/")[-1] - logger = logging.getLogger(__name__) - log_dir = Path(args.save_dir) / "log/" - log_dir.mkdir(exist_ok=True) - filename = f"debug_{accelerator.process_index}.log" - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()], - ) - if accelerator.is_main_process: # we only want to setup logging once - accelerator.init_trackers(project_name, vars(args)) - run_name = accelerator.trackers[0].run.name - logger.setLevel(logging.INFO) - datasets.utils.logging.set_verbosity_info() - transformers.utils.logging.set_verbosity_info() - else: - run_name = "" - logger.setLevel(logging.ERROR) - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - return logger, run_name - def create_dataloaders(args): train_data = load_dataset(args.dataset_name_train, split="train", use_auth_token=True) - train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) + train_data = train_data.shuffle(seed=args.seed) if args.dataset_name_train == args.dataset_name_valid: # Split the dataset into train and validation - data = train_data.train_test_split( - test_size=0.005, shuffle=False, seed=args.seed - ) + data = train_data.train_test_split(test_size=0.005, shuffle=False, seed=args.seed) train_data = data["train"] valid_data = data["test"] - print(f"Size of train set: {len(valid_data)} and validation set {len(train_data)}") + print(f"Size of train set: {len(train_data)} and validation set {len(valid_data)}") else: valid_data = load_dataset(args.dataset_name_valid, split="train", use_auth_token=True) train_dataset = ConstantLengthDataset( @@ -140,65 +109,18 @@ def create_dataloaders(args): tokenizer, valid_data, infinite=False, seq_length=args.seq_length, tokenized=args.tokenized ) train_dataset = train_dataset.shuffle(buffer_size=args.shuffle_buffer) - # ShufflerIterDataPipe does not work in torch 1.13 hence shuffle=False - train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=False) + # ShufflerIterDataPipe does not work in torch > 1.11 + train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True) eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size) return train_dataloader, eval_dataloader -def get_grouped_params(model, args, no_decay=["bias", "ln_1.weight", "ln_2.weight", "ln_f.weight"]): - params_with_wd, params_without_wd = [], [] - for n, p in model.named_parameters(): - if any(nd in n for nd in no_decay): - params_without_wd.append(p) - else: - params_with_wd.append(p) - return [ - {"params": params_with_wd, "weight_decay": args.weight_decay}, - {"params": params_without_wd, "weight_decay": 0.0}, - ] - - def log_metrics(step, metrics): logger.info(f"Step {step}: {metrics}") if accelerator.is_main_process: accelerator.log(metrics, step) -def compute_tflops(elapsed_time, accelerator, args): - # TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf). - config_model = accelerator.unwrap_model(model).config - checkpoint_factor = 4 if args.gradient_checkpointing else 3 - batch_size = args.train_batch_size * accelerator.state.num_processes * args.gradient_accumulation_steps - factor = 24 * checkpoint_factor * batch_size * args.seq_length * config_model.n_layer * (config_model.n_embd**2) - flops_per_iteration = factor * ( - 1.0 - + (args.seq_length / (6.0 * config_model.n_embd)) - + (tokenizer.vocab_size / (16.0 * config_model.n_layer * config_model.n_embd)) - ) - tflops = flops_per_iteration / (elapsed_time * accelerator.state.num_processes * (10**12)) - return tflops - - -def evaluate(args): - model.eval() - losses = [] - for step, batch in enumerate(eval_dataloader): - with torch.no_grad(): - outputs = model(batch, labels=batch) - loss = outputs.loss.repeat(args.valid_batch_size) - losses.append(accelerator.gather(loss)) - if args.max_eval_steps > 0 and step >= args.max_eval_steps: - break - losses = torch.cat(losses) - loss = losses[: eval_dataloader.dataset.current_size].mean() - try: - perplexity = torch.exp(loss) - except OverflowError: - perplexity = float("inf") - return loss.item(), perplexity.item() - - # Settings parser = HfArgumentParser(TrainingArguments) args = parser.parse_args() @@ -211,23 +133,15 @@ def evaluate(args): samples_per_step = accelerator.state.num_processes * args.train_batch_size set_seed(args.seed) -# Clone model repository -if accelerator.is_main_process: - hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt) - # Logging -logger, run_name = setup_logging(args) +logger, run_name = setup_logging(accelerator, args) logger.info(accelerator.state) -# Checkout new branch on repo -if accelerator.is_main_process: - hf_repo.git_checkout(run_name, create_branch_ok=True) - # Load model and tokenizer -model = AutoModelForCausalLM.from_pretrained(args.save_dir) +model = AutoModelForCausalLM.from_pretrained(args.model_ckpt) if args.gradient_checkpointing: model.gradient_checkpointing_enable() -tokenizer = AutoTokenizer.from_pretrained(args.save_dir) +tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt) # Load dataset and dataloader train_dataloader, eval_dataloader = create_dataloaders(args) @@ -243,30 +157,11 @@ def evaluate(args): accelerator.register_for_checkpointing(lr_scheduler) -def get_lr(): - return optimizer.param_groups[0]["lr"] - - # Prepare everything with our `accelerator`. model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( model, optimizer, train_dataloader, eval_dataloader ) -# load in the weights and states from a previous save -if args.resume_from_checkpoint: - if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": - accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") - accelerator.load_state(args.resume_from_checkpoint) - path = os.path.basename(args.resume_from_checkpoint) - else: - # Get the most recent checkpoint - dirs = [f.name for f in os.scandir(args.save_dir) if f.is_dir() and "step" in str(f)] - dirs.sort(key=os.path.getctime) - path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last - # Extract the step of the checkpoint to continue from there - training_difference = os.path.splitext(path)[0] - resume_step = int(training_difference.replace("step_", "")) - # Train model model.train() completed_steps = 0 @@ -367,10 +262,10 @@ def get_lr(): # model training # TODO! 32 batch doesn't fit => split batch and add "proper" grad accumulation - # we are using reduction="none" in GPT2LMHeadModel loss => we add .mean() - avg_loss = model(batch, labels=batch, use_cache=False).loss.mean() + # we are using reduction="none" in GPT2LMHeadModel loss => we add .mean() over tokens + loss = model(batch, labels=batch, use_cache=False).loss.mean() # no need to do accelerate gather we would just be repeating the loss (workers have the same batch) - loss_tracking += avg_loss.item() / args.gradient_accumulation_steps + loss_tracking += loss.item() / args.gradient_accumulation_steps log_metrics(step, {"samples": step * samples_per_step, "loss_per_step/train": loss.item()}) loss = loss / args.gradient_accumulation_steps if step % args.gradient_accumulation_steps != 0: @@ -381,7 +276,7 @@ def get_lr(): else: accelerator.backward(loss) else: - lr = get_lr() + lr = get_lr(optimizer) accelerator.backward(loss) accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() @@ -404,7 +299,7 @@ def get_lr(): completed_steps += 1 if step % args.save_checkpoint_steps == 0: logger.info("Evaluating and saving model checkpoint") - eval_loss, perplexity = evaluate(args) + eval_loss, perplexity = evaluate(model, accelerator, eval_dataloader, args) log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity}) accelerator.wait_for_everyone() save_dir = os.path.join(args.save_dir, f"step_{step}") @@ -422,12 +317,10 @@ def get_lr(): # Evaluate and save the last checkpoint logger.info("Evaluating and saving model after training") -eval_loss, perplexity = evaluate(args) +eval_loss, perplexity = evaluate(model, accelerator, eval_dataloader, args) log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity}) accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save) save_dir = os.path.join(args.save_dir, f"step_{step}") accelerator.save_state(save_dir) -if accelerator.is_main_process: - hf_repo.push_to_hub(commit_message="final model") From 013e006dafb50f5d456dcaefc13acea49db658c2 Mon Sep 17 00:00:00 2001 From: loubna Date: Mon, 12 Dec 2022 11:57:57 +0000 Subject: [PATCH 13/14] reformat utils files --- .../codeparrot/scripts/utils_rholoss.py | 6 +- .../codeparrot/scripts/utils_training.py | 84 +++++++++++++++++++ 2 files changed, 86 insertions(+), 4 deletions(-) create mode 100644 examples/research_projects/codeparrot/scripts/utils_training.py diff --git a/examples/research_projects/codeparrot/scripts/utils_rholoss.py b/examples/research_projects/codeparrot/scripts/utils_rholoss.py index c5d9e19db7..cc8c6a7f1b 100644 --- a/examples/research_projects/codeparrot/scripts/utils_rholoss.py +++ b/examples/research_projects/codeparrot/scripts/utils_rholoss.py @@ -3,9 +3,7 @@ import torch -def get_control_examples( - accelerator, dataloader, train_batch_size, limit_control_examples=600, seq_length=1024 -): +def get_control_examples(accelerator, dataloader, train_batch_size, limit_control_examples=600, seq_length=1024): """get control examples for computing the irreducible losses over the entire dataset""" control_examples = torch.zeros((limit_control_examples, seq_length), device="cpu").type(torch.LongTensor) for step, batch in enumerate(dataloader, start=1): @@ -33,7 +31,7 @@ def compute_save_control_examples( ): """compute and save control examples for computing the irreducible losses over the entire dataset""" control_examples = get_control_examples( - accelerator, dataloader, train_batch_size, limit_control_examples, seq_length + accelerator, dataloader, train_batch_size, limit_control_examples, seq_length ) if accelerator.is_main_process: # save each losses and examples diff --git a/examples/research_projects/codeparrot/scripts/utils_training.py b/examples/research_projects/codeparrot/scripts/utils_training.py new file mode 100644 index 0000000000..632deb6b7a --- /dev/null +++ b/examples/research_projects/codeparrot/scripts/utils_training.py @@ -0,0 +1,84 @@ +import logging +from pathlib import Path + +import datasets +import torch + +import transformers + + +def setup_logging(accelerator, args): + project_name = args.model_ckpt.split("/")[-1] + logger = logging.getLogger(__name__) + log_dir = Path(args.save_dir) / "log/" + log_dir.mkdir(exist_ok=True) + filename = f"debug_{accelerator.process_index}.log" + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()], + ) + if accelerator.is_main_process: # we only want to setup logging once + accelerator.init_trackers(project_name, vars(args)) + run_name = accelerator.trackers[0].run.name + logger.setLevel(logging.INFO) + datasets.utils.logging.set_verbosity_info() + transformers.utils.logging.set_verbosity_info() + else: + run_name = "" + logger.setLevel(logging.ERROR) + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + return logger, run_name + + +def compute_tflops(model, tokenizer, elapsed_time, accelerator, args): + # TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf). + config_model = accelerator.unwrap_model(model).config + checkpoint_factor = 4 if args.gradient_checkpointing else 3 + batch_size = args.train_batch_size * accelerator.state.num_processes * args.gradient_accumulation_steps + factor = 24 * checkpoint_factor * batch_size * args.seq_length * config_model.n_layer * (config_model.n_embd**2) + flops_per_iteration = factor * ( + 1.0 + + (args.seq_length / (6.0 * config_model.n_embd)) + + (tokenizer.vocab_size / (16.0 * config_model.n_layer * config_model.n_embd)) + ) + tflops = flops_per_iteration / (elapsed_time * accelerator.state.num_processes * (10**12)) + return tflops + + +def get_grouped_params(model, args, no_decay=["bias", "ln_1.weight", "ln_2.weight", "ln_f.weight"]): + params_with_wd, params_without_wd = [], [] + for n, p in model.named_parameters(): + if any(nd in n for nd in no_decay): + params_without_wd.append(p) + else: + params_with_wd.append(p) + return [ + {"params": params_with_wd, "weight_decay": args.weight_decay}, + {"params": params_without_wd, "weight_decay": 0.0}, + ] + + +def get_lr(optimizer): + return optimizer.param_groups[0]["lr"] + + +def evaluate(model, accelerator, eval_dataloader, args): + model.eval() + losses = [] + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + outputs = model(batch, labels=batch) + loss = outputs.loss.repeat(args.valid_batch_size) + losses.append(accelerator.gather(loss)) + if args.max_eval_steps > 0 and step >= args.max_eval_steps: + break + losses = torch.cat(losses) + loss = losses[: eval_dataloader.dataset.current_size].mean() + try: + perplexity = torch.exp(loss) + except OverflowError: + perplexity = float("inf") + return loss.item(), perplexity.item() From 2887158b48e8661883d3ed0cfdaac4387e0ebc44 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Mon, 12 Dec 2022 21:53:02 +0400 Subject: [PATCH 14/14] Compute no grad RHO-Losses in multiple iterations --- .../codeparrot/scripts/codeparrot_training.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/research_projects/codeparrot/scripts/codeparrot_training.py b/examples/research_projects/codeparrot/scripts/codeparrot_training.py index 2b3f7857f7..393f7730fc 100644 --- a/examples/research_projects/codeparrot/scripts/codeparrot_training.py +++ b/examples/research_projects/codeparrot/scripts/codeparrot_training.py @@ -213,9 +213,15 @@ def log_metrics(step, metrics): elif args.selection_method == "rholoss": # in this setting global_batch_size = train_batch_size x nb_workers with torch.no_grad(): - # we use reduction="none" in GPT2LMHeadModel loss implementation (install transformers from a fork) - loss = model(batch, labels=batch, use_cache=False).loss - loss = loss.view(batch.size(0), -1).mean(dim=1) + losses = [] + # To avoid running OOM, compute the loss on a smaller batches + sub_batches = torch.split(batch, args.gradient_accumulation_steps) + for sub_batch in sub_batches: + # we use reduction="none" in GPT2LMHeadModel loss implementation (install transformers from a fork) + loss = model(sub_batch, labels=sub_batch, use_cache=False).loss + loss = loss.view(sub_batch.size(0), -1).mean(dim=1) + losses.append(loss) + losses = torch.cat(losses, dim=0) assert loss.shape == torch.Size( [args.train_batch_size] ), "make sure you are using GPT2LMHeadModel with reduction=none in the loss"