-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
At least I don't think so. Summing weights on the log scale feels off, and the LLMs are screaming about it. Here's the current function:
def dropout_weights(weights: torch.Tensor, p: float) -> torch.Tensor:
"""Apply dropout to the weights.
Args:
weights (torch.Tensor): Current weights in log space.
p (float): Probability of dropping weights.
Returns:
torch.Tensor: Weights after applying dropout.
"""
if p == 0:
return weights
total_weight = weights.sum()
mask = torch.rand_like(weights) < p
masked_weights = weights.clone()
masked_weights[mask] = 0
masked_weights = masked_weights / masked_weights.sum() * total_weight
return masked_weights
Here's Gemini's explanation: "The function rescales the modified weights so their new sum equals the original sum of logs. This means it's preserving the product of the original weights, which is not the standard or desired behavior for dropout. The goal of dropout scaling (inverted dropout) is to ensure that the expected sum of the weights remains the same, not the product."
Here's its proposed fix:
import torch
def correct_dropout_weights(log_weights: torch.Tensor, p: float) -> torch.Tensor:
"""
Correctly apply dropout to weights stored in log space.
Args:
log_weights (torch.Tensor): Current weights in log space.
p (float): Probability of dropping weights.
Returns:
torch.Tensor: Weights in log space after applying dropout.
"""
if p < 0 or p >= 1:
raise ValueError("dropout probability has to be in [0, 1)")
if p == 0:
return log_weights
# 1. Convert from log space to linear space
weights = torch.exp(log_weights)
# 2. Apply inverted dropout in linear space
# Create a mask of 0s and 1s and apply it.
mask = (torch.rand_like(weights) > p).float()
# Scale the remaining weights to maintain the same expected sum.
# This is the standard "inverted dropout" technique.
scaled_weights = mask * weights / (1 - p)
# 3. Convert back to log space
# Add a small epsilon for numerical stability where scaled_weights is 0.
# torch.log(0) is -inf, which correctly represents a dropped weight.
return torch.log(scaled_weights)
Metadata
Metadata
Assignees
Labels
No labels