# 〈 Diffusion Model 論文研究與實作心得 Part.3 〉 模型訓練、照片修復與結果呈現 (Finale)


一、前言

在前兩篇文章〈 Diffusion Model 論文研究與實作心得 Part.1 〉 前言與圖片雜訊前處理〈 Diffusion Model 論文研究與實作心得 Part.2 〉 U-Net 模型架構介紹與實作中,我完成了資料前處理與模型的搭建。因此在Part.3(最終篇)就要來進行模型的訓練和結果呈現。

二、模型訓練

我們可以參考一下ddpm作者的sudo code,這樣對實作的步驟有很大的幫助。

我們的模型輸出是預測圖片的雜訊(對,不是修復後的圖),拿去和加在上面的雜訊進行比較。所以get_loss函數應該有三個參數,X_0,timestep和model。

def get_loss(x_0, t, model):
    pass

比較所需要的有三個東西

  1. 某個timestep的x
  2. 實際加上的雜訊
  3. 模型預測的雜訊

1.和2.可以用〈 Diffusion Model 論文研究與實作心得 Part.1 〉 前言與圖片雜訊前處理 裡定義的

def forward_diffuse_process(x_0, t):
    '''
    回傳第t個timestep的圖片和加上的雜訊
    '''
    noise = torch.randn_like(x_0) #回傳與X_0相同size的noise tensor,也就是reparameterization的epsilon
    sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[t]
    sqrt_oneminus_alphas_cumprod_t = sqrt_oneminus_alphas_cumprod[t]

    #element-wise的運算
    return sqrt_alphas_cumprod_t*x_0 + sqrt_oneminus_alphas_cumprod_t*noise, noise

這個函數會回傳前兩點需要的東西。

def get_loss(x_0, t, model):
    x_t, noise = forward_diffuse_process(x_0, t)

而3. 則需要使用我們上次架構的U-net模型

def get_loss(x_0, t, model):
    x_t, noise = forward_diffuse_process(x_0, t)
    noise_prediction = model(x_t, t)

最後對noise和noise_prediction進行比較就能得到Loss了,這邊選用L2 Loss

def get_loss(x_0, t, model):
    x_t, noise = forward_diffuse_process(x_0, t)
    noise_prediction = model(x_t, t)
    return F.l2_loss(noise, noise_prediction)

如此一來就能開始進行訓練了!

optimizer選用Adam,epochs先選用20 (colab的資源讓我一次只敢做這麼多QQ)

from torch.optim import Adam

dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
optimizer = Adam(model.parameters(), lr=0.001)
epochs = 20 # Try more!

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
      optimizer.zero_grad()

      t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
      loss = get_loss(model, batch[0], t)
      loss.backward() 
      optimizer.step()

      if epoch % 5 == 0 and step == 0:
        print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")

訓練output:

Epoch 0 | step 000 Loss: 0.8118380904197693 
Epoch 5 | step 000 Loss: 0.2767971158027649 
Epoch 10 | step 000 Loss: 0.29156017303466797 
Epoch 15 | step 000 Loss: 0.24683958292007446 
Epoch 20 | step 000 Loss: 0.22735241055488586

這個專案的心臟,Diffusion Model正式訓練完成 (感動

三、圖片修復與成果呈現

說到底,我們模型輸出的終究只是對雜訊的預測,因此還需要一點點的數學才能將這個雜訊預測用於修復原圖。

還記得第一篇提到的q(Xt|Xt-1)嗎?那是用於破壞照片的forward process,而現在的backward process(修復圖片)ddpm的論文作者使用p(Xt-1|Xt)代表。

這部分牽涉到很複雜的數學(我也不太懂),所以我就放一部份的筆記和完整數學算式的連結

總之經過一點魔法我們能透過最底下框起來的式子計算出前一步timestep的圖。

論文中作者好像"憑經驗"省略了一堆數學還得到更好的結果,所以實作的部分就依照上面的sudo code就行了。

此外,這邊會用到第一篇定義的變數,我放在下面方便理解。

# Define beta schedule
T = 300
betas = linear_beta_schedule(timesteps=T)

# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

#新定義的
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
@torch.no_grad() #記得寫這行,在sample的時候才不會逆向傳遞梯度
def sample_timestep(x, t):
    """
    給一個被破壞的圖片x和timestep,回傳修復後的圖片
    """
    #這邊基本都是按照sudo code的算式
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)

    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)

    if t == 0:
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

@torch.no_grad()
def sample_and_plot_image():
    #首先,生成隨機雜訊
    img_size = IMG_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=device)
    #這部分是用plt來呈現成果
    plt.figure(figsize=(15,15))
    plt.axis('off')
    num_images = 10
    stepsize = int(T/num_images)

    #從第T個timestep修復到第0個
    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t)
        if i % stepsize == 0:
            plt.subplot(1, num_images, int(i/stepsize+1))
            show_tensor_image(img.detach().cpu()) #第一篇的函式
    plt.show()

來看看epoch=80時候的成果:

呃...雖然有點抽象,但多少能看出類似臉、眼睛、頭髮的色塊,如果將epochs調高一點應該能得到更好的成果。

四、結語(系列總結)

第一次寫這種系列文,從資料前處裡到訓練模型,雖然省略了很多細節,很多地方可能做得不夠好,但我對自己踏出的第一步感到挺滿意的。

我之後可能會再寫一篇外傳 (?,講講怎麼改造這個模型,讓他能產出更高畫質的圖片或變成prompt-to-image模型,又或者我搞了一張顯卡把epochs跑完再來看看成果之類的。一樣,都是後話了。

相關資料

https://www.youtube.com/watch?v=a4Yfz2FxXiY
https://www.youtube.com/watch?v=HoKDTa5jHvg&t=1338s
https://huggingface.co/blog/annotated-diffusion
https://arxiv.org/pdf/2102.09672.pdf
https://arxiv.org/pdf/1503.03585.pdf
https://arxiv.org/pdf/2006.11239.pdf
https://theaisummer.com/latent-variable-models/#reparameterization-trick
https://theaisummer.com/diffusion-models/
https://brohrer.mcknote.com/zh-
https://lilianweng.github.io/posts/2021-07-11-diffusion-models/#nice

#AI #Deep Learning #Diffusion Model







你可能感興趣的文章

Cyberpunk 風格按鈕動畫

Cyberpunk 風格按鈕動畫

Leetcode 刷題 pattern - Bitwise XOR

Leetcode 刷題 pattern - Bitwise XOR

[筆記] 最重要的小事:輸入範圍

[筆記] 最重要的小事:輸入範圍






留言討論