From f8762d596ba188a4013f04b87fcceaae260f5129 Mon Sep 17 00:00:00 2001 From: motiwari Date: Wed, 1 May 2024 21:16:32 -0700 Subject: [PATCH] Partial review done --- adaptive_softmax/bandits_softmax.py | 10 +- adaptive_softmax/sftm.py | 159 +++++++++++++++------------- 2 files changed, 91 insertions(+), 78 deletions(-) diff --git a/adaptive_softmax/bandits_softmax.py b/adaptive_softmax/bandits_softmax.py index 6f5209d0..154454fa 100644 --- a/adaptive_softmax/bandits_softmax.py +++ b/adaptive_softmax/bandits_softmax.py @@ -18,6 +18,7 @@ def generate_weighted_permutation(weights: np.ndarray, gen=np.random.default_rng logits = np.log(weights) - np.log(np.sum(weights)) perturbed_logits = logits + gen.gumbel(size=logits.size) permutation = perturbed_logits.argsort()[::-1] + return permutation, logits, perturbed_logits class BanditsSoftmax: @@ -65,8 +66,9 @@ def __init__( query_importance_sampling=True, randomized_hadamard_transform=False, verbose=False, - seed=42): - + seed=42, + ): + # TODO(colin): Why are so many class members prefixed with underscores? assert len(A.shape) == 2, 'A must be a 2D array' self.n = A.shape[0] @@ -94,7 +96,7 @@ def __init__( self._permutation, self._logits, self._perturbed_logits = generate_weighted_permutation(self._atom_weights, gen=self._gen) q = (self._atom_weights / (np.sum(self._atom_weights)) )[np.newaxis, :] - q[q == 0 | np.isnan(q)] = 1 # NOTE 0-weight columns will never be selected + q[q == 0 | np.isnan(q)] = 1 # NOTE 0-weight columns will never be selected self._est_atom_sig2 = np.max(np.sum((self._A / q / self.d) ** 2 * q, axis=1)) self._est_query_sig2 = None self._sparse_columns = None @@ -111,8 +113,10 @@ def __init__( print(f'Query importance sampling: {self.query_importance_sampling}') print(f'Randomized Hadamard transform: {self.randomized_hadamard_transform}') print(f'Permutation:\n{self._permutation}') + if atom_importance_sampling: print(f'Atom weights:\n{self._atom_weights}') + if randomized_hadamard_transform: print(f'Columns 0-padded: {A.shape[1]} --> {self.d}') diff --git a/adaptive_softmax/sftm.py b/adaptive_softmax/sftm.py index dff2a9e6..19ffdcef 100644 --- a/adaptive_softmax/sftm.py +++ b/adaptive_softmax/sftm.py @@ -1,6 +1,6 @@ import numpy as np from typing import Tuple -from math import log, ceil, sqrt, exp +from math import log, ceil, sqrt from scipy.special import logsumexp, softmax from adaptive_softmax.bandits_softmax import BanditsSoftmax @@ -18,48 +18,46 @@ class SFTM: A : np.ndarray The atom matrix A of shape (n, d) for the matrix-vector multiplication. temperature : float, optional - The temperature parameter for the softmax function, by default 1.0. + The temperature parameter for the softmax function (default 1.0). multiplicative_error : float, optional - The multiplicative error parameter for the PAC guarantee, by default 3e-1. + The multiplicative error parameter for the PAC guarantee, epsilon (default 3e-1). failure_probability : float, optional - The failure probability parameter for the PAC guarantee, by default 1e-1. + The failure probability parameter for the PAC guarantee, delta (default 1e-1). noise_bound : float, optional - The noise bound parameter for entries of the matrix-vector multiplication, - by default None. + The noise bound parameter for entries of the matrix-vector multiplication (default None). fudge_pull : float, optional The multiplier for the number of pulls used in the bandits algorithm to - account for loose bounds, by default 1.0. + account for loose bounds (default 1.0). fudge_sigma2 : float, optional The multiplier for the variance used in the bandits algorithm to account - for loose bounds, by default 1.0. + for loose bounds (default 1.0). atom_importance_sampling : bool, optional - The flag to enable atom-based importance sampling in the bandits algorithm, - by default True. + The flag to enable atom-based importance sampling in the bandits algorithm (default True). query_importance_sampling : bool, optional - The flag to enable query-based importance sampling in the bandits algorithm, - by default True. + The flag to enable query-based importance sampling in the bandits algorithm (default True). randomized_hadamard_transform : bool, optional - The flag to enable randomized Hadamard transform of the atom matrix A + The flag to enable randomized Hadamard transform of the atom matrix A (default False) verbose : bool, optional - The flag to enable verbose output, by default False. + The flag to enable verbose output (default False). seed : int, optional - The seed for the random number generator used in the bandits algorithm, by - default 42. + The seed for the random number generator used in the bandits algorithm (default 42). """ - def __init__(self, - A: np.ndarray, - temperature: float = 1.0, - multiplicative_error: float = 3e-1, - failure_probability: float = 1e-1, - noise_bound: float = None, - fudge_pull: float = 1.0, - fudge_sigma2: float = 1.0, - atom_importance_sampling: bool = True, - query_importance_sampling: bool = True, - randomized_hadamard_transform: bool = False, - verbose: bool = False, - seed=42): + def __init__( + self, + A: np.ndarray, + temperature: float = 1.0, + multiplicative_error: float = 3e-1, + failure_probability: float = 1e-1, + noise_bound: float = None, + fudge_pull: float = 1.0, + fudge_sigma2: float = 1.0, + atom_importance_sampling: bool = True, + query_importance_sampling: bool = True, + randomized_hadamard_transform: bool = False, + verbose: bool = False, + seed=42 + ): self.A = A self.n = A.shape[0] self.d = A.shape[1] @@ -70,11 +68,11 @@ def __init__(self, self.verbose = verbose if self.verbose: - print(f"Initializing SFTM for a matrix of shape ({self.n}, {self.d})...") + print(f"Initializing SFTM for a matrix of shape ({self.n} x {self.d})...") print("Parameters:") - print(f"\t-temperature: {self.temperature}") - print(f"\t-multiplicative_error: {self.multiplicative_error}") - print(f"\t-failure_probability: {self.failure_probability}") + print(f"\t- temperature: {self.temperature}") + print(f"\t- multiplicative_error: {self.multiplicative_error}") + print(f"\t- failure_probability: {self.failure_probability}") self.bandits = BanditsSoftmax( A, @@ -92,19 +90,19 @@ def __init__(self, print("SFTM initialized.") print("") - def softmax(self, x: np.ndarray, k: int=1) -> Tuple[np.ndarray, np.ndarray]: + def softmax(self, x: np.ndarray, k: int = 1) -> Tuple[np.ndarray, np.ndarray]: """ Computes the true softmax, returning the top-k indices and the softmax. @param x: The query vector x of shape (d,). - @param k: The number of elements to return, by default 1. + @param k: The number of elements to return (default 1). @return: The top-k indices and the softmax. """ mu = (self.A @ x) * self.temperature top_k = np.sort(np.argpartition(mu, -k)[-k:]) return top_k, softmax(mu) - def adaptive_softmax(self, x: np.ndarray, k: int=1) -> Tuple[int, float]: + def adaptive_softmax(self, x: np.ndarray, k: int = 1) -> Tuple[int, float]: """ Computes the approximate softmax using the SFTM algorithm, returning the top-k indices, the approximate softmax for these indices, and the @@ -114,9 +112,9 @@ def adaptive_softmax(self, x: np.ndarray, k: int=1) -> Tuple[int, float]: Efficient Softmax Approximation." @param x: The query vector x of shape (d,). - @param k: The number of elements to return, by default 1. + @param k: The number of elements to return (default 1). @return: The top-k indices, the approximate softmax, and the normalizing - constant. + constant Z. """ if self.verbose: @@ -124,35 +122,38 @@ def adaptive_softmax(self, x: np.ndarray, k: int=1) -> Tuple[int, float]: self.bandits.set_query(x) - bta = self.temperature + beta = self.temperature eps = self.multiplicative_error - dlt = self.failure_probability + delta = self.failure_probability sig2 = self.noise_bound if self.noise_bound is not None else self.bandits.variance if self.verbose: print(f"Noise bound: {sig2}") - i_star_hat = self.best_arms(dlt/2, bta, sig2, k) + # TODO(@colin): Did we decide whether this should be delta/2 or delta? + i_star_hat = self.best_arms(delta/2, beta, sig2, k) + # TODO(@colin): if i_star_hat is wrong, won't mu_star_hat also be the wrong value? mu_star_hat = self.bandits.exact_values(i_star_hat) - log_S_hat = self.log_norm_estimation(bta, eps, dlt/2, sig2) + # TODO(@colin): Did we decide whether this should be delta/2 or delta? + log_S_hat = self.log_norm_estimation(beta, eps, delta/2, sig2) if self.verbose: print(f"Top-{k} arms: {i_star_hat}") print(f"Estimated logit values: {mu_star_hat}") print(f"Estimated log normalizing constant: {log_S_hat}") - return i_star_hat, np.exp(bta * mu_star_hat - log_S_hat), np.exp(log_S_hat) - - def best_arms(self, dlt: float, bta: float, sig2: float, k: int) -> np.ndarray: + return i_star_hat, np.exp(beta * mu_star_hat - log_S_hat), np.exp(log_S_hat) + + def best_arms(self, delta: float, beta: float, sig2: float, k: int) -> np.ndarray: """ Finds the top-k arms with the highest estimated logit values. This method uses a round-based PAC bandits algorithm based on Algorithm 3 from the paper, "Distributed Exploration in Multi-Armed Bandits" by Hillel et al. (2013). - - @param dlt: The failure probability parameter. - @param bta: The temperature parameter. + + @param delta: The failure probability parameter. + @param beta: The temperature parameter. @param sig2: The noise bound parameter. @param k: The number of arms to return. @return: The top-k arms with the highest estimated logit values. @@ -162,7 +163,7 @@ def best_arms(self, dlt: float, bta: float, sig2: float, k: int) -> np.ndarray: n = self.n d = self.bandits.max_pulls - T0 = int(ceil(17 * (bta ** 2) * sig2 * log(6 * n / dlt))) + T0 = int(ceil(17 * (beta ** 2) * sig2 * log(6 * n / delta))) if self.verbose: print(f"Initial number of pulls: {T0}") @@ -170,17 +171,23 @@ def best_arms(self, dlt: float, bta: float, sig2: float, k: int) -> np.ndarray: # initialize parameters confidence_set = np.arange(n) num_pulls = T0 - estimates = np.zeros(n) - - while True: - + keep_pulling = True + while keep_pulling is True: # pull arms and update confidence interval estimates = self.bandits.batch_pull(confidence_set, it=fpc(num_pulls, d)) - confidence_interval = sqrt(2 * sig2 * log(6 * n * log(d) / dlt) / num_pulls) + confidence_interval = sqrt(2 * sig2 * log(6 * n * log(d) / delta) / num_pulls) # update confidence set keep = estimates >= np.max(estimates) - confidence_interval + # TODO(@colin): I don't think this is exactly correct. It may be the case that an arm is + # removed at some point, but then np.max(estimates) moves down and the arm gets added back later. + # The current implementation would say that arm has been pulled num_pulls times, but it's been pulled + # fewer times. For this reason, I think it's actually better to make num_pulls an *array* of how many + # times each arm has been pulled, and then update the confidence interval for each arm separately according + # to its number of pulls. This is how we did it in several other projects, see BanditPAM, FastForest, or BanditMIPS + # for examples. + if self.verbose: print(f"Number of pulls: {num_pulls}") print(f"FPC-adjusted number of pulls: {fpc(num_pulls, d)}") @@ -190,6 +197,7 @@ def best_arms(self, dlt: float, bta: float, sig2: float, k: int) -> np.ndarray: # check stopping condition if np.sum(keep) <= k or fpc(num_pulls, d) >= d: + keep_pulling = False break # update parameters @@ -198,7 +206,7 @@ def best_arms(self, dlt: float, bta: float, sig2: float, k: int) -> np.ndarray: return confidence_set[np.argsort(estimates)[-k:]] - def estimate_arm_logits(self, arms: np.ndarray, bta: float, eps: float, dlt: float, sig2: float) -> np.ndarray: + def estimate_arm_logits(self, arms: np.ndarray, beta: float, eps: float, delta: float, sig2: float) -> np.ndarray: """ Estimates the logit values of the specified arms with PAC guarantees. @@ -207,29 +215,30 @@ def estimate_arm_logits(self, arms: np.ndarray, bta: float, eps: float, dlt: flo paper. @param arms: The indices of the arms to estimate. - @param bta: The temperature parameter. - @param eps: The multiplicative error parameter. - @param dlt: The failure probability parameter. + @param beta: The temperature parameter. + @param eps: The multiplicative error parameter. + @param delta: The failure probability parameter. @param sig2: The noise bound parameter. @return: The estimated logit values of the specified arms. """ if self.verbose: print(f"Estimating logit values for arms {arms}...") + d = self.bandits.max_pulls - T = int(ceil(32 * (sig2) * (bta ** 2) * log(2 / dlt) / (eps ** 2))) + T = int(ceil(32 * sig2 * (beta ** 2) * log(2 / delta) / (eps ** 2))) return self.bandits.pull(arms, its=np.array(fpc(T, d))) - - def log_norm_estimation(self, bta: float, eps: float, dlt: float, sig2: float) -> float: + + def log_norm_estimation(self, beta: float, eps: float, delta: float, sig2: float) -> float: """ - Estimates the log normalizing constant of the softmax function with PAC + Estimates the log normalizing constant of the softmax function with PAC guarantees. This method is based on Algorithm 2 of the paper, "Adaptive Sampling for Efficient Softmax Approximation." - @param bta: The temperature parameter. + @param beta: The temperature parameter. @param eps: The multiplicative error parameter. - @param dlt: The failure probability parameter. + @param delta: The failure probability parameter. @param sig2: The noise bound parameter. @return: The estimated log normalizing constant of the softmax function. """ @@ -237,29 +246,29 @@ def log_norm_estimation(self, bta: float, eps: float, dlt: float, sig2: float) - n = self.n d = self.bandits.max_pulls - T0 = int(ceil(17 * (bta ** 2) * sig2 * log(6 * n / dlt))) - C = np.sqrt(2 * sig2 * log(6 * n / dlt) / T0) + T0 = int(ceil(17 * (beta ** 2) * sig2 * log(6 * n / delta))) + C = np.sqrt(2 * sig2 * log(6 * n / delta) / T0) if self.verbose: print("Estimating log normalizing constant of the softmax function...") print(f"Initial number of pulls: {T0}") print(f"Confidence interval constant: {C}") - + # initial estimates mu_hat = self.bandits.pull(np.arange(n), its=np.full(shape=n, fill_value=fpc(T0, d))) if self.verbose: print(f"Initial estimates: {mu_hat}") - log_alpha = bta * (mu_hat - C) - log_gamma = bta * (mu_hat - C) / 2 + log_alpha = beta * (mu_hat - C) + log_gamma = beta * (mu_hat - C) / 2 log_alpha_sum = logsumexp(log_alpha) log_gamma_sum = logsumexp(log_gamma) # adapt sample sizes based on initial estimates - log_b = log(17 * (bta ** 2) * sig2 * log(6 * n / dlt)) - log_c = log(16 * sqrt(2) * sig2 * log(6 * n / dlt) / eps) + 2 * log_gamma_sum - log_alpha_sum - log_d = log(16 * sig2 * log(12 / dlt) / (eps ** 2)) + log_b = log(17 * (beta ** 2) * sig2 * log(6 * n / delta)) + log_c = log(16 * sqrt(2) * sig2 * log(6 * n / delta) / eps) + 2 * log_gamma_sum - log_alpha_sum + log_d = log(16 * sig2 * log(12 / delta) / (eps ** 2)) it = np.exp(log_b) it = np.maximum(it, np.exp(log_c + log_gamma - log_gamma_sum)) @@ -269,12 +278,12 @@ def log_norm_estimation(self, bta: float, eps: float, dlt: float, sig2: float) - if self.verbose: print(f"Adaptive sample sizes: {it}") - # make updated estimates + # make updated estimates mu_hat = self.bandits.pull(np.arange(n), its=fpc(it, d)) if self.verbose: print(f"Updated estimates: {mu_hat}") - print(f"Estimated log normalizing constant: {logsumexp(bta * mu_hat)}") + print(f"Estimated log normalizing constant: {logsumexp(beta * mu_hat)}") - return logsumexp(bta * mu_hat) + return logsumexp(beta * mu_hat) \ No newline at end of file