Towards a Probabilistic Disentanglement of Transformer Activations Part 2

This is the second post in our series dedicated to exploring methods around dictionary learning, and the possibility of a probabilistic take on the disentanglement of transformer activations into interpretable monosemantic features.

Building from the first post, we introduce some of the most relevant behaviors around high-dimensional neural-like data and the training dynamics of a sparsity-penalized Autoencoder with an overcomplete basis.

Replication code is based on the GitHub repo GitHub, but with important modifications to ensure consistent replication and logging of the experiments.

Setup of the experiment

In this post, we will keep working with a synthetic dataset of neural-like activations, generated in a similar fashion as the first post in the series.

Taking as a reference the Conjecture initial article, we will use 2 different dataset generation methodologies.

1) Generate the ground truth features with decay in their probability occurrences, this will replicate a property that real activations likely have. This means some features will probably be way more likely than others.

2) Generate the ground truth features with decay and correlation. This tries to extend the behavior of the first data generation method to include another property that real activations probably have. This means that some features co-occur with a higher probability than others.

Parameters of the experiments

Parameter Value
Activation Dimensionality 256
Number of ground truth features 512
Feature probability Decay 0.99
Average number of features per sample 5

All the experiments will be performed on both data generation methodologies.

Dataset Exploration

To introduce the datasets generated, we created some visualizations to get intuitions on the effect of some of the parameters.

When working with high-dimensional data, it’s important to generate helpful and understandable visualizations.

We used t-SNE (t-distributed stochastic neighbor embedding) to project down to two dimensions the 256-dimension dataset.

With this, we mainly wanted to explore the variation of the feature probability decay and the correlation, or lack thereof, between the ground truth features.

Uncorrelated features

We can see how, as we increase the decay, the data points tend to evenly distribute in the plane.

This is because, as the features have equal probabilities, there is no way of embedding their distances in a more compressed way. The small agglomerations of points in the later plot should not be confused with single features. They are superpositions of randomly close features that happen to co-occur.

Correlated features

In the case of correlated ground truth components, we can also observe how, as we increase the feature probability decay, the points more evenly distribute in the plane.

An interesting phenomenon can be observed, clearly distinct from the uncorrelated case. This is the formation of clearly defined clusters in the plane.

This can only be attributed to the correlation of ground truth features; this correlation results in the formation of clusters that efficiently embed the presence of predominantly correlated features in a set of datapoints.

We can see how this phenomenon evolves through the increase of the decay. We observe how the clusters become more compact and frequent as we increase the feature decay. This is because the presence of really common features acts as a way of reducing the noise in the encoding process.

Baseline Methods

To assess the performance of the disentanglement process of the SAEs (Sparse Autoencoders), we define and test some widely used methods for disentanglement that have trivial implementations and do not require costly training.

We use the MMCS metric defined in the previous post to assess the performance of these baseline methods.

Performance

Method MMCS Feature Correlation
KMeans 0.56 No
KMedoids 0.21 No
PCA 0.17 No
ICA 0.17 No
KMeans 0.59 Yes
KMedoids 0.45 Yes
PCA 0.18 Yes
ICA 0.19 Yes

To make sure that the performance was stable relative to the dataset size we ran the baseline methods across a set of dataset sizes.

Uncorrelated features

The baseline method that has the best performance on uncorrelated neural-like activations is k-means, with an increase in MMCS with data size, the MMCS converged to around 0.55 when tested up to 409,600 data points.

We found the KMedoids implementation in scikit-learn to be very unstable, especially when the ground truth features had different probabilities. Also, during the experiments, we encountered random errors and artifacts.

The last two methods, PCA and ICA, were found to not be suitable for the task due to the fact that they just retrieve linear combinations of the most frequent features.

Correlated features

When experimenting with the dataset with correlated features, we found that the MMCS where more stable to changes in the dataset size, with KMeans and KMedoids being the best performing baseline methods.

Featuers Representedness

Uncorrelated features

Correlated features

Trainning of SAEs

We trained a set of Sparse Autoencoders for datasets generated using correlated and uncorrelated ground truth features.

The dataset and SAEs have the following specifications:

Parameter Value
Activation Dimensionality 256
Number of Ground Truth Features 512
Feature Probability Decay 0.99
Average Number of Features per Sample 5
$L_1$ Penalty [0.01, 0.02, 0.03, 0.06, 0.1, 0.18]
Dictionary Ratios [0.25, 0.5, 1, 2, 4, 8, 16, 32]
Batch Size 4096
Epochs 30000

Hence, we trained 48 SAEs for the uncorrelated and 48 for the correlated dataset, for a total of 96.

This large training run was made possible due to the services provided by vast.ai. The runs were conducted on an A6000 for 2 hours, and all the metrics were recorded using the wandb API.

During the training of each SAE, some metrics were recorded; some of them proved to be more interesting than others, primarily:

Summary Plots

We plot some of the metrics used in the original paper for both the correlated and uncorrelated datasets. Uncorrelated

Correlated

We can see that while the plots in both cases have similar structures in terms of the distribution of values, the values themselves are somewhat different, with a better reconstruction in the case of the correlated dataset but with a slightly worse MMCS.

These plots show how important the tuning of both hyperparameters is, especially when dealing with correlated and noisy features, as is the case in language models.

Dead Nuerons

We can observe the phenomenon in whic appropietly sized dictionaries plato the number of dead neurons vs the oversized dictionaries in which the number of dead neurons grow constantly.

This is an important fact that can help us create more efficient SAEs.

Feature tracking plots

We show some of the most intersting figures that show the training dynamics of the SAEs

Unlearned Features

One of the most interesting things we observed during training was the phenomenon in which some features where learned to be latter unlearned, we can observe how this was mostly common in undersized dictionaries, probably due to the gradient pressure to learn more important ground truth features .

Still Learned Features

Finally we can see how the number of features that are learned and don’t change the entry of the dictionary increases with the epochs, specially for apropietly sized dictionaries with the right $L_1$ penalty.

In the following post, we will explore the possibility of disentangling the features in a probabilistic fashion with the use of Variational Autoencoders, with sparse priors.

← Back to home