Notes

MNIST-CIFAR Dominoes

Here are some dominoes based on [1]. The idea behind this dataset is that there are two "patterns" in the data: the MNIST image and the CIFAR image.

Pasted image 20230316175757.png

Notice that some of the dominoes have only one "pattern" present. By tracking training/test loss on these one-sided dominoes, we can tease apart how quickly the model learns the two different patterns.

We'd like to compare these pattern-learning curves to the curves predicted by the toy model of [2]. In particular, we'd like to compare predictions to the empirical curves as we change the relevant macroscopic parameters (e.g., prevalence, reliability, and simplicity1).

Which means running sweeps over these macroscopic parameters.

Prevalence

What happens as we change the relative incidence of MNIST vs CIFAR images in the dataset? We can accomplish this by varying the frequency of one-sided MNIST dominoes vs. one-sided CIFAR dominoes.

We control two parameters:

  • pmp_m, the probability of a domino containing an MNIST image (either one-sided or two-sided),
  • pcp_c, the probability of a domino containing a CIFAR image (either one-sided or two-sided), and

Two parameters are fixed by our datasets:

  • NmN_m, the number of samples in the MNIST dataset.
  • NcN_c, the number of samples in the CIFAR dataset.

Given these parameters, we have to determine:

  • rm0r_{m0}, the fraction of the MNIST dataset that we reject,
  • rm1r_{m1}, the fraction of the MNIST dataset that ends up in one-sided dominoes,
  • rm2r_{m2}, the fraction of the MNIST dataset that ends up in two-sided dominoes,

and, similarly, rc0r_{c0}, rc1r_{c1}, and rc2r_{c2} for the CIFAR dataset.

IMG_495F2C6C4A1E-1.jpeg Here's the corresponding Sankey diagram (in terms of numbers of samples rather than probabilities, but it's totally equivalent).

Six unknowns means we need six constraints.

We get the first two from the requirement that probabilities are normalized,

rm0+rm1+rm2=rc0+rc1+rc2=1, r_{m0} + r_{m1} + r_{m2} = r_{c0} + r_{c1} + r_{c2} = 1,

and the another from the double dominoes requiring the sample number of samples from both datasets,

rm2Nm=rc2Nc. r_{m2} N_m = r_{c2} N_c.

Out of convenience, we'll introduce an additional variable, which we immediately constrain,

N=rc1Nc+rm1Nm+rm2Nm, N = r_{c1}N_c + r_{m1}N_m + r_{m2} N_m,

the number of samples in the resulting dominoes dataset.

We get the last three constraints from our choices of pmp_{m}, pcp_c, and p1p_1:

Npm=Nm1+N2=rm1Nm+rm2Nm, N p_m = N_{m1} + N_2 = r_{m1} N_m + r_{m2} N_m,
Npc=Nc1+N2=rc1Nc+rc2Nc, N p_c = N_{c1} + N_2 = r_{c1} N_c + r_{c2} N_c,

In matrix format,

(1110000000111000Nm00Nc00NmNm0Nc010NmNm000pm0000NcNcpc0Nm00Nc0p1)(rm0rm1rm2rc0rc1rc2N)=(1100000),\begin{pmatrix} 1 & 1 & 1 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 1 & 1 & 1 & 0 \\ 0 & 0 & N_m & 0 & 0 & -N_c & 0 \\ 0 & N_m & N_m & 0 & N_c & 0 & 1 \\ 0 & N_m & N_m & 0 & 0 & 0 & -p_m \\ 0 & 0 & 0 & 0 & N_c & N_c & -p_c \\ 0 & N_m & 0 & 0 & N_c & 0 & -p_1 \end{pmatrix} \cdot \begin{pmatrix} r_{m0} \\ r_{m1} \\ r_{m2} \\ r_{c0} \\ r_{c1} \\ r_{c2} \\ N \end{pmatrix} = \begin{pmatrix} 1 \\ 1 \\ 0 \\ 0 \\ 0 \\ 0 \\ 0 \end{pmatrix},

where p1=2pcpmp_1 = 2 - p_c - p_m.

So unfortunately, this yields trivial answers where rm0=rc0=1r_{m0}=r_{c0}=1 and all other values are 0. The solution seems to be to just allow there to be empty dominoes.

Reliability

We can vary the reliability by inserting "wrong" dominoes. I.e.: with some probability make either of the two sides display the incorrect class for the label.

Simplicity

One of the downsides of this task is that we don't have much control over the simplicity of the feature. MNIST is simpler than CIFAR, sure, but how much? How might we control this?

Footnotes

  1. Axes conceived by Ekdeep Singh Lubana.

How to Perturb Weights

I'm running a series of experiments that involve some variation of: (1) perturb a weight initialization; (2) train the perturbed and baseline models in parallel, and (3) track how the perturbation grows/shrinks over time.

Naively, if we're interested in a perturbation analysis of the choice of weight initialization, we prepare some baseline initialization, w0\mathbf w_0, and then apply i.i.d. Gaussian noise, δ\boldsymbol \delta, to each of its elements, δiN(0,ϵ2)\delta_i \sim \mathcal N(0, \epsilon^2). (If we want, we can let this vary layer-by-layer and let it depend on, for example, the norm of the layer it's being applied to.)

The problem with this strategy is that the perturbed weights w=w0+δ\mathbf w = \mathbf w_0 + \boldsymbol\delta are, in general, no longer sampled from the same distribution as the baseline weights.

There is nothing wrong with this per se, but it introduces a possible confounder (the thickness). This is especially relevant if we're interested specifically in the question of how behavior changes with the size of a perturbation, this problem introduces a possible confounder. As responsible experimentalists, we don't like confounders.

Fortunately, there's an easy way to "clean up" Kaiming He to make it better suited to this perturbative analysis.

Kaiming initialization lives in a hyperspherical shell

Consider a matrix, w(l)\mathbf w^{(l)}, representing the weights of a particular layer ll with shape (Din(l),Dout(l+1))(D_\mathrm{in}^{(l)}, D_\mathrm{out}^{(l+1)}). Din(l)D_\mathrm{in}^{(l)} is also called the fan-in of the layer, and Din(l+1)D_\mathrm{in}^{(l+1)} the fan-out. For ease of presentation, we'll ignore the bias, though the following reasoning applies equally well to the bias.

We're interested in the vectorized form of this matrix, w(l)RD(l)\vec w^{(l)} \in \mathbb R^{D^{(l)}}, where D(l)=Din(l)×Dout(l+1)D^{(l)} =D_\mathrm{in}^{(l)} \times D_\mathrm{out}^{(l+1)}.

In Kaiming initialization, we sample the components, wi(l)w_i^{(l)}, of this vector, i.i.d. from a normal distribution with mean 0 and variance σ2\sigma^2 (where σ2=2Din(l)\sigma^2 = \frac{2}{D_\mathrm{in}^{(l)}}).

Geometrically, this is equivalent to sampling from a hyperspherical shell, SD1S^{D-1} with radius Dσ\sqrt{D}\sigma and (fuzzy) thickness, δ\delta. (Ok, so technically, because the radius can vary from layer-to-layer, it's a hyperellipsoidal shell.)

This follows from some straightforward algebra (dropping the superscript ll for simplicity):

E[w2]=E[i=1Dwi2]=i=1DE[wi2]=i=1Dσ2=Dσ2, \mathbb E[|\mathbf w|^2] = \mathbb E\left[\sum_{i=1}^D w_i^2\right] = \sum_{i=1}^D \mathbb E[w_i^2] = \sum_{i=1}^D \sigma^2 = D\sigma^2,

and

δ2var[w2]=E[(i=1Dwi2)2]E[i=1Dwi2]2=i,j=1DE[wi2wj2](Dσ2)2=ijDE[wi2]E[wj2]+i=1DE[wi4](Dσ2)2=D(D1)σ4+D(3σ4)(Dσ2)2=2Dσ4. \begin{align} \delta^2 \propto \mathrm{var} [|\mathbf w|^2] &= \mathbb E\left[\left(\sum_{i=1}^D w_i^2\right)^2\right] - \mathbb E\left[\sum_{i=1}^D w_i^2\right]^2 \\ &= \sum_{i, j=1}^D \mathbb E[w_i^2 w_j^2] - (D\sigma^2)^2 \\ &= \sum_{i \neq j}^D \mathbb E[w_i^2] \mathbb E[w_j^2] + \sum_{i=1}^D \mathbb E[w_i^4]- (D\sigma^2)^2 \\ &= D(D-1) \sigma^4 + D(3\sigma^4) - (D\sigma^2)^2 \\ &= 2D\sigma^4. \end{align}

So the thickness as a fraction of the radius is

δDσ=2DσD=2σ=2Din(l), \frac{\delta}{\sqrt{D}\sigma} = \frac{\sqrt{2D}\sigma}{\sqrt{D}} = \sqrt{2}\sigma = \frac{2}{\sqrt{D_\mathrm{in}^{(l)}}},

where the last equality follows from the choice of σ\sigma for Kaiming initialization.

This means that for suitably wide networks (Din(l)D_\mathrm{in}^{(l)} \to \infty), the thickness of this shell goes to 00.

Taking the thickness to 0

This suggests an alternative initialization strategy. What if we immediately take the limit Din(l)D_\text{in}^{(l)} \to \infty, and sample directly from the boundary of a hypersphere with radius Dσ\sqrt{D}\sigma, i.e., modify the shell thickness to be 00.

This can easily be done by sampling each component from a normal distribution with mean 0 and variance 11 and then normalizing the resulting vector to have length Dσ\sqrt{D}\sigma (this is known as the Muller method).

Perturbing Weight initialization

The modification we made to Kaiming initialization was to sample directly from the boundary of a hypersphere, rather than from a hyperspherical shell. This is a more natural choice when conducting a perturbation analysis, because it makes it easier to ensure that the perturbed weights are sampled from the same distribution as the baseline weights.

Geometrically, the intersection of a hypersphere SDS^D of radius w0=w0w_0=|\mathbf w_0| with a hypersphere SDS^D of radius ϵ\epsilon that is centered at some point on the boundary of the first hypersphere, is a lower-dimensional hypersphere SD1S^{D-1} of a modified radius ϵ\epsilon'. If we sample uniformly from this lower-dimensional hypersphere, then the resulting points will follow the same distribution over the original hypersphere.

This suggests a procedure to sample from the intersection of the weight initialization hypersphere and the perturbation hypersphere.

First, we sample from a hypersphere of dimension D1D-1 and radius ϵ\epsilon' (using the same technique we used to sample the baseline weights). From a bit of trigonometry, see figure below, we know that the radius of this hypersphere will be ϵ=w0cosθ\epsilon' = w_0\cos \theta, where θ=cos1(1ϵ22w02)\theta = \cos^{-1}\left(1-\frac{\epsilon^2}{2w_0^2}\right).

400

Next, we rotate the vector so it is orthogonal to the baseline vector w0\mathbf w_0. This is done with a Householder reflection, HH, that maps the current normal vector n^=(0,,0,1)\hat{\mathbf n} = (0, \dots, 0, 1) onto w0\mathbf w_0:

H=I2ccTcTc, H = \mathbf I - 2\frac{\mathbf c \mathbf c^T}{\mathbf c^T \mathbf c},

where

c=n^+w^0, \mathbf c = \hat{\mathbf n} + \hat {\mathbf w}_0,

and w^0=w0w0\hat{\mathbf w}_0 = \frac{\mathbf w_0}{|w_0|} is the unit vector in the direction of the baseline weights.

Implementation note: For the sake of tractability, we directly apply the reflection via:

Hy=y2cTycTcc. H\mathbf y = \mathbf y - 2 \frac{\mathbf c^T \mathbf y}{\mathbf c^T\mathbf c} \mathbf c.

Finally, we translate the rotated intersection sphere along the baseline vector, so that its boundary goes through the intersection of the two hyperspheres. From the figure above, we find that the translation has the magnitude w0=w0cosθw_0' = w_0 \cos \theta.

By the uniform sampling of the intersection sphere and the uniform sampling of the baseline vector, we know that the resulting perturbed vector will have the same distribution as the baseline vector, when restricted to the intersection sphere.

sampling-perturation 1.png

Random Walks on Hyperspheres

Neural networks begin their lives on a hypersphere.

As it turns out, they live out the remainder of their lives on hyperspheres as well. Take the norm of the vectorized weights, w\mathbf w, and plot it as a function of the number of training steps. Across many different choices of weight initialization, this norm will grow (or shrink, depending on weight decay) remarkably consistently.

In the following figure, I've plotted this behavior for 10 independent initializations across 70 epochs of training a simple MNIST classifier. The same results hold more generally for a wide range of hyperparameters and architectures.

Pasted image 20230124102455.png

The norm during training, divided by the norm upon weight initialization. The parameter ϵ controls how far apart (as a fraction of the initial weight norm) the different weight initializations are.

All this was making me think that I disregarded Brownian motion as a model of SGD dynamics a little too quickly. Yes, on lattices/planes, increasing the number of dimensions makes no difference to the overall behavior (beyond decreasing the variance in how displacement grows with time). But maybe that was just the wrong model. Maybe Brownian motion on hyperspheres is considerably different and is actually the thing I should have been looking at.

Asymptotics (t,nt, n \to \infty)

We'd like to calculate the expected displacement (=distance from our starting point) as a function of time. We're interested in both distance on the sphere and distance in the embedding space.

Before we tackle this in full, let's try to understand Brownian motion on hyperspheres in a few simplifying limits.

Unlike the Euclidean case, where having more than one dimension means you'll never return to the origin, living on a finite sphere means we'll visit every point infinitely many times. The stationary distribution is the uniform distribution over the sphere, which we'll denote U(n)\mathcal U^{(n)} for SnS^n.

IMG_671D5890C083-1.jpeg

So our expectations end up being relatively simple integrals over the sphere.

The distance from the pole at the bottom to any other point on the surface depends only on the angle from the line intersecting the origin and the pole to the line intersecting the origin and the other point,

z(θ)=(rsinθ)2(rrcosθ)2=2rsin(θ2). z(\theta) = \sqrt{(r\sin \theta)^2 - (r-r\cos \theta)^2} = 2r\sin \left(\frac{\theta}{2}\right).

This symmetry means we only have to consider a single integral from the bottom to to the top of Sn1S^{n-1} "slices." Each slice has as much probability associated to it as the corresponding Sn1S^{n-1} has surface area times the measure associated to θ\theta (=r dθ=r\ \mathrm d \theta). This surface area varies with the angle from the center, as does the distance of that point from the origin. Put it together and you get some nasty integrals.

IMG_63F07F66DFB1-1.jpeg

From this point on, I believe I made a mistake somewhere (the analysis for nn \to \infty is correct).

If we let Sn(x)S_{n}(x) denote the surface area of SnS^{n} with radius rr, we can fill in the formula from Wikipedia:

Sn1(x)=2πn2Γ(n2)xn1. S_{n-1}(x) = \frac{2\pi^{\frac{n}{2}}}{\Gamma(\frac{n}{2})} x^{n-1}.

Filling in x(θ)=rsinθx(\theta) =r \sin\theta, we get an expression for the surface area of each slice as a function of θ\theta,

Sn1(x(θ))=Sn1(r)sinn1(θ), S_{n-1}(x(\theta)) = S_{n-1}(r) \sin^{n-1}(\theta),

so the measure associated to each slice is

dμ(θ)=Sn1(r)sinn1(θ) rdθ. \mathrm d\mu(\theta) = S_{n-1}(r)\sin^{n-1}(\theta)\ r\, \mathrm d\theta.

nn \to \infty

We can simplify further, as nn\to\infty, the sinn1(θ)\sin^{n-1}(\theta) will tend towards a "comb function," that is zero everywhere except for θ=(m+1/2)π\theta = (m + 1/2)\pi for integer mm, where it is equal to 11 if mm is even, and 1-1 if mm is odd.

Within the above integration limits (0,π)(0, \pi), this means all of the probability becomes concentrated at the middle, where θ=π/2\theta=\pi/2.

So in the infinite-dimensional limit, the expected distance from the origin is the distance between a point on the middle and one of its poles, which is 2r\sqrt{2} r.

General case

An important lesson I've learned here is: "try more keyword variations in Google Scholar and Elicit before going off and solving these nasty integrals yourself because you will probably make a mistake somewhere and someone else has probably already done the thing you're trying to do."

Anyway, here's the thing where someone else has already done the thing I was trying to do. [1] gives us the following form for the expected cosine of the angular displacement with time:

cosθ(t)=exp(D(n1)t/R2).\langle\cos \theta(t)\rangle=\exp \left(-D(n-1) t / R^2\right).

A little trig tells us that the expected displacement in the embedding space is

z(t)=r2(1cosθ), \langle z(t)\rangle = r \sqrt{2\cdot(1-\langle \cos\theta\rangle)},

which gives us a set of nice curves depending on the diffusion constant DD:

output.png

The major difference with Brownian motion on the plane is that expected displacement asymptotes at the limit we argued for above (2r\sqrt{2} r).

When I had initially looked at curves like the following (this ones from [numerical simulation](https://github.com/jqhoogland/experiminis/blob/main/brownian_hyperspheres.ipynb, not the formula), I thought I had struck gold.

rw_sphere.png

I've been looking for a model that predicts square root like growth at the very start, followed by a long period of nearly linear growth, followed by asymptotic behavior. Eyeballing this, it looked like my searching for the long plateau of linear growth had come to an end.

Of course, this was entirely wishful thinking. Human eyes aren't great discriminators of lines. That "straight line" is actually pretty much parabolic. As I should have known (from the Euclidean nature of manifolds), the initial behavior is incredibly well modeled by a square root.

As you get further away from the nearly flat regime, the concavity only increases. This is not the curve I was looking for.

rw-comparison.png

So my search continues.