《White-Box Transformers via Sparse Rate Reduction》论文学习

发布时间 2023-09-12 17:28:21作者: 郑瀚Andrew

一、Introduction

近年来,深度学习在处理大量高维多模态数据方面取得了巨大的实证成功。其中很大一部分成功归功于对数据分布的有效学习,然后将分布转化为简洁的结构化和紧凑的表示形式,这有助于许多下游任务(例如视觉、分类、识别和分割以及生成。为此,已提出和实践了许多模型和方法,每种方法都有其优点和局限性。

在这里,我们对几种流行的方法进行简要介绍,作为我们在这项工作中寻求完整理解和统一的背景。

Transformer模型和自注意力(Transformer models and self-attention)

Transformer是最新流行的用于学习高维结构化数据表示的模型之一,例如文本、图像和其他类型的信号。

在第一个块之后,将每个数据点(例如文本语料库或图像)转换为一组或序列的tokens标记,并以一种介质不可知的方式对标记集进行进一步处理。

Transformer模型的一个基石是所谓的自注意力层,它利用tokens标记序列中的统计相关性来改进tokens标记表示。

Transformer在学习性能良好的紧凑表示方面取得了巨大成功。然而,Transformer网络架构是经验设计的,缺乏严格的数学解释。事实上,注意力层的输出本身有几种竞争的解释。因此,数据分布与Transformer学到的最终表示之间的统计和几何关系在很大程度上仍然是一个神秘的黑盒子。

扩散模型和去噪(Diffusion models and denoising)

扩散模型最近成为学习数据分布的一种流行方法,特别是用于”生成任务(generative tasks)“和高度结构化但难以有效建模的自然图像数据。

扩散模型的核心概念是从高斯噪声分布(或其他标准模板)中采样特征,并迭代地去噪和变形特征分布,直到收敛到原始数据分布。如果将这个过程建模为一步是计算上不可行的,因此通常将其分为多个增量步骤。每个步骤的关键是所谓的评分函数,或者说是“最佳去噪函数”的估计。在实践中,这个函数是使用通用的黑盒深度网络建模的。

扩散模型已经显示出在学习和从数据分布中采样方面的有效性。然而,尽管近期进行了一些努力,它们通常没有建立起初始特征与数据样本之间的清晰对应关系。因此,扩散模型本身并没有提供对数据分布的简洁或可解释的表示。

结构寻求模型和采样降维(Structure-seeking models and rate reduction)

在前两种方法中都是通过使用深度网络解决下游任务(例如分类或生成/抽样)的副产品来隐式构建的。然而,我们也可以直接显式地学习数据分布,作为任务本身的目的。

  • 最常见的方法是尝试识别和表示输入数据中的低维结构。这一范式的经典示例包括基于模型的方法,如稀疏编码和字典学习,这些方法促使了早期的深度网络架构的设计和解释。
  • 近年来的方法则更多地从无模型的角度出发,通过一个足够信息丰富的预训练任务来学习表示(例如在对比学习中压缩相似和分离不相似的数据,或者在最大编码速率减少方法类别中最大化信息增益)。

与黑盒深度学习方法相比,基于模型和无模型的表示学习方案具有更好的可解释性:

  • 首先,它们允许用户明确设计所学表示的期望属性。
  • 此外,它们允许用户通过展开表示学习目标的优化策略来构建新的白盒前向深度网络架构,使得构建网络的每一层实现优化算法的迭代。

然而,不幸的是,在这种范式中,如果所需属性的定义狭窄,可能很难在大规模真实数据集上实现良好的实际性能。

主要贡献和本文的概述

在本文中,我们旨在通过更统一的框架来解决这些现有方法的局限性,设计类似transformer的网络架构,从而实现数学可解释性和良好的实际性能。为此,我们提出学习一系列增量映射,以获得输入数据(或其令牌集)的最紧凑和稀疏表示,优化统一的目标函数,即稀疏率降维。映射的目标在下图中进行了说明。

在这个框架内,我们将上述三种看似不相关的方法统一起来,并展示了类似transformer的深度网络层可以自然地从展开迭代优化方案中派生出来,以逐步优化稀疏率降维目标。 

The ‘main loop’ of the CRATE white-box deep network design. After encoding input data X as a sequence of tokens Z0, CRATE constructs a deep network that transforms the data to a canonical configuration of low-dimensional subspaces by successive compression against a local model for the distribution, generating Zℓ+1/2, and sparsification against a global dictionary, generating Zℓ+1. Repeatedly stacking these blocks and training the model parameters via backpropagation yields a powerful and interpretable representation of the data.

具体而言,我们的贡献和本文的概述如下: 

  • 我们使用一个理想化的令牌分布模型,证明了如果将令牌迭代地向低维子空间去噪声,相关的评分函数会呈现出类似于Transformer中的自注意运算符的显式形式。
  • 我们推导出多头自注意层作为一个展开的梯度下降步骤,以最小化有损编码率部分的速率降低,从而展示了自注意层的另一种解释,即对令牌表示进行压缩。
  • 我们展示了紧随多头自注意的多层感知器可以被解释为(并被替换为)一个层,通过构建令牌表示的稀疏编码来逐步优化稀疏率降维目标的剩余部分。
  • 我们利用这个理解来创建一个新的白盒(完全数学可解释的)Transformer架构,称为CRATE(即Coding RAte reduction TransformEr),其中每个层执行交替最小化算法的单步操作,以优化稀疏率降维目标。

因此,在我们的框架中,学习目标函数、深度学习架构和最终学习到的表示都成为完全数学可解释的白盒。

CRATE网络虽然简单,但已经可以在大规模真实数据集上学习到所需的压缩和稀疏表示,并在各种任务(如分类和迁移学习)上实现与更复杂的Transformer网络(如ViT)相当的性能。

参考链接:

https://ma-lab-berkeley.github.io/CRATE/ 
https://arxiv.org/pdf/2306.01129.pdf

 

二、Technical Approach and Justification

0x1:Objective and Approach

我们考虑一个与现实世界信号相关的一般学习任务。

我们有一些随机变量X = [x1, . . . , xN] ∈ RD×N,它是我们的数据来源。每个xi ∈ RD被解释为一个令牌token,xi的相关结构可以是任意的。

我们使用Z = [z1, . . . , zN] ∈ Rd×N来表示定义输入表示的随机变量。每个zi ∈ Rd是相应令牌xi的表示。

我们给出了B ≥ 1 i.i.d.(独立同分布)的样本X1, . . . , XB ∼ X,其令牌为xi,b

我们样本的表示表示为Z1, . . . , ZB ∼ Z,令牌的表示为zi,b

最后,对于给定的网络,当输入为X时,我们使用Z来表示前ℓ层的输出。相应地,样本的输出为Zi,令牌的输出为zi,b

1、Objective for learning a structured and compact representation

根据稀疏率降维框架,我们认为表示学习(representation learning)的目标是找到一个特征映射f:

将具有潜在非线性和多模态分布的输入数据X∈RD×N转换为一个(分段)线性化和紧凑的特征表示Z∈Rd×N

虽然对应的特征表示Z的联合分布(ziNi=1的联合分布可能很复杂(并且任务特定),但我们进一步认为要求单个token标记zi的目标边缘分布应该高度压缩和结构化,便于紧凑编码,这是合理和实用的。

特别地,我们要求该分布是低维(比如K)高斯分布的混合,其中第k个高斯分布的均值为0∈Rd,协方差Σk⪰0∈Rd×d,并且由正交基Uk∈Rd×p组成。

我们用U[K] = (Uk)Kk=1表示所有高斯分布的基集。

因此,为了最大化最终token标记表示的信息增益,我们希望最大化token标记的编码率降低,即:

其中R和Rc是损失编码率的估计。

这也促使来自不同高斯分布的标记表示zi不相关。由于编码率降低是表示正例(goodness)的内在度量,它对表示的任意旋转是不变的。因此,为了确保最终的表示适于更紧凑的编码,我们希望将表示(及其支撑子空间)转换为相对于结果表示空间的标准坐标而言是稀疏的。这种结合速率降低和稀疏化的过程在下图中示意。

从计算上讲,我们可以将上述两个目标合并为一个统一的优化目标:

其中ℓ0范数 ||Z||0提升了最终token标记表示Z = f(X)的稀疏性。

我们将这个目标称为“稀疏编码率降低”。

2、White-box deep architecture as unrolled incremental optimization

虽然很容易陈述,但上述目标的每个术语在计算上都很难优化。因此,自然而然地采用一种近似方法,通过多个简单的增量和局部操作f的连续组合来实现全局转换f优化,从而将表示分布推向期望的简约模型分布。

其中 f0 : RD → Rd 是将输入token令牌 xi ∈ RD 转换为它们的token令牌表示 z1i ∈ Rd 的预处理映射。

每个增量前向映射 Zℓ+1 = f(Z),或称为“层”,根据其输入token令牌的分布 Z,优化上述稀疏率降维目标函数。

与其他展开优化方法(如ReduNet)不同,我们明确地对每个层的输入分布 Z 进行建模,例如将其建模为线性子空间的混合或由字典稀疏生成。模型参数通过数据学习(例如通过端对端训练进行反向传播)。前向“优化”和后向“学习”的区分明确了每个层作为操作符,转换其输入分布的数学角色,而输入分布则由层的参数建模(并随后学习)。

0x2:Self-Attention via Denoising Tokens Towards Multiple Subspaces

有很多不同的方式可以逐步优化目标函数。在这项工作中,我们提出了可能是最基本的方案。为了帮助澄清我们推导和近似的直觉,在本节中,我们研究了一个在很大程度上理想化的模型,尽管如此,它仍然捕捉到了几乎整个过程的本质,并特别揭示了为什么在许多情况下会出现类似于自注意力的运算符的原因。

假设N = 1,并且单个令牌x是从一个未知的高斯混合中独立同分布地抽取的,该混合分布在低维子空间上支持具有正交基,并且受到加性高斯噪声的干扰,即:

其中z按照混合分布进行分布。

我们的目标仅仅是将带有噪声的令牌x的分布转化为低维高斯分布的混合。

根据上述增量构建表示f的目标,我们进行归纳推理:如果z 是一个噪声令牌,在噪声水平σ 下进行去噪是自然的。从均方意义上讲,最优估计是E[z | z ],它具有变分特征: 

One layer of the CRATE architecture. The full architecture is simply a concatenation of such layers, with some initial tokenizer and final task-specific architecture (i.e., a classification head).

 

三、Application Experiments

Classification

下面是CRATE用于分类任务的架构。它与流行的vision transformer几乎完全相同。

我们使用软最大交叉熵损失来训练监督图像分类任务。与通常用于分类训练的视觉变换器(ViT)相比,我们获得了具有竞争力的性能,并具有类似的规模行为,包括在ImageNet-1K上超过80%的top-1准确度,而仅使用了ViT参数的25%。

五、代码示例

ImageNet Dataset prepare

git clone https://github.com/Ma-Lab-Berkeley/CRATE.git
cd CRATE/

screen wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar --no-check-certificate && wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar --no-check-certificate

mkdir imagenet
mkdir imagenet/train && mv ILSVRC2012_img_train.tar imagenet/train/ && cd imagenet/train
tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done

cd ../..
mkdir imagenet/val && mv ILSVRC2012_img_val.tar imagenet/val/ && cd imagenet/val && tar -xvf ILSVRC2012_img_val.tar && rm -f ILSVRC2012_img_val.tar
wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash

Training CRATE on ImageNet

screen python3 main.py --arch CRATE_tiny --batch-size 512 --epochs 200 --optimizer Lion --lr 0.0002 --weight-decay 0.05 --print-freq 25 --data /data_vdb1/CRATE/imagenet

Finetuning pretrained / training random initialized CRATE on CIFAR10

cd CRATE/
python3 finetune.py 
  --bs 256 
  --net CRATE_tiny 
  --opt adamW  
  --lr 5e-5 
  --n_epochs 200 
  --randomaug 1 
  --data cifar10 
  --ckpt_dir /data_vdb1/CRATE/checkpoints
  --data_dir /data_vdb1/CRATE/imagenet

参考链接:

https://www.kaggle.com/competitions/imagenet-object-localization-challenge/data
https://cloud.google.com/tpu/docs/imagenet-setup?hl=zh-cn 
https://github.com/pytorch/examples/blob/main/imagenet/extract_ILSVRC.sh
https://stackoverflow.com/questions/64714119/valid-url-for-downloading-imagenet-dataset