Skip to content

dropout does not handle weights in log-space correctly #62

@baogorek

Description

@baogorek

At least I don't think so. Summing weights on the log scale feels off, and the LLMs are screaming about it. Here's the current function:

    def dropout_weights(weights: torch.Tensor, p: float) -> torch.Tensor:
        """Apply dropout to the weights.

        Args:
            weights (torch.Tensor): Current weights in log space.
            p (float): Probability of dropping weights.

        Returns:
            torch.Tensor: Weights after applying dropout.
        """
        if p == 0:
            return weights
        total_weight = weights.sum()
        mask = torch.rand_like(weights) < p
        masked_weights = weights.clone()
        masked_weights[mask] = 0
        masked_weights = masked_weights / masked_weights.sum() * total_weight
        return masked_weights

Here's Gemini's explanation: "The function rescales the modified weights so their new sum equals the original sum of logs. This means it's preserving the product of the original weights, which is not the standard or desired behavior for dropout. The goal of dropout scaling (inverted dropout) is to ensure that the expected sum of the weights remains the same, not the product."

Here's its proposed fix:

import torch

def correct_dropout_weights(log_weights: torch.Tensor, p: float) -> torch.Tensor:
    """
    Correctly apply dropout to weights stored in log space.

    Args:
        log_weights (torch.Tensor): Current weights in log space.
        p (float): Probability of dropping weights.

    Returns:
        torch.Tensor: Weights in log space after applying dropout.
    """
    if p < 0 or p >= 1:
        raise ValueError("dropout probability has to be in [0, 1)")

    if p == 0:
        return log_weights

    # 1. Convert from log space to linear space
    weights = torch.exp(log_weights)

    # 2. Apply inverted dropout in linear space
    # Create a mask of 0s and 1s and apply it.
    mask = (torch.rand_like(weights) > p).float()
    
    # Scale the remaining weights to maintain the same expected sum.
    # This is the standard "inverted dropout" technique.
    scaled_weights = mask * weights / (1 - p)

    # 3. Convert back to log space
    # Add a small epsilon for numerical stability where scaled_weights is 0.
    # torch.log(0) is -inf, which correctly represents a dropped weight.
    return torch.log(scaled_weights)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions