$$ \newcommand{\bone}{\mathbf{1}} \newcommand{\bbeta}{\mathbf{\beta}} \newcommand{\bdelta}{\mathbf{\delta}} \newcommand{\bepsilon}{\mathbf{\epsilon}} \newcommand{\blambda}{\mathbf{\lambda}} \newcommand{\bomega}{\mathbf{\omega}} \newcommand{\bpi}{\mathbf{\pi}} \newcommand{\bphi}{\mathbf{\phi}} \newcommand{\bvphi}{\mathbf{\varphi}} \newcommand{\bpsi}{\mathbf{\psi}} \newcommand{\bsigma}{\mathbf{\sigma}} \newcommand{\btheta}{\mathbf{\theta}} \newcommand{\btau}{\mathbf{\tau}} \newcommand{\ba}{\mathbf{a}} \newcommand{\bb}{\mathbf{b}} \newcommand{\bc}{\mathbf{c}} \newcommand{\bd}{\mathbf{d}} \newcommand{\be}{\mathbf{e}} \newcommand{\boldf}{\mathbf{f}} \newcommand{\bg}{\mathbf{g}} \newcommand{\bh}{\mathbf{h}} \newcommand{\bi}{\mathbf{i}} \newcommand{\bj}{\mathbf{j}} \newcommand{\bk}{\mathbf{k}} \newcommand{\bell}{\mathbf{\ell}} \newcommand{\bm}{\mathbf{m}} \newcommand{\bn}{\mathbf{n}} \newcommand{\bo}{\mathbf{o}} \newcommand{\bp}{\mathbf{p}} \newcommand{\bq}{\mathbf{q}} \newcommand{\br}{\mathbf{r}} \newcommand{\bs}{\mathbf{s}} \newcommand{\bt}{\mathbf{t}} \newcommand{\bu}{\mathbf{u}} \newcommand{\bv}{\mathbf{v}} \newcommand{\bw}{\mathbf{w}} \newcommand{\bx}{\mathbf{x}} \newcommand{\by}{\mathbf{y}} \newcommand{\bz}{\mathbf{z}} \newcommand{\bA}{\mathbf{A}} \newcommand{\bB}{\mathbf{B}} \newcommand{\bC}{\mathbf{C}} \newcommand{\bD}{\mathbf{D}} \newcommand{\bE}{\mathbf{E}} \newcommand{\bF}{\mathbf{F}} \newcommand{\bG}{\mathbf{G}} \newcommand{\bH}{\mathbf{H}} \newcommand{\bI}{\mathbf{I}} \newcommand{\bJ}{\mathbf{J}} \newcommand{\bK}{\mathbf{K}} \newcommand{\bL}{\mathbf{L}} \newcommand{\bM}{\mathbf{M}} \newcommand{\bN}{\mathbf{N}} \newcommand{\bP}{\mathbf{P}} \newcommand{\bQ}{\mathbf{Q}} \newcommand{\bR}{\mathbf{R}} \newcommand{\bS}{\mathbf{S}} \newcommand{\bT}{\mathbf{T}} \newcommand{\bU}{\mathbf{U}} \newcommand{\bV}{\mathbf{V}} \newcommand{\bW}{\mathbf{W}} \newcommand{\bX}{\mathbf{X}} \newcommand{\bY}{\mathbf{Y}} \newcommand{\bZ}{\mathbf{Z}} \newcommand{\bsa}{\boldsymbol{a}} \newcommand{\bsb}{\boldsymbol{b}} \newcommand{\bsc}{\boldsymbol{c}} \newcommand{\bsd}{\boldsymbol{d}} \newcommand{\bse}{\boldsymbol{e}} \newcommand{\bsoldf}{\boldsymbol{f}} \newcommand{\bsg}{\boldsymbol{g}} \newcommand{\bsh}{\boldsymbol{h}} \newcommand{\bsi}{\boldsymbol{i}} \newcommand{\bsj}{\boldsymbol{j}} \newcommand{\bsk}{\boldsymbol{k}} \newcommand{\bsell}{\boldsymbol{\ell}} \newcommand{\bsm}{\boldsymbol{m}} \newcommand{\bsn}{\boldsymbol{n}} \newcommand{\bso}{\boldsymbol{o}} \newcommand{\bsp}{\boldsymbol{p}} \newcommand{\bsq}{\boldsymbol{q}} \newcommand{\bsr}{\boldsymbol{r}} \newcommand{\bss}{\boldsymbol{s}} \newcommand{\bst}{\boldsymbol{t}} \newcommand{\bsu}{\boldsymbol{u}} \newcommand{\bsv}{\boldsymbol{v}} \newcommand{\bsw}{\boldsymbol{w}} \newcommand{\bsx}{\boldsymbol{x}} \newcommand{\bsy}{\boldsymbol{y}} \newcommand{\bsz}{\boldsymbol{z}} \newcommand{\bsA}{\boldsymbol{A}} \newcommand{\bsB}{\boldsymbol{B}} \newcommand{\bsC}{\boldsymbol{C}} \newcommand{\bsD}{\boldsymbol{D}} \newcommand{\bsE}{\boldsymbol{E}} \newcommand{\bsF}{\boldsymbol{F}} \newcommand{\bsG}{\boldsymbol{G}} \newcommand{\bsH}{\boldsymbol{H}} \newcommand{\bsI}{\boldsymbol{I}} \newcommand{\bsJ}{\boldsymbol{J}} \newcommand{\bsK}{\boldsymbol{K}} \newcommand{\bsL}{\boldsymbol{L}} \newcommand{\bsM}{\boldsymbol{M}} \newcommand{\bsN}{\boldsymbol{N}} \newcommand{\bsP}{\boldsymbol{P}} \newcommand{\bsQ}{\boldsymbol{Q}} \newcommand{\bsR}{\boldsymbol{R}} \newcommand{\bsS}{\boldsymbol{S}} \newcommand{\bsT}{\boldsymbol{T}} \newcommand{\bsU}{\boldsymbol{U}} \newcommand{\bsV}{\boldsymbol{V}} \newcommand{\bsW}{\boldsymbol{W}} \newcommand{\bsX}{\boldsymbol{X}} \newcommand{\bsY}{\boldsymbol{Y}} \newcommand{\bsZ}{\boldsymbol{Z}} \newcommand{\calA}{\mathcal{A}} \newcommand{\calB}{\mathcal{B}} \newcommand{\calC}{\mathcal{C}} \newcommand{\calD}{\mathcal{D}} \newcommand{\calE}{\mathcal{E}} \newcommand{\calF}{\mathcal{F}} \newcommand{\calG}{\mathcal{G}} \newcommand{\calH}{\mathcal{H}} \newcommand{\calI}{\mathcal{I}} \newcommand{\calJ}{\mathcal{J}} \newcommand{\calK}{\mathcal{K}} \newcommand{\calL}{\mathcal{L}} \newcommand{\calM}{\mathcal{M}} \newcommand{\calN}{\mathcal{N}} \newcommand{\calO}{\mathcal{O}} \newcommand{\calP}{\mathcal{P}} \newcommand{\calQ}{\mathcal{Q}} \newcommand{\calR}{\mathcal{R}} \newcommand{\calS}{\mathcal{S}} \newcommand{\calT}{\mathcal{T}} \newcommand{\calU}{\mathcal{U}} \newcommand{\calV}{\mathcal{V}} \newcommand{\calW}{\mathcal{W}} \newcommand{\calX}{\mathcal{X}} \newcommand{\calY}{\mathcal{Y}} \newcommand{\calZ}{\mathcal{Z}} \newcommand{\R}{\mathbb{R}} \newcommand{\C}{\mathbb{C}} \newcommand{\N}{\mathbb{N}} \newcommand{\Z}{\mathbb{Z}} \newcommand{\F}{\mathbb{F}} \newcommand{\Q}{\mathbb{Q}} \DeclareMathOperator*{\argmax}{arg\,max} \DeclareMathOperator*{\argmin}{arg\,min} \newcommand{\nnz}[1]{\mbox{nnz}(#1)} \newcommand{\dotprod}[2]{\langle #1, #2 \rangle} \newcommand{\ignore}[1]{} \let\Pr\relax \DeclareMathOperator*{\Pr}{\mathbf{Pr}} \newcommand{\E}{\mathbb{E}} \DeclareMathOperator*{\Ex}{\mathbf{E}} \DeclareMathOperator*{\Var}{\mathbf{Var}} \DeclareMathOperator*{\Cov}{\mathbf{Cov}} \DeclareMathOperator*{\stddev}{\mathbf{stddev}} \DeclareMathOperator*{\avg}{avg} \DeclareMathOperator{\poly}{poly} \DeclareMathOperator{\polylog}{polylog} \DeclareMathOperator{\size}{size} \DeclareMathOperator{\sgn}{sgn} \DeclareMathOperator{\dist}{dist} \DeclareMathOperator{\vol}{vol} \DeclareMathOperator{\spn}{span} \DeclareMathOperator{\supp}{supp} \DeclareMathOperator{\tr}{tr} \DeclareMathOperator{\Tr}{Tr} \DeclareMathOperator{\codim}{codim} \DeclareMathOperator{\diag}{diag} \newcommand{\PTIME}{\mathsf{P}} \newcommand{\LOGSPACE}{\mathsf{L}} \newcommand{\ZPP}{\mathsf{ZPP}} \newcommand{\RP}{\mathsf{RP}} \newcommand{\BPP}{\mathsf{BPP}} \newcommand{\P}{\mathsf{P}} \newcommand{\NP}{\mathsf{NP}} \newcommand{\TC}{\mathsf{TC}} \newcommand{\AC}{\mathsf{AC}} \newcommand{\SC}{\mathsf{SC}} \newcommand{\SZK}{\mathsf{SZK}} \newcommand{\AM}{\mathsf{AM}} \newcommand{\IP}{\mathsf{IP}} \newcommand{\PSPACE}{\mathsf{PSPACE}} \newcommand{\EXP}{\mathsf{EXP}} \newcommand{\MIP}{\mathsf{MIP}} \newcommand{\NEXP}{\mathsf{NEXP}} \newcommand{\BQP}{\mathsf{BQP}} \newcommand{\distP}{\mathsf{dist\textbf{P}}} \newcommand{\distNP}{\mathsf{dist\textbf{NP}}} \newcommand{\eps}{\epsilon} \newcommand{\lam}{\lambda} \newcommand{\dleta}{\delta} \newcommand{\simga}{\sigma} \newcommand{\vphi}{\varphi} \newcommand{\la}{\langle} \newcommand{\ra}{\rangle} \newcommand{\wt}[1]{\widetilde{#1}} \newcommand{\wh}[1]{\widehat{#1}} \newcommand{\ol}[1]{\overline{#1}} \newcommand{\ul}[1]{\underline{#1}} \newcommand{\ot}{\otimes} \newcommand{\zo}{\{0,1\}} \newcommand{\co}{:} %\newcommand{\co}{\colon} \newcommand{\bdry}{\partial} \newcommand{\grad}{\nabla} \newcommand{\transp}{^\intercal} \newcommand{\inv}{^{-1}} \newcommand{\symmdiff}{\triangle} \newcommand{\symdiff}{\symmdiff} \newcommand{\half}{\tfrac{1}{2}} \newcommand{\mathbbm}{\Bbb} \newcommand{\bbone}{\mathbbm 1} \newcommand{\Id}{\bbone} \newcommand{\SAT}{\mathsf{SAT}} \newcommand{\bcalG}{\boldsymbol{\calG}} \newcommand{\calbG}{\bcalG} \newcommand{\bcalX}{\boldsymbol{\calX}} \newcommand{\calbX}{\bcalX} \newcommand{\bcalY}{\boldsymbol{\calY}} \newcommand{\calbY}{\bcalY} \newcommand{\bcalZ}{\boldsymbol{\calZ}} \newcommand{\calbZ}{\bcalZ} $$

Neural Networks from Maximizing Rate Reduction

post.cover
Mt. Aoraki (Mt. Cook), South Island, New Zealand

The quest for a new white box modeling paradigm

While we have witnessed empirical evidence of the success of deep learning, much of it is attributable to trial and error and not guided by underlying mathematical principles. I attended Yi Ma’s keynote on Pursuing the Nature of Intelligence at ICLR this year, which took on a statistical lens towards urging the community to view model training as learning to do compression.

I was especially struck by the novelty of some of his recent work on using coding rate reduction as a learning objective, and the remainder of this post will be a high-level overview of his paper ReduNet: ReduNet: A White-box Deep Network from the Principle of Maximizing Rate Reduction .

The paper tries to answer the question of how to develop a principled mathematical framework for better understanding and design of deep networks?

To this end, they noted that all predictable information is encoded as a distribution of low-dimensional supports observed in high-dimensional data space, and hence compression is the only inductive bias that our models need. They come up with a constructive modeling approach called ReduNet hat uses the principle of maximal coding rate reduction.

ReduNet

ReduNet is motivated by the following three desiderata for a model: that features of samples from the same (resp. different) class belong to the same (resp. different) low-dimensional linear subspace, and that variance of features within a class should be as large as possible as long as they stay uncorrelated from other classes for diversity. This will result in features that can be easily discriminated by a linear model.

How can we measure the “compactness” of a distribution of these latents to achieve the goals above? We can’t use cross entropy as we want something that doesn’t depend on the existence of class labels so we can do unsupervised learning. We also can’t use information-theoretic measures like entropy or information gain, because it is not always well-defined on all distributions (i.e diverges for Cauchy). In addition, we want something that we can actually compute tractably using a finite number of samples as an approximation.

To do this, they use the coding rate of the features, defined as the average number of bits needed to encode a set of learned representations $Z = [z^1, \cdots, z^m]$ with each $z^i \in \R^d$, each of which can be recovered up to some error $\epsilon$ via a codebook: $\mathcal{L}(\boldsymbol{Z}, \epsilon) \doteq\left(\frac{m+n}{2}\right) \log \operatorname{det}\left(\boldsymbol{I}+\frac{n}{m \epsilon^2} \boldsymbol{Z} \boldsymbol{Z}^*\right)$. If $m \gg n$ as is the case, then the average coding rate is

\[R(\boldsymbol{Z}, \epsilon) \doteq \frac{1}{2} \log \operatorname{det}\left(\boldsymbol{I}+\frac{n}{m \epsilon^2} \boldsymbol{Z} \boldsymbol{Z}^*\right)\]

We can similarly define the average coding rate for each class of the data, where \(\mathbb{\Pi}^j\) encodes the probability of membership for class $j$:

\[R_c(\boldsymbol{Z}, \epsilon \mid \boldsymbol{\Pi}) \doteq \sum_{j=1}^k \frac{\operatorname{tr}\left(\boldsymbol{\Pi}^j\right)}{2 m} \log \operatorname{det}\left(\boldsymbol{I}+\frac{n}{\operatorname{tr}\left(\boldsymbol{\Pi}^j\right) \epsilon^2} \boldsymbol{Z} \boldsymbol{\Pi}^j \boldsymbol{Z}^*\right)\]

The point of defining $R$ and $R_c$ this way is that we want to maximize the coding rate over the entire dataset while minimizing the coding rate within each class, which encourages inter-class subspaces to be orthogonal. Furthermore, controlling for $\epsilon$ allows us to preserve intra-class diversity of features.

This gives us the following objective:

\[\max _{\boldsymbol{\theta}, \boldsymbol{\Pi}} \Delta R(\boldsymbol{Z}(\boldsymbol{\theta}), \boldsymbol{\Pi}, \epsilon)=R(\boldsymbol{Z}(\boldsymbol{\theta}), \epsilon)-R_c(\boldsymbol{Z}(\boldsymbol{\theta}), \epsilon \mid \boldsymbol{\Pi}), \quad$ s.t. $\left\|\boldsymbol{Z}^j(\boldsymbol{\theta})\right\|_F^2=m_j, \boldsymbol{\Pi} \in \Omega,\]

where the constraint is to ensure that feature sizes are normalized.

Intuitively, the picture looks like the below, where we try to maximize size of our codebook (i.e balls) to capture all the data whilst minimizing the codebook of each individual class:

Training

The neat thing about these networks is that we only have to train it layer by layer sequentially with just forward propagation (no backpropagation).

In plaintext, it works as follows:

  1. Suppose we want to train a ReduNet with $L$ layers
  2. Initialize our initial set of features to be the same as our data
  3. Compute our gradients $\bE_{\ell}$ for the $R$ term and $\bC^j_{\ell}$ for the $R_c$ term (note that this is just gradients for just layer $l$)
  4. Compute soft assignments in feature space (I didn’t understand the requirement for this fully, but I think it stems from the need to support the unsupervised context when the true class labels aren’t known)
  5. Output new features with the expression in line 6, normalized by projecting onto the unit sphere $\mathcal{P}$
  6. Repeat for each subsequent layer

Geometrically, at each step the features across classes become increasingly orthogonal, whereas those within are contracted together:

One may worry about collapse of intra-class features, but they say that neural collapse will give rise to a suboptimal overall coding rate and so is avoided.

Some Complaints

Given the novelty of the ideas of the paper, it took some time to digest and I felt some parts were under-explained, such as the reason behind why soft assignments are needed. It was also not immediately clear why we can’t keep training a single large width layer repeatedly with this setup to get good features, or experiments to show whether this was good/bad.

Concluding Thoughts

It takes a lot of determination and courage to push through novel approaches of tackling fundamental problems that people have taken for granted with traditional approaches. I think coding rate reduction is just one technique and a single step in the grander scheme of coming up with more explainable and interpretable neural network architectures that are designed from mathematical principles rather than discovered by accident, and we still have a long (but exciting) way to go in this direction.




    Related Posts:

  • An Intuitive Introduction to Gaussian Processes
  • Bounding Mixing Times of Markov Chains via the Spectral Gap
  • Notes on 'The Llama 3 Herd of Models'
  • Playing Sound Voltex at Home: Setting Up Unnamed SDVX Clone with the Yuancon SDVX Controller
  • Creating Trackback Requests for Static Sites