From fe95da3e72055dfc4e2dd6036772283e7b8d93a2 Mon Sep 17 00:00:00 2001 From: tslwn Date: Sat, 28 Dec 2024 13:10:33 +0000 Subject: [PATCH 1/2] handle other activation types --- .../activations/activations_computation.py | 71 +++++++++++++++++-- .../activations/exemplars_wrapper.py | 31 +++++++- project/expgen/.gitignore | 1 + project/expgen/scripts/compute_exemplars.py | 23 ++++-- 4 files changed, 113 insertions(+), 13 deletions(-) create mode 100644 project/expgen/.gitignore diff --git a/lib/activations/activations/activations_computation.py b/lib/activations/activations/activations_computation.py index 093633c..cb6328a 100644 --- a/lib/activations/activations/activations_computation.py +++ b/lib/activations/activations/activations_computation.py @@ -4,24 +4,83 @@ from util.subject import Subject -def get_activations_computing_func(subject: Subject, activation_type: Literal["MLP"], layer: int): +# TODO(timl): reduce duplication +def get_activations_computing_func( + subject: Subject, + # TODO(timl): move to enum? + activation_type: Literal["resid", "mlp_in", "mlp_out", "attn_out", "neurons"], + layer: int, +): """ Returns a function that computes activations for a given input: input_ids: torch.Tensor attn_mask: torch.Tensor """ - if activation_type == "MLP": - mlp_acts_for_layer = subject.w_outs[layer] + if activation_type == "resid": + acts_for_layer = subject.layers[layer] - def get_mlp_activations(input_ids: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: + def get_activations(input_ids: torch.Tensor, attn_mask: torch.Tensor): with torch.no_grad(): with subject.model.trace( {"input_ids": input_ids, "attention_mask": attn_mask} # type: ignore ): - acts = mlp_acts_for_layer.input.save() + acts = acts_for_layer.output.save() return acts - return get_mlp_activations + return get_activations + + if activation_type == "mlp_in": + acts_for_layer = subject.mlps[layer] + + def get_activations(input_ids: torch.Tensor, attn_mask: torch.Tensor): + with torch.no_grad(): + with subject.model.trace( + {"input_ids": input_ids, "attention_mask": attn_mask} # type: ignore + ): + acts = acts_for_layer.input.save() + return acts + + return get_activations + + if activation_type == "mlp_out": + acts_for_layer = subject.mlps[layer] + + def get_activations(input_ids: torch.Tensor, attn_mask: torch.Tensor): + with torch.no_grad(): + with subject.model.trace( + {"input_ids": input_ids, "attention_mask": attn_mask} # type: ignore + ): + acts = acts_for_layer.output.save() + return acts + + return get_activations + + if activation_type == "attn_out": + acts_for_layer = subject.attns[layer] + + def get_activations(input_ids: torch.Tensor, attn_mask: torch.Tensor): + with torch.no_grad(): + with subject.model.trace( + {"input_ids": input_ids, "attention_mask": attn_mask} # type: ignore + ): + acts = acts_for_layer.output[0].save() + return acts + + return get_activations + + if activation_type == "neurons": + acts_for_layer = subject.w_outs[layer] + + def get_activations(input_ids: torch.Tensor, attn_mask: torch.Tensor): + with torch.no_grad(): + with subject.model.trace( + {"input_ids": input_ids, "attention_mask": attn_mask} # type: ignore + ): + acts = acts_for_layer.input.save() + return acts + + return get_activations + else: raise ValueError(f"Unknown activation type: {activation_type}") diff --git a/lib/activations/activations/exemplars_wrapper.py b/lib/activations/activations/exemplars_wrapper.py index 9aa323b..5acceb4 100644 --- a/lib/activations/activations/exemplars_wrapper.py +++ b/lib/activations/activations/exemplars_wrapper.py @@ -321,7 +321,14 @@ class ExemplarConfig(BaseModel): batch_size: int = 512 rand_seqs: int = 10 seed: int = 64 - activation_type: Literal["MLP"] = "MLP" + # TODO(timl): move to enum? + activation_type: Literal[ + "resid", + "mlp_in", + "mlp_out", + "attn_out", + "neurons", + ] = "neurons" class ExemplarsWrapper: @@ -347,6 +354,8 @@ def __init__( if subject.is_chat_model: folder_name_components.append("chat") folder_name_components.append(f"{config.seq_len}seqlen") + if config.activation_type != "neurons": + folder_name_components.append(config.activation_type) assert subject.tokenizer.padding_side == "left" folder_name = "_".join(folder_name_components) @@ -430,8 +439,16 @@ def load_layer_checkpoint(self, layer: int, split: ExemplarSplit) -> ( ExemplarSplit.RANDOM_TEST, ) - if self.config.activation_type == "MLP": + # TODO(timl): move to function + if self.config.activation_type == "neurons": num_features = self.subject.I + elif self.config.activation_type in ( + "resid", + "mlp_in", + "mlp_out", + "attn_out", + ): + num_features = self.subject.D else: raise ValueError(f"Invalid activation type: {self.config.activation_type}") num_top_feats_to_save = self.config.num_top_acts_to_save @@ -496,8 +513,16 @@ def save_layer_checkpoint( layer_dir = self.get_layer_dir(layer, split) os.makedirs(layer_dir, exist_ok=True) - if self.config.activation_type == "MLP": + # TODO(timl): move to function + if self.config.activation_type == "neurons": num_features = self.subject.I + elif self.config.activation_type in ( + "resid", + "mlp_in", + "mlp_out", + "attn_out", + ): + num_features = self.subject.D else: raise ValueError(f"Invalid activation type: {self.config.activation_type}") num_top_feats_to_save = self.config.num_top_acts_to_save diff --git a/project/expgen/.gitignore b/project/expgen/.gitignore new file mode 100644 index 0000000..8fce603 --- /dev/null +++ b/project/expgen/.gitignore @@ -0,0 +1 @@ +data/ diff --git a/project/expgen/scripts/compute_exemplars.py b/project/expgen/scripts/compute_exemplars.py index 7f24b64..4d3c072 100644 --- a/project/expgen/scripts/compute_exemplars.py +++ b/project/expgen/scripts/compute_exemplars.py @@ -4,8 +4,9 @@ """ import argparse +from typing import Any -from activations.dataset import fineweb_dset_config, lmsys_dset_config +from activations.dataset import HFDatasetWrapperConfig, fineweb_dset_config, lmsys_dset_config from activations.exemplars import ExemplarSplit from activations.exemplars_computation import ( compute_exemplars_for_layer, @@ -15,12 +16,25 @@ from util.subject import Subject, get_subject_config parser = argparse.ArgumentParser() +parser.add_argument( + "--activation_type", + type=str, + # TODO(timl): move to enum? + choices=[ + "resid", + "mlp_in", + "mlp_out", + "attn_out", + "neurons", + ], + help="Type of activations from which we pick indices to compute exemplars for.", +) parser.add_argument( "--layer_indices", type=int, nargs="+", default=None, - help="Layers from which we pick neurons to compute exemplars for.", + help="Layers from which we pick indices to compute exemplars for.", ) parser.add_argument( "--subject_hf_model_id", @@ -87,7 +101,7 @@ subject_config = get_subject_config(args.subject_hf_model_id) subject = Subject(subject_config, nnsight_lm_kwargs={"dispatch": True}) -hf_dataset_configs = [] +hf_dataset_configs: list[HFDatasetWrapperConfig] = [] for hf_dataset in args.hf_datasets: if hf_dataset == "fineweb": hf_dataset_configs.append(fineweb_dset_config) @@ -106,13 +120,14 @@ num_top_acts_to_save=args.num_top_acts_to_save, batch_size=args.batch_size, seed=args.seed, + activation_type=args.activation_type, ) exemplars_wrapper = ExemplarsWrapper(args.data_dir, exemplar_config, subject) layer_indices = args.layer_indices if args.layer_indices else range(subject.L) for layer in layer_indices: print(f"============ Layer {layer} ============") - kwargs = { + kwargs: dict[str, Any] = { "exemplars_wrapper": exemplars_wrapper, "layer": layer, "split": ExemplarSplit(args.split), From 309ba42f38aa8c1e826a2e3060009b6d5448f606 Mon Sep 17 00:00:00 2001 From: tslwn Date: Mon, 30 Dec 2024 12:25:24 +0000 Subject: [PATCH 2/2] add enum and refactor comp funcs --- .../activations/activations_computation.py | 132 ++++++++---------- .../activations/exemplars_wrapper.py | 49 +++---- project/expgen/scripts/compute_exemplars.py | 13 +- 3 files changed, 82 insertions(+), 112 deletions(-) diff --git a/lib/activations/activations/activations_computation.py b/lib/activations/activations/activations_computation.py index cb6328a..a431aee 100644 --- a/lib/activations/activations/activations_computation.py +++ b/lib/activations/activations/activations_computation.py @@ -1,86 +1,70 @@ -from typing import Literal +from enum import Enum +from typing import Callable import torch +from nnsight.envoy import Envoy # type: ignore +from nnsight.intervention import InterventionProxy # type: ignore from util.subject import Subject -# TODO(timl): reduce duplication +class ActivationType(str, Enum): + RESID = "resid" + MLP_IN = "mlp_in" + MLP_OUT = "mlp_out" + ATTN_OUT = "attn_out" + NEURONS = "neurons" + + +def _get_activations_funcs( + subject: Subject, activation_type: ActivationType, layer: int +) -> tuple[Callable[[], Envoy], Callable[[Envoy], InterventionProxy]]: + if activation_type == ActivationType.RESID: + return ( + lambda: subject.layers[layer], + lambda component: component.output[0], + ) + if activation_type == ActivationType.MLP_IN: + return ( + lambda: subject.mlps[layer], + lambda component: component.input, + ) + if activation_type == ActivationType.MLP_OUT: + return ( + lambda: subject.mlps[layer], + lambda component: component.output, + ) + if activation_type == ActivationType.ATTN_OUT: + return ( + lambda: subject.attns[layer], + lambda component: component.output[0], + ) + if activation_type == ActivationType.NEURONS: + return ( + lambda: subject.w_outs[layer], + lambda component: component.input, + ) + raise ValueError(f"Unknown activation type: {activation_type}") + + def get_activations_computing_func( - subject: Subject, - # TODO(timl): move to enum? - activation_type: Literal["resid", "mlp_in", "mlp_out", "attn_out", "neurons"], - layer: int, -): + subject: Subject, activation_type: ActivationType, layer: int +) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]: """ Returns a function that computes activations for a given input: input_ids: torch.Tensor attn_mask: torch.Tensor """ - if activation_type == "resid": - acts_for_layer = subject.layers[layer] - - def get_activations(input_ids: torch.Tensor, attn_mask: torch.Tensor): - with torch.no_grad(): - with subject.model.trace( - {"input_ids": input_ids, "attention_mask": attn_mask} # type: ignore - ): - acts = acts_for_layer.output.save() - return acts - - return get_activations - - if activation_type == "mlp_in": - acts_for_layer = subject.mlps[layer] - - def get_activations(input_ids: torch.Tensor, attn_mask: torch.Tensor): - with torch.no_grad(): - with subject.model.trace( - {"input_ids": input_ids, "attention_mask": attn_mask} # type: ignore - ): - acts = acts_for_layer.input.save() - return acts - - return get_activations - - if activation_type == "mlp_out": - acts_for_layer = subject.mlps[layer] - - def get_activations(input_ids: torch.Tensor, attn_mask: torch.Tensor): - with torch.no_grad(): - with subject.model.trace( - {"input_ids": input_ids, "attention_mask": attn_mask} # type: ignore - ): - acts = acts_for_layer.output.save() - return acts - - return get_activations - - if activation_type == "attn_out": - acts_for_layer = subject.attns[layer] - - def get_activations(input_ids: torch.Tensor, attn_mask: torch.Tensor): - with torch.no_grad(): - with subject.model.trace( - {"input_ids": input_ids, "attention_mask": attn_mask} # type: ignore - ): - acts = acts_for_layer.output[0].save() - return acts - - return get_activations - - if activation_type == "neurons": - acts_for_layer = subject.w_outs[layer] - - def get_activations(input_ids: torch.Tensor, attn_mask: torch.Tensor): - with torch.no_grad(): - with subject.model.trace( - {"input_ids": input_ids, "attention_mask": attn_mask} # type: ignore - ): - acts = acts_for_layer.input.save() - return acts - - return get_activations - - else: - raise ValueError(f"Unknown activation type: {activation_type}") + get_component, get_activations = _get_activations_funcs(subject, activation_type, layer) + + def activations_computing_func( + input_ids: torch.Tensor, attn_mask: torch.Tensor + ) -> torch.Tensor: + with torch.no_grad(): + with subject.model.trace( + {"input_ids": input_ids, "attention_mask": attn_mask} # type: ignore + ): + acts: torch.Tensor = get_activations(get_component()).save() # type: ignore + return acts + + return activations_computing_func diff --git a/lib/activations/activations/exemplars_wrapper.py b/lib/activations/activations/exemplars_wrapper.py index 5acceb4..294f3d8 100644 --- a/lib/activations/activations/exemplars_wrapper.py +++ b/lib/activations/activations/exemplars_wrapper.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import numpy as np from activations.activations import ActivationRecord +from activations.activations_computation import ActivationType from activations.dataset import ( ChatDataset, HFDatasetWrapper, @@ -321,14 +322,7 @@ class ExemplarConfig(BaseModel): batch_size: int = 512 rand_seqs: int = 10 seed: int = 64 - # TODO(timl): move to enum? - activation_type: Literal[ - "resid", - "mlp_in", - "mlp_out", - "attn_out", - "neurons", - ] = "neurons" + activation_type: ActivationType = ActivationType.NEURONS class ExemplarsWrapper: @@ -439,18 +433,7 @@ def load_layer_checkpoint(self, layer: int, split: ExemplarSplit) -> ( ExemplarSplit.RANDOM_TEST, ) - # TODO(timl): move to function - if self.config.activation_type == "neurons": - num_features = self.subject.I - elif self.config.activation_type in ( - "resid", - "mlp_in", - "mlp_out", - "attn_out", - ): - num_features = self.subject.D - else: - raise ValueError(f"Invalid activation type: {self.config.activation_type}") + num_features = self.num_features num_top_feats_to_save = self.config.num_top_acts_to_save k, seq_len = self.config.k, self.config.seq_len @@ -513,18 +496,7 @@ def save_layer_checkpoint( layer_dir = self.get_layer_dir(layer, split) os.makedirs(layer_dir, exist_ok=True) - # TODO(timl): move to function - if self.config.activation_type == "neurons": - num_features = self.subject.I - elif self.config.activation_type in ( - "resid", - "mlp_in", - "mlp_out", - "attn_out", - ): - num_features = self.subject.D - else: - raise ValueError(f"Invalid activation type: {self.config.activation_type}") + num_features = self.num_features num_top_feats_to_save = self.config.num_top_acts_to_save k, seq_len = self.config.k, self.config.seq_len @@ -908,6 +880,19 @@ def visualize_neuron_exemplars( ) display(HTML(html_content)) # type: ignore + @property + def num_features(self) -> int: + if self.config.activation_type == ActivationType.NEURONS: + return self.subject.I + if self.config.activation_type in ( + ActivationType.RESID, + ActivationType.MLP_IN, + ActivationType.MLP_OUT, + ActivationType.ATTN_OUT, + ): + return self.subject.D + raise ValueError(f"Invalid activation type: {self.config.activation_type}") + ################### # Example Configs # diff --git a/project/expgen/scripts/compute_exemplars.py b/project/expgen/scripts/compute_exemplars.py index 4d3c072..ae017f3 100644 --- a/project/expgen/scripts/compute_exemplars.py +++ b/project/expgen/scripts/compute_exemplars.py @@ -6,6 +6,7 @@ import argparse from typing import Any +from activations.activations_computation import ActivationType from activations.dataset import HFDatasetWrapperConfig, fineweb_dset_config, lmsys_dset_config from activations.exemplars import ExemplarSplit from activations.exemplars_computation import ( @@ -19,14 +20,14 @@ parser.add_argument( "--activation_type", type=str, - # TODO(timl): move to enum? choices=[ - "resid", - "mlp_in", - "mlp_out", - "attn_out", - "neurons", + ActivationType.RESID, + ActivationType.MLP_IN, + ActivationType.MLP_OUT, + ActivationType.ATTN_OUT, + ActivationType.NEURONS, ], + default="neurons", help="Type of activations from which we pick indices to compute exemplars for.", ) parser.add_argument(