vqvae的loss计算

发布时间 2023-11-24 17:32:12作者: 张博的博客
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
z_q是codebook 找到的最接近z的向量.
z是encoder生成的向量.
L对z求导 = 2(z_q.detach()-z)*(-1)=2(z - z_q.detach())     # 这个部分对于encoder做了训练.
L对z_q求导=2(z_q - z.detach())                                       #这个部分对于codebook做了训练.
所以这个detach对于变量x虽然对x不求导,但是计算其他变量时候参与计算.
 
很早之前,在RVQ那篇文章里说到过,VQ-VAE中是通过在codebook中选择欧式距离最近的embedding对应的index作为离散token的。即其中涉及到argmin操作,该操作是不可导的。因此重建loss的梯度是无法传递到encoder网络的。
如果我们写成 loss= torch.mean((z-z_q)**2)
那么L对z_q求导=2(z_q-z)
对z求导=2(z_q-z)*(-1)=2(z-z_q). 这里面的两个导数是算不了的.因为argmin不可导. 导数没法从z_q传到变量x上.(x是输入网络参数)
所以我们只能用上面的方法来计算.
 
我们上面的方法.
L对z求导 = 2(z_q.detach()-z)*(-1)=2(z - z_q.detach())     # 这个部分对于encoder做了训练. 
L对z_q求导=2(z_q - z.detach())                                       #这个部分对于codebook做了训练. 我们再计算z_q 对x求导.即可. z是没法对x求导的.
 
参考这个: https://zhuanlan.zhihu.com/p/644091516  讲的很好.