CDDM公式汇总


rayleigh信道下MMSE均衡:

x_t = x_hat * (torch.conj(h)) / (torch.abs(h) ** 2 + sigma_square_fix)

数学解释:

x̂ = y × h* / (|h|² + σ²)

符号说明:

  • :估计的发送信号( x_t
  • y:接收信号
  • h*:信道增益的复共轭(torch.conj(h)
  • |h|²:信道增益的模长平方(torch.abs(h) ** 2
  • σ²:噪声方差(sigma_square_fix

CDDM中实现了从信道条件(SNR)到扩散时间步(t)的智能映射,这是CDDM相比传统扩散模型的重要创新,下面是线性调度相关代码。

self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double()) #beta_1=1e-4 beta_T=0.02 T=1000 
alphas = 1. - self.betas
alphas_bar = torch.cumprod(alphas, dim=0) #累乘
self.register_buffer('snr', -10 * torch.log10((1 - alphas_bar) / alphas_bar))

符号说明:

β_t = 噪声调度参数 (从 β_1β_T 线性增长)
α_t = 1 - β_t
ᾱ_t = ∏(i=1 to t) α_i # 累积乘积

数值举例:

import torch

T = 1000
beta_1, beta_T = 1e-4, 0.02

# 1. 噪声调度
betas = torch.linspace(beta_1, beta_T, T)
# betas = [0.0001, 0.0001, ..., 0.0199, 0.0200]

# 2. 计算alphas
alphas = 1. - betas
# alphas = [0.9999, 0.9999, ..., 0.9801, 0.9800]

# 3. 累积乘积
alphas_bar = torch.cumprod(alphas, dim=0)
# alphas_bar[0] ≈ 0.9999
# alphas_bar[500] ≈ 0.5
# alphas_bar[999] ≈ 0.001

# 4. 计算SNR
snr = -10 * torch.log10((1 - alphas_bar) / alphas_bar)

# 可以打印部分关键点
key_steps = [0, 100, 200, 300, 400, 500, 600, 700, 800, 900, 999]
for step in key_steps:
    print(f"时间步 {step:3d}: SNR = {snr[step].item():8.4f}")

# 时间步   0: SNR =  39.9988
# 时间步 100: SNR =   9.3129
# 时间步 200: SNR =   2.8101
# 时间步 300: SNR =  -1.8696
# 时间步 400: SNR =  -6.1972
# 时间步 500: SNR = -10.7387
# 时间步 600: SNR = -15.8106
# 时间步 700: SNR = -21.6016
# 时间步 800: SNR = -28.2108
# 时间步 900: SNR = -35.6813
# 时间步 999: SNR = -43.9405

逆映射过程:给定信道SNR → 在预计算的SNR数组中找最接近值 → 返回对应时间步

def match_snr_t(self, snr):
        out = torch.argmin(torch.abs(self.snr - snr))
        return out

去噪过程:

if time_step > 0:
                x_t = extract(self.sqrt_alphas_bar, t - 1, x_t.shape) / extract(self.sqrt_alphas_bar, t, 	                 x_t.shape) * (x_t - extract(self.sqrt_one_minus_alphas_bar, t, x_t.shape) * eps) + 
                extract(self.sqrt_one_minus_alphas_bar, t - 1, x_t.shape) * eps
else:
                x_t = (x_t - extract(self.sqrt_one_minus_alphas_bar, t, x_t.shape) * eps) / extract(
                self.sqrt_alphas_bar, t, x_t.shape)

数学解释:

x_{t-1} = (√ᾱ_{t-1} / √ᾱ_t) × (x_t - √(1-ᾱ_t) × ε_θ) + √(1-ᾱ_{t-1}) × ε_θ

系数A: √ᾱ_{t-1} / √ᾱ_t

  • 时间步间的信号缩放因子
  • 补偿不同时间步的信号强度差异

项B: x_t - √(1-ᾱ_t) × ε_θ

  • 从当前状态中减去预测的噪声
  • 得到”去噪后的信号”

项C: √(1-ᾱ_{t-1}) × ε_θ

  • 添加适量的噪声到下一时间步
  • 保持扩散过程的随机性

文章作者: ycx
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 ycx !
评论
  目录