刚硬矩阵 (2) Walsh–Hadamard 变换的 "更快" 算法

发布时间 2023-11-30 20:25:53作者: EntropyIncreaser

\(\newcommand{\sfT}{\mathsf T}\newcommand{\rank}{\operatorname{rank}}\)

为了避免歧义, 我们这里约定

\[H = \begin{bmatrix}1 & 1 \\ 1 & -1\end{bmatrix}, \]

以及 \(2^n\times 2^n\) 的 Hadamard 矩阵写作 \(H^{\otimes n}\). 令 \(N = 2^n\).

低深度电路的算法

这里我们约定一个计算 \(M\) 的深度为 \(d\) 的电路是将矩阵 \(M\) 写成乘积

\[M = A_1 A_2 \cdots A_d, \]

然后代价是所有 \(A_i\) 的非零元素个数之和.

显然我们知道的一点是, 如果我们可以做 \(n\) 层计算, 也即 \(\log N\), 那么经典 "逐位计算" 的算法, 也即 FWT, 或者理论界中称呼的 Yates 算法, 可以做到代价 \(2N\log N\). 这个算法有一个用代数上的描述, 首先注意到张量积满足

\[(A\otimes B) (C\times D) = (AC) \otimes (BD), \]

我们想计算的是 \(H^{\otimes n}\), 于是我们可以将其写成, 以 \(n=3\) 为例:

\[H^{\otimes 3} = (H\otimes I\otimes I)(I\otimes H\otimes I)(I\otimes I\otimes H). \]

如此推广, 也就是 \(H^{\otimes n} = A_1\cdots A_n\), 其中

\[A_i = I\otimes \cdots\otimes I \otimes \overbrace{H}^i\otimes I \otimes \cdots \otimes I. \]

如果我们固定一个常数 \(d \geq 1\), 然后考虑层数为 \(d\) 的电路呢?

显然, 我们可以把 \(H^{\otimes n}\) 写成 \(d\) 层, 也就是

\[A_i = I^{\otimes (n/d)(i-1)} \otimes H^{\otimes(n/d)} \otimes I^{\otimes (n/d)(d-i)}, \]

这样的话, 每层的非零元数量是 \(N^{1+1/d}\), 于是总代价是 \(dN^{1+1/d}\).

这个 \(1/d\) 的阶数是最优的吗? Pudlak 证明了, 如果 矩阵所有元素都是有界的, 那么总是需要 \(\Omega(d N^{1+1/d})\) 的代价, 这个论证的策略基本上就是想办法用非零元的数量控制最后算出来的矩阵的行列式. 不过我们这里按下不表, 因为

我们有如下结果:

定理 (Alman, 2021) 存在 \(O(d N^{1 + 0.96/d})\) 的算法, 在偶数 \(d > 1\) 层电路模型下计算 \(H^{\otimes n}\).

Remark. 论文里说的是对任意 \(d>1\), 但我发现论证好像有个 gap, 但对偶数来说是成立的.

特定矩阵的非刚性

考虑 Hadamard 矩阵

\[ H^{\otimes 4} = \begin{pmatrix} 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 \\ 1 & -1 & 1 & -1 & 1 & -1 & 1 & -1 & 1 & -1 & 1 & -1 & 1 & -1 & 1 & -1 \\ 1 & 1 & -1 & -1 & 1 & 1 & -1 & -1 & 1 & 1 & -1 & -1 & 1 & 1 & -1 & -1 \\ 1 & -1 & -1 & 1 & 1 & -1 & -1 & 1 & 1 & -1 & -1 & 1 & 1 & -1 & -1 & 1 \\ 1 & 1 & 1 & 1 & -1 & -1 & -1 & -1 & 1 & 1 & 1 & 1 & -1 & -1 & -1 & -1 \\ 1 & -1 & 1 & -1 & -1 & 1 & -1 & 1 & 1 & -1 & 1 & -1 & -1 & 1 & -1 & 1 \\ 1 & 1 & -1 & -1 & -1 & -1 & 1 & 1 & 1 & 1 & -1 & -1 & -1 & -1 & 1 & 1 \\ 1 & -1 & -1 & 1 & -1 & 1 & 1 & -1 & 1 & -1 & -1 & 1 & -1 & 1 & 1 & -1 \\ 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & -1 & -1 & -1 & -1 & -1 & -1 & -1 & -1 \\ 1 & -1 & 1 & -1 & 1 & -1 & 1 & -1 & -1 & 1 & -1 & 1 & -1 & 1 & -1 & 1 \\ 1 & 1 & -1 & -1 & 1 & 1 & -1 & -1 & -1 & -1 & 1 & 1 & -1 & -1 & 1 & 1 \\ 1 & -1 & -1 & 1 & 1 & -1 & -1 & 1 & -1 & 1 & 1 & -1 & -1 & 1 & 1 & -1 \\ 1 & 1 & 1 & 1 & -1 & -1 & -1 & -1 & -1 & -1 & -1 & -1 & 1 & 1 & 1 & 1 \\ 1 & -1 & 1 & -1 & -1 & 1 & -1 & 1 & -1 & 1 & -1 & 1 & 1 & -1 & 1 & -1 \\ 1 & 1 & -1 & -1 & -1 & -1 & 1 & 1 & -1 & -1 & 1 & 1 & 1 & 1 & -1 & -1 \\ 1 & -1 & -1 & 1 & -1 & 1 & 1 & -1 & -1 & 1 & 1 & -1 & 1 & -1 & -1 & 1 \\ \end{pmatrix}, \]

\(L = \left( \begin{array}{cccc} -1 & 1 & 1 & 1 \\ 1 & -1 & -1 & -1 \\ 1 & -1 & -1 & -1 \\ 1 & -1 & -1 & -1 \\ \end{array} \right)^{\otimes 2}\), 显然 \(\rank L = 1\). 经过计算可以发现, \(H^{\otimes 4} = L + S\), 其中

\[ S = \begin{pmatrix} 0 & 2 & 2 & 2 & 2 & 0 & 0 & 0 & 2 & 0 & 0 & 0 & 2 & 0 & 0 & 0 \\ 2 & -2 & 0 & -2 & 0 & 0 & 2 & 0 & 0 & 0 & 2 & 0 & 0 & 0 & 2 & 0 \\ 2 & 0 & -2 & -2 & 0 & 2 & 0 & 0 & 0 & 2 & 0 & 0 & 0 & 2 & 0 & 0 \\ 2 & -2 & -2 & 0 & 0 & 0 & 0 & 2 & 0 & 0 & 0 & 2 & 0 & 0 & 0 & 2 \\ 2 & 0 & 0 & 0 & -2 & 0 & 0 & 0 & 0 & 2 & 2 & 2 & -2 & 0 & 0 & 0 \\ 0 & 0 & 2 & 0 & 0 & 0 & -2 & 0 & 2 & -2 & 0 & -2 & 0 & 0 & -2 & 0 \\ 0 & 2 & 0 & 0 & 0 & -2 & 0 & 0 & 2 & 0 & -2 & -2 & 0 & -2 & 0 & 0 \\ 0 & 0 & 0 & 2 & 0 & 0 & 0 & -2 & 2 & -2 & -2 & 0 & 0 & 0 & 0 & -2 \\ 2 & 0 & 0 & 0 & 0 & 2 & 2 & 2 & -2 & 0 & 0 & 0 & -2 & 0 & 0 & 0 \\ 0 & 0 & 2 & 0 & 2 & -2 & 0 & -2 & 0 & 0 & -2 & 0 & 0 & 0 & -2 & 0 \\ 0 & 2 & 0 & 0 & 2 & 0 & -2 & -2 & 0 & -2 & 0 & 0 & 0 & -2 & 0 & 0 \\ 0 & 0 & 0 & 2 & 2 & -2 & -2 & 0 & 0 & 0 & 0 & -2 & 0 & 0 & 0 & -2 \\ 2 & 0 & 0 & 0 & -2 & 0 & 0 & 0 & -2 & 0 & 0 & 0 & 0 & 2 & 2 & 2 \\ 0 & 0 & 2 & 0 & 0 & 0 & -2 & 0 & 0 & 0 & -2 & 0 & 2 & -2 & 0 & -2 \\ 0 & 2 & 0 & 0 & 0 & -2 & 0 & 0 & 0 & -2 & 0 & 0 & 2 & 0 & -2 & -2 \\ 0 & 0 & 0 & 2 & 0 & 0 & 0 & -2 & 0 & 0 & 0 & -2 & 2 & -2 & -2 & 0 \\ \end{pmatrix}, \]

数一数, 得到 \(\|S\|_0 \leq 96\), 因此我们有

\[R_{H^{\otimes 4}}(1) \leq 96. \]

我们将利用这个事实来给出更好的算法.

非刚性矩阵给出的电路

这个想法其实很简单, 我们先放宽到一个稍微一般一点的情况, 对于一个 \(q\times q\) 的矩阵 \(M\), 可以写成 \(M = L + S\), 其中 \(\rank L \leq r\), \(\|S\|_0 \leq s\), 那么我们首先可以构造一个矩阵

\[M = U V^\sfT + S = \begin{pmatrix} U & I\end{pmatrix} \begin{pmatrix} V^\sfT \\ S\end{pmatrix}, \]

那么这给出了一个分解 \(M = AB\), 其中 \(\|A\|_0 \leq q(r+1)\), 而 \(\|B\|_0 \leq qr + s\).

类似地, 如果我们把 \(S\) 挪到前面, 就能得到 \(M = B'A'\), 其中 \(\|B'\|_0 \leq qr+s\), 而 \(\|A'\|_0 \leq q(r+1)\).

这样一来, 我们就可以写成

\[ \begin{align*} M^{\otimes 2} &= M\otimes M\\ &= (A\otimes B') \times (B \otimes A') \end{align*}, \]

易见, 每个张量积得到的矩阵的非零元数量都是

\[q(r+1)\cdot (qr+s) = q^2 \cdot(r+1)(r + s/q). \]

类似的, 对于 \(d\) 是偶数, 我们可以在对角上复制这个构造, 得到代价

\[q^d \cdot(r+1)(r + s/q). \]

现在我们得到了 \(M^{\otimes d} = A_1 \cdots A_d\), 然后就有

\[M^{\otimes n} = (A_1^{\otimes(n/d)}) \cdots (A_d^{\otimes(n/d)}), \]

总共的非零元个数之和不超过

\[d \cdot q^n \cdot [(r+1)(r+s/q)]^{n/d}, \]

换元可以得到 \(d \cdot N^{1 + c/d},\) 其中

\[c = \log_q [(r+1) (r+s/q)]. \]

带入 \(q = 16\), \(r=1\), \(s = 96\), 我们就得到了

\[c = \log_{16} 14 < 0.95184. \]

动态维护 FWT

根据本能反应, 我们可以进一步检查这个做法能不能让我们维护一个数组的单点修改, 单点查询 FWT.

注意 \(A\) 矩阵每行有 \(r+1\) 个非零元, \(B'\) 矩阵每行有 \(r+s/q\) 个非零元 (一般来说不一定, 但是我们检查发现前面构造的矩阵 \(S\) 是成立的!), 类似地, \(B\)\(A'\) 的每一列也满足类似的性质,

所以得到了 \(M^{\otimes 2} = A_1 A_2\) 之后, 有 \(A_1^{\otimes n/2}\)\(A_2^{\otimes n/2}\) 分别有 \([(r+1)(r+s/q)]^{n/2} = N^{c/2}\) 个每行, 每列非零元素.

所以我们可以做到 \(O(N^{0.476})\) 单点修改, 单点查询 FWT.

进一步的结果

现在 \(N^{1.476}\) 不是最优的结果, 注意到我们原来的构造里 \(A\) 分成 \(L\)\(S\) 两部分, 它们的非零元是数量是不太平衡的, 一些后续的工作通过在做张量积的时候一步步消除这个不平衡, 得到了更好的结果.

定理. (Alman, Guan, Padaki, 2022) Hadamard 矩阵可以用 \(O(N^{1.446})\) 大小的两层电路进行计算.

常数更小的 Walsh–Hadamard 变换

如果我们考虑只把加法和数乘看做运算的基本时间单位, 那么容易发现计算 \(H^{\otimes n} x\) 的时间, 经典的做法刚好有 \(n2^n\) 次基本运算, 也就是 \(N\log N\).

有趣的是, 这个结果也是可以被改进的.

定理. (Alman, Rao, 2023)
可以在 \(\frac{23}{24}N\log N +O(N)\) 次基本运算内计算 \(H^{\otimes n} x\). 特别地, 其中有 \(\frac{11}{12}N\log N + O(N)\) 次是加减法, \(\frac{1}{24}N\log N\) 次是 \(/2\), 和 \(O(N)\) 次数乘.

特定矩阵的非刚性

我们考虑 \(H^{\otimes 3}\),

\[ H^{\otimes 3} = \begin{pmatrix} 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 \\ 1 & -1 & 1 & -1 & 1 & -1 & 1 & -1 \\ 1 & 1 & -1 & -1 & 1 & 1 & -1 & -1 \\ 1 & -1 & -1 & 1 & 1 & -1 & -1 & 1 \\ 1 & 1 & 1 & 1 & -1 & -1 & -1 & -1 \\ 1 & -1 & 1 & -1 & -1 & 1 & -1 & 1 \\ 1 & 1 & -1 & -1 & -1 & -1 & 1 & 1 \\ 1 & -1 & -1 & 1 & -1 & 1 & 1 & -1 \\ \end{pmatrix} \]

它可以分解成 \(L+S\), 其中

\[ L = \begin{pmatrix} 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 \\ 1 & -1 & -1 & -1 & -1 & -1 & -1 & -1 \\ 1 & -1 & -1 & -1 & -1 & -1 & -1 & -1 \\ 1 & -1 & -1 & -1 & -1 & -1 & -1 & -1 \\ 1 & -1 & -1 & -1 & -1 & -1 & -1 & -1 \\ 1 & -1 & -1 & -1 & -1 & -1 & -1 & -1 \\ 1 & -1 & -1 & -1 & -1 & -1 & -1 & -1 \\ 1 & -1 & -1 & -1 & -1 & -1 & -1 & -1 \end{pmatrix}, \]

\[ S = \begin{pmatrix} 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 2 & 0 & 2 & 0 & 2 & 0 \\ 0 & 2 & 0 & 0 & 2 & 2 & 0 & 0 \\ 0 & 0 & 0 & 2 & 2 & 0 & 0 & 2 \\ 0 & 2 & 2 & 2 & 0 & 0 & 0 & 0 \\ 0 & 0 & 2 & 0 & 0 & 2 & 0 & 2 \\ 0 & 2 & 0 & 0 & 0 & 0 & 2 & 2 \\ 0 & 0 & 0 & 2 & 0 & 2 & 2 & 0 \end{pmatrix}. \]

我们看到, \(\rank L = 2\), 而 \(\|S\|_0 = 21\).

算法

注意, 原始的计算 \(H^{\otimes 3}\) 的算法使用了 \(24\) 次运算, 原则上说我们的目的是在 \(23\) 次运算内解决.

一个直觉是, \(S\) 这部分有一堆乘以 \(2\), 我们可以把这些 \(\times 2\) 操作能不做的都不做, 对齐到最后的部分.

首先, 我们 \(Lx\) 的部分首先希望算出 \(x_1 + \cdots + x_7\), 注意到

\[ S = \begin{pmatrix} 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 2 & 0 & \textcolor{blue} 2 & 0 & 2 & 0 \\ 0 & 2 & 0 & 0 & \textcolor{blue} 2 & 2 & 0 & 0 \\ 0 & 0 & 0 & \textcolor{red} 2 & \textcolor{blue} 2 & 0 & 0 & \textcolor{red} 2 \\ 0 & \textcolor{red} 2 & \textcolor{red} 2 & \textcolor{blue} 2 & 0 & 0 & 0 & 0 \\ 0 & 0 & 2 & 0 & 0 & 2 & 0 & \textcolor{blue} 2 \\ 0 & 2 & 0 & 0 & 0 & 0 & 2 & \textcolor{blue} 2 \\ 0 & 0 & 0 & \textcolor{blue} 2 & 0 & \textcolor{red} 2 & \textcolor{red} 2 & 0 \end{pmatrix}. \]

红色的这三个对子需要先被计算, 需要 \(3\) 次运算.

那么只需要把 \(x_1 + \cdots + x_7\) 加起来, 因为有了前面的对子, 需要 \(3\) 次运算.

然后记 \(t = x_0 - (x_1 + \cdots + x_7)\), 需要 \(1\) 次计算得到. 另外 \(x_0 + (x_1 + \cdots + x_7)\) 也只需要额外的 \(1\) 次运算.

然后我们选取三列覆盖每一行, 也即图中蓝色的这些列, 也就是计算出 \(t + 2x_3, t + 2x_4, t+ 2x_7\), 但注意我们最后会发现这些\(2\) 操作是会被省略掉的, 所以现在先把它们当做 \(3\) 次操作.

然后我们可以拼出剩下的 \(7\) 个数了, 红色的那四行只需要一次加法, 没有红色的那三行需要两次加法, 一共 \(11\) 次加法.

我们看看最后总共有多少次运算: \(3 + 3 + 1 + 1 + 3 + 9 = 22\).

但是我们还没有修补刚刚提到的问题: 我们中间很多运算是形如 \(x + 2y\), 但我们只当做了一次运算. 我们的策略是这样的: 实际上, 我们希望计算 \(\frac 12 Hx\), 最后再把错误的幂都乘回来.

于是我们实际应该计算的是 \(t' = t/2\), 这实际上需要 \(2\) 次操作了. 然后 \(\frac 12 (Hx)_0 = \frac 12(x_0 + (x_1 + \cdots + x_7)) = t' + (x_1+\cdots+x_7)\), 这个依然只需要 \(1\) 次操作.

然后所有的 \(t' + x_3, t' + x_4, t' + x_7\) 什么的都是正确的了. 我们可以正确地在前述时间内计算出所有 \(\frac 12 (Hx)_i\).

这样, 我们就在 \(\frac{23}{24} N\log N\) 次运算内计算出了 \(\frac 1{2^{n/3}}H^{\otimes n} x\). 最后逐点乘以 \(2^{n/3}\) 就可以了.

进一步的结果

Alman 和 Rao 还在同一篇文章里进一步得到了 \(3.75N\log N + O(N)\) 次复数运算的 FFT, 与之相比, 之前最优的结果是 \(3.76875N\log N + O(N)\). 这里就不介绍了.