Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions distributed_training/averaging/avg_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,17 @@ async def run_validator_allreduce(
)

# Update state_avgs main params with inner optimizer params
self.update_main_param_after_outer_step()
self.update_local_model_after_outer_step()

# Zero grads of outer optimizer
self.state_averager.optimizer.zero_grad()

bt.logging.info(
":white_heavy_check_mark: Finished Outer Optimizer Step."
)

# Clip gradients again
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)

# Validate weight updates
await self._validate_weight_update(initial_weights, block)
Expand Down Expand Up @@ -392,11 +398,19 @@ async def run_miner_allreduce(
self.state_averager.step(
increment_epoch=True, optimizer_step=True, zero_grad=False
)
self.update_main_param_after_outer_step()

# Update state_avgs main params with inner optimizer params
self.update_local_model_after_outer_step()

# Zero grads of outer optimizer
self.state_averager.optimizer.zero_grad()

bt.logging.info(
":white_heavy_check_mark: Finished Outer Optimizer Step."
)

# Clip gradients again
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)

# Validate weight updates
await self._validate_weight_update(initial_weights, block)
Expand All @@ -419,14 +433,14 @@ async def run_miner_allreduce(
bt.logging.success("Averaging Round Finished Succesfully")
return synapse

def update_main_param_after_outer_step(self):
def update_local_model_after_outer_step(self):
"""Update the main parameters with the inner optimizer step"""
opt_parameters = [
param
for group in self.inner_optimizer.param_groups
for param in group["params"]
]
for main_param, opt_param in zip(
self.state_averager.main_parameters, opt_parameters
for local_model_param, avg_param in zip(
opt_parameters, self.state_averager.main_parameters
):
main_param.data.copy_(opt_param.data, non_blocking=True)
local_model_param.data.copy_(avg_param.data.to(local_model_param.device), non_blocking=True)
2 changes: 1 addition & 1 deletion distributed_training/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def add_args(cls, parser, prefix=None):
"--neuron.local_batch_size_train_effective",
type=int,
help="Amount of micro batches for gradient accumulation",
default=2048,
default=512,
)

parser.add_argument(
Expand Down
9 changes: 1 addition & 8 deletions distributed_training/utils/state_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import hivemind
import psutil
import torch
from memory_profiler import profile
from datetime import datetime

from hivemind.compression import deserialize_torch_tensor
Expand All @@ -32,11 +31,7 @@
scan_cache_dir,
upload_folder,
)
from huggingface_hub.utils import (
HfHubHTTPError,
RepositoryNotFoundError,
EntryNotFoundError,
)
from huggingface_hub.utils import HfHubHTTPError
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -547,8 +542,6 @@ def load_model_optimizer_gradient_averager(
self.device,
)

self.scaler = torch.amp.GradScaler(enabled=True)

if (self.local_progress.inner_step != 0) and ("." in revision):
self.state_averager.reset_main_parameters(
model_name,
Expand Down
7 changes: 3 additions & 4 deletions neurons/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def _process_training_batch(self, dataset):
outputs = self.model(input_ids=inputs, labels=labels)
loss = outputs[1] / self.number_of_local_steps

self.scaler.scale(loss).backward()
loss.backward()

self.running_loss += loss.item() * self.number_of_local_steps
self.batch_count += 1
Expand Down Expand Up @@ -581,9 +581,8 @@ def _process_training_batch(self, dataset):

def inner_optimizer_step(self):
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.scaler.unscale_(optimizer=self.inner_optimizer)
self.scaler.step(self.inner_optimizer)
self.scaler.update()

self.inner_optimizer.step()

self.scheduler.step()

Expand Down