月小白
Classifier-guidance and -free diffusion

Classifier-guidance and -free diffusion

背景

DDPM和SMLD都是没有条件引导的,因此最终生成的结果并非可控,但是我们在宋飏博士的SDE中其实已经知道,可以通过控制分数来实现对最终目标的把控。根据贝叶斯定理,我们有

第一项我们知道是 11αˉtϵ-\dfrac{1}{\sqrt{1-\bar\alpha_t}}\epsilon ,而第二项我们可以通过训练一个分类器来求到

Classifier Guidance

因此呢,就有了工作【1】,他们将采样过程设计为

pθ,ϕ(xtxt+1,y)=Zpθ(xxt+1)pϕ(yxt)p_{\theta,\phi}(x_t|x_{t+1},y)=Z\cdot p_\theta(x|x_{t+1})p_\phi(y|x_t)

其中, ZZ 是一个归一化常量,这其实和前面的公式是一致的

文中指出,对于DDIM(DDPM也一样)有

ϵ^(xt)=ϵθ(xt)1αˉtxtlogpϕ(yxt)\hat\epsilon(x_t)=\epsilon_\theta(x_t)-\sqrt{1-\bar\alpha_t}\nabla_{x_t}\log p_\phi(y|x_t)

然后用这个去进行采样就行

在官方代码guided-diffusion/scripts/classifier_sample.py中,有这么一段对应的代码实现

1
2
3
4
5
6
7
8
9
10
11
12
def cond_fn(x, t, y=None):
assert y is not None
# th = torch
with th.enable_grad():
x_in = x.detach().requires_grad_(True)
# get logits
logits = classifier(x_in, t)
log_probs = F.log_softmax(logits, dim=-1)
# get the true label prediction
selected = log_probs[range(len(logits)), y.view(-1)]
# scale the grad
return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale

其中最后一行给分类器的梯度乘了一个缩放因子,实验表明,缩放因子越大,重建效果越好

对于SMLD和SDE,也是一样的,在原来的分数基础上再加上分类器的梯度方向即可

SMLD:

xk+1xk+ϵ22(xlogpθ(xk)+xlogpθ(yxk))+ϵzk,k=0,1,...,K1x^{k+1}\leftarrow x^k+\dfrac{\epsilon^2}{2} (\nabla_x\log p_\theta(x^k) + \nabla_x\log p_\theta(y|x^k))+\epsilon z^k, k=0,1,...,K-1

SDE:

dx=[f(x,t)g(t)2(xlogpt(x)+xlogpt(yx))]dt+g(t)dwˉ\mathsf{d}x=[f(x,t)-g(t)^2 (\nabla_x\log p_t(x) + \nabla_x\log p_t(y|x))]\mathsf{d}t+g(t)\mathsf{d}\bar w

Classifier-free

但用一个显式的分类器引导生成有几个比较大的问题:一个是你需要额外训练一个在噪声图像上做分类判别的分类器,费时费力。其次该分类器的质量决定了你按类别生成的效果。

而由此引申的最重要的一个缺陷是,这样的一个分类器在多样性和生成效果此消彼长的权衡博弈,并没有直接反映到我们的评判测度上(例如IS这个测度的输出本身就依赖于分类器)。

论文【2】指出了这一问题,并通过下图进行了直观展示

guidance.png

对于一个来自于三个高斯分布混合而成的分布,我们通过分类器引导的采样过程导致了采样结果严重受限于该分布的局部领域,且分类器引导强度越强,远离其他类别的质心的表现越明显,使得结果越加集中在局部空间。

为了避免这一问题,论文【2】提出了classifier-free的引导方式

首先,在SMLD中,我们有

x~log[qσ(x~x)]=ϵσ\nabla_{\tilde x} \log[q_\sigma(\tilde x | x)]=-\dfrac{\epsilon}{\sigma}

我们改写一下

ϵθ(x~)σx~logp(x~)\epsilon_\theta(\tilde x) \approx -\sigma \nabla_{\tilde x}\log p(\tilde x)

这和论文中的公式其实是一样的,都是对噪声的估计

然后我们根据【1】可以知道有分类器引导的梯度等于原梯度加上分类器的梯度,假设上面的缩放比例为 w+1w+1 ,则有:

ϵθ(x~)(w+1)σx~logpθ(yx~)σx~[logp(x~)+(w+1)logpθ(yx~)]=σx~[logp(x~y)+wlogpθ(yx~)]\epsilon_\theta(\tilde x)-(w+1)\sigma \nabla_{\tilde x}\log p_\theta(y|\tilde x) \approx - \sigma \nabla_{\tilde x} [\log p(\tilde x) +(w+1)\log p_\theta(y|\tilde x)] \\ = - \sigma \nabla_{\tilde x} [\log p(\tilde x | y) + w\log p_\theta(y|\tilde x)]

这个公式表明,对无条件生成的网络以 w+1w+1 的强度进行有分类器引导的扩散生成等价于对有条件生成的网络以 ww 的强度进行有分类器引导。

我们再通过背景中的公式对 x~logpθ(yx~)\nabla_{\tilde x}\log p_\theta(y|\tilde x) 进行代换,得到

x~logpθ(yx~)=x~logp(x~y)x~logp(x~)\nabla_{\tilde x}\log p_\theta(y|\tilde x) = \nabla_{\tilde x}\log p(\tilde x|y) - \nabla_{\tilde x}\log p(\tilde x)

带入进去,最终得到

ϵ~(x~y)=(1+w)ϵθ(x~y)wϵθ(x~)\tilde\epsilon(\tilde x|y)=(1+w)\epsilon_\theta(\tilde x|y)-w\epsilon_\theta(\tilde x)

第一项是基于条件的梯度预估模型,第二项是无条件生成的梯度预估模型。而我们甚至可以使用同一个模型同时表示两者,只需要在无条件生成时将条件向量设置为零。而如果不想要让模型生成具有某一条件的特征,那设置 ww 为负就可以。

一方面这大大减轻了条件生成的训练代价(无需训练额外的分类器,只需要在训练时进行随机drop out condition来同时训练两个目标),另一方面这样的条件生成并不是以一个类似于对抗攻击(博弈)的方式进行的。

参考资料

[1] Diffusion Models Beat GANs on Image Synthesis

[2] Classifier-Free Diffusion Guidance

[3] 浅谈扩散模型的有分类器引导和无分类器引导 - 知乎 (zhihu.com)

[4] 条件控制扩散模型(Classifier-Guidance,Classifier-Free)基础理论总结 - 知乎 (zhihu.com)

本文作者:月小白
本文链接:http://example.com/2023/05/04/guidance/
版权声明:本文采用 CC BY-NC-SA 3.0 CN 协议进行许可