DDPM Implementation using UNet Architecture

Generative Models
Author

Guntas Singh Saran

Published

June 18, 2024

import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.nn.functional as F
import torchvision.utils as vutils
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models


import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

from einops import rearrange

from latex import latexify
latexify(columns = 2)

%matplotlib inline
%config InlineBackend.figure_format = "retina"

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)
mps

Forward Process

\[ \textbf{x}_t = \sqrt{1 - \beta_t} \textbf{x}_{t - 1} + \sqrt{\beta_t} \epsilon_{t - 1} \]

\[ \boxed{\textbf{x}_t = \sqrt{\bar{\alpha_t}} \textbf{x}_0 + \sqrt{1 - \bar{\alpha_t}} \epsilon} \]

\[ \boxed{q(\textbf{x}_t | \textbf{x}_0) = \mathcal{N}(\textbf{x}_t; \sqrt{\bar{\alpha_t}} \textbf{x}_0, (1 - \bar{\alpha_t}) \mathbb{I})} \]

Reverse Distribution

\[ \boxed{q(\textbf{x}_{t - 1} | \textbf{x}_t, \textbf{x}_0) = \mathcal{N}(\textbf{x}_{t - 1}; \boldsymbol{\mu}_q(\textbf{x}_0, \textbf{x}_t), \mathbf{\Sigma}_q(t))} \]

\[ \boxed{\boldsymbol{\mu}_q(\textbf{x}_t, \textbf{x}_0) = \frac{(1 - \bar{\alpha}_{t - 1}) \sqrt{\alpha_t}}{1 - \bar{\alpha}_t} \textbf{x}_t + \frac{(1 - \alpha_t)\sqrt{\bar{\alpha}_{t - 1}}}{1 - \bar{\alpha}_t}\textbf{x}_0} \]

\[ \boxed{\boldsymbol{\mu}_q(t) = \frac{1}{\sqrt{\alpha_t}} \left(\textbf{x}_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \boldsymbol{\epsilon}_t \right)} \]

Loss Function

\[ L_{VN} = \sum_{t = 2}^T D_{KL} \left(p_{\theta}(\textbf{x}_{t - 1} | \textbf{x}_t) \parallel q(\textbf{x}_{t - 1} | \textbf{x}_t, \textbf{x}_0) \right) + \log p_{\theta}(\textbf{x}_0 | \textbf{x}_1) \]

\[ p_{\theta}(\textbf{x}_{t - 1} | \textbf{x}_t) = \mathcal{N}(\textbf{x}_{t - 1}; \boldsymbol{\mu}_{\theta}(\textbf{x}_t, t), \mathbf{\Sigma}_{\theta}(\textbf{x}_t, t)) \]

\[ L_t = \mathbb{E}_{\textbf{x}_0, \boldsymbol{\epsilon}} \left[ \frac{1}{2 \lVert \mathbf{\Sigma}_{\theta}(\textbf{x}_t, t) \rVert^2_2} \lVert \boldsymbol{\mu}_{\theta}(\textbf{x}_t, t) - \boldsymbol{\mu}_t(\textbf{x}_t, \textbf{x}_0) \rVert^2 \right] \]

\[ \boldsymbol{\mu}_{\theta}(\textbf{x}_t, t) = \frac{1}{\sqrt{\alpha_t}} \left(\textbf{x}_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \boldsymbol{\epsilon}_{\theta}(\textbf{x}_t, t) \right) \]

\[ L_t = \mathbb{E}_{\textbf{x}_0, \boldsymbol{\epsilon}} \left[ \frac{(1 - \alpha_t)^2}{2 \alpha_t (1 - \bar{\alpha}_t) \lVert \mathbf{\Sigma}_{\theta} \rVert^2_2} \lVert \boldsymbol{\epsilon}_{\theta}(\textbf{x}_t, t) - \boldsymbol{\epsilon}_t \rVert^2 \right] \]

\[ \color{OrangeRed}{\boxed{L_t^{\text{Simple}} = \mathbb{E}_{t \sim [1, T], \textbf{x}_0, \boldsymbol{\epsilon}_t} \left[\lVert \boldsymbol{\epsilon}_{\theta}((\sqrt{\bar{\alpha_t}}) \textbf{x}_0 + (\sqrt{1 - \bar{\alpha_t}})\boldsymbol{\epsilon}_t , t) - \boldsymbol{\epsilon}_t \rVert^2 \right]}} \]

Noise Scheduler

Task1: Getting the noisy image given \(\textbf{x}_t, t, \epsilon\)

\[ \textbf{x}_t, t, \epsilon \to \textbf{x}_t = \sqrt{\bar{\alpha_t}} \textbf{x}_0 + \sqrt{1 - \bar{\alpha_t}} \epsilon \]

\[ \alpha_t = 1 - \beta_t \]

\[ \bar{\alpha}_t = \prod_{i = 1}^t \alpha_i \]

We use linear noise scheduler with \(\beta_1 = 10^{-4}\) to \(\beta_T = 0.02\) and \(T = 1000\) steps between them.

Task2: Given \(\textbf{x}_t\) get \(\textbf{x}_{t - 1}\)

\[ p_{\theta}(\textbf{x}_{t - 1} | \textbf{x}_t) = \mathcal{N}(\textbf{x}_{t - 1}; \boldsymbol{\mu}_{\theta}(\textbf{x}_t, t), \mathbf{\Sigma}_{\theta}(\textbf{x}_t, t)) \]

\[ \boldsymbol{\mu}_{\theta} = \frac{1}{\sqrt{\alpha_t}} \left(\textbf{x}_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \boldsymbol{\epsilon}_{\theta} \right) \]

\[ \Sigma_q(t) = \frac{(1 - \alpha_t) \cdot (1 - \bar{\alpha}_{t - 1})}{(1 - \bar{\alpha}_t)} \mathbb{I} \]

\[ \begin{align} \textbf{x}_{t - 1} = \boldsymbol{\mu}_{\theta} + \sigma_t \textbf{z} && \textbf{z} \sim \mathcal{N}(\textbf{0}, \mathbb{I}) \end{align} \]

class LinearNoiseScheduler:
    def __init__(self, T, beta_start, beta_end):
        self.T = T
        self.beta_start = beta_start
        self.beta_end = beta_end

        self.betas = torch.linspace(beta_start, beta_end, T)
        self.alphas = 1 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim = 0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_1m_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)

    def add_noise(self, original, noise, t):
        original_shape = original.shape
        batch_size = original_shape[0]

        sqrt_alpha_cum_prod = self.sqrt_alphas_cumprod.to(original.device)[t].reshape(batch_size)
        sqrt_one_minus_alpha_cum_prod = self.sqrt_1m_alphas_cumprod.to(original.device)[t].reshape(batch_size)

        for _ in range(len(original_shape) - 1):
            sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
        for _ in range(len(original_shape) - 1):
            sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)

        return (sqrt_alpha_cum_prod.to(original.device) * original
                + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise)

    def sample_prev_timestep(self, xt, noise_pred, t):
        x0 = (xt - (self.sqrt_alphas_cumprod.to(xt.device)[t] * noise_pred)) / self.sqrt_alphas_cumprod.to(xt.device)[t]
        x0 = torch.clamp(x0, -1., 1.)

        mean = (xt - ((self.betas.to(xt.device)[t] * noise_pred) / (self.sqrt_1m_alphas_cumprod.to(xt.device)[t]))) / torch.sqrt(self.alphas.to(xt.device)[t])

        if t == 0:
            return mean, x0
        else:
            variance = ((1 - self.alphas_cumprod.to(xt.device)[t - 1]) * self.betas.to(xt.device)[t]) / (1. - self.alphas_cumprod.to(xt.device)[t])
            sigma = variance ** 0.5
            z = torch.randn(xt.shape).to(xt.device)
            return mean + sigma * z, x0

Model Architecture - UNet inspired

\[ \text{DownBlock} \to \text{MidBlock} \to \text{UpBlock} \]

Time Embedding Block

Positional Encoding for a position \(t\) of an object, \(d\) dimension of output embedding space, \(i\) represents column index

\[ P(t, 2i) = \sin \left( \frac{t}{10000^{2i/d_{\text{model}}}} \right) \] \[ P(t, 2i + 1) = \cos \left( \frac{t}{10000^{2i/d_{\text{model}}}} \right) \]

Positional Encoding Block \(\to\) FC Layer \(\to\) \(\text{SiLU}\) Activation \(\to\) FC Layer

def get_time_embedding(T, d_model):
    factor = 10000 ** ((torch.arange(start = 0, end = d_model // 2, dtype = torch.float32, device = T.device)) / (d_model // 2))
    t_emb = T[:, None].repeat(1, d_model // 2) / factor
    t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim = -1)
    return t_emb

DownBlock of UNet

There can be multiple layers of this \(\text{ResNet} + \text{Self Attention}\)

image

image

class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, d_model, down_sample, num_heads):
        super().__init__()
        self.down_sample = down_sample
        self.resnet_conv_first = nn.Sequential(
            nn.GroupNorm(8, in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
        )

        self.t_emb_layers = nn.Sequential(
            nn.SiLU(),
            nn.Linear(d_model, out_channels)
        )

        self.resnet_conv_second = nn.Sequential(
            nn.GroupNorm(8, out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
        )

        self.attention_norm = nn.GroupNorm(8, out_channels)
        self.attention = nn.MultiheadAttention(out_channels, num_heads, batch_first = True)
        self.residual_input_conv = nn.Conv2d(in_channels, out_channels, kernel_size = 1)
        self.down_sample_conv = nn.Conv2d(out_channels, out_channels, kernel_size = 4, stride = 2, padding = 1) if self.down_sample else nn.Identity()

    def forward(self, x, t_emb):
        out = x

        # ResNet Block
        resnet_input = out
        out = self.resnet_conv_first(out)
        out = out + self.t_emb_layers(t_emb)[:, :, None, None]
        out = self.resnet_conv_second(out)
        out = out + self.residual_input_conv(resnet_input)

        # Attention Block
        batch_size, channels, h, w = out.shape
        in_attn = out.reshape(batch_size, channels, h * w)
        in_attn = self.attention_norm(in_attn)
        in_attn = in_attn.transpose(1, 2)
        out_attn, _ = self.attention(in_attn, in_attn, in_attn)
        out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
        out = out + out_attn

        out = self.down_sample_conv(out)
        return out
class MidBlock(nn.Module):
    def __init__(self, in_channels, out_channels, d_model, num_heads):
        super().__init__()
        self.resnet_conv_first = nn.ModuleList([
            nn.Sequential(
                nn.GroupNorm(8, in_channels),
                nn.SiLU(),
                nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
            ),
            nn.Sequential(
                nn.GroupNorm(8, out_channels),
                nn.SiLU(),
                nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
            )
        ])

        self.t_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(d_model, out_channels)
            ),
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(d_model, out_channels)
            )
        ])


        self.resnet_conv_second = nn.ModuleList([
            nn.Sequential(
                nn.GroupNorm(8, out_channels),
                nn.SiLU(),
                nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
            ),
            nn.Sequential(
                nn.GroupNorm(8, out_channels),
                nn.SiLU(),
                nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
            )
        ])

        self.attention_norm = nn.GroupNorm(8, out_channels)
        self.attention = nn.MultiheadAttention(out_channels, num_heads, batch_first = True)

        self.residual_input_conv = nn.ModuleList([
            nn.Conv2d(in_channels, out_channels, kernel_size = 1),
            nn.Conv2d(out_channels, out_channels, kernel_size = 1)
        ])

    def forward(self, x, t_emb):
        out = x

        # First ResNet Block
        resnet_input = out
        out = self.resnet_conv_first[0](out)
        out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
        out = self.resnet_conv_second[0](out)
        out = out + self.residual_input_conv[0](resnet_input)

        # Attention Block
        batch_size, channels, h, w = out.shape
        in_attn = out.reshape(batch_size, channels, h * w)
        in_attn = self.attention_norm(in_attn)
        in_attn = in_attn.transpose(1, 2)
        out_attn, _ = self.attention(in_attn, in_attn, in_attn)
        out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
        out = out + out_attn

        # Second ResNet Block
        resnet_input = out
        out = self.resnet_conv_first[1](out)
        out = out + self.t_emb_layers[1](t_emb)[:, :, None, None]
        out = self.resnet_conv_second[1](out)
        out = out + self.residual_input_conv[1](resnet_input)

        return out
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, d_model, up_sample, num_heads):
        super().__init__()
        self.up_sample = up_sample
        self.resnet_conv_first = nn.Sequential(
            nn.GroupNorm(8, in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
        )

        self.t_emb_layers = nn.Sequential(
            nn.SiLU(),
            nn.Linear(d_model, out_channels)
        )

        self.resnet_conv_second = nn.Sequential(
            nn.GroupNorm(8, out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
        )

        self.attention_norm = nn.GroupNorm(8, out_channels)
        self.attention = nn.MultiheadAttention(out_channels, num_heads, batch_first = True)
        self.residual_input_conv = nn.Conv2d(in_channels, out_channels, kernel_size = 1)
        self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size = 4, stride = 2, padding = 1) if self.up_sample else nn.Identity()

    def forward(self, x, out_down, t_emb):
        x = self.up_sample_conv(x)
        x = torch.cat([x, out_down], dim = 1)

        # ResNet Block
        out = x
        resnet_input = out
        out = self.resnet_conv_first(out)
        out = out + self.t_emb_layers(t_emb)[:, :, None, None]
        out = self.resnet_conv_second(out)
        out = out + self.residual_input_conv(resnet_input)

        # Attention Block
        batch_size, channels, h, w = out.shape
        in_attn = out.reshape(batch_size, channels, h * w)
        in_attn = self.attention_norm(in_attn)
        in_attn = in_attn.transpose(1, 2)
        out_attn, _ = self.attention(in_attn, in_attn, in_attn)
        out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
        out = out + out_attn

        return out
class UNet(nn.Module):
    def __init__(self, im_channels):
        super().__init__()
        self.down_channels = [32, 64, 128, 256]
        self.mid_channels = [256, 256, 128]
        self.d_model = 128
        self.down_sample = [True, True, False]

        self.t_proj = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.SiLU(),
            nn.Linear(self.d_model, self.d_model)
        )

        self.up_sample = list(reversed(self.down_sample))
        self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size = 3, padding = (1, 1))

        self.downs = nn.ModuleList([])
        for i in range(len(self.down_channels) - 1):
            self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], self.d_model, down_sample = self.down_sample[i], num_heads = 4))

        self.mids = nn.ModuleList([])
        for i in range(len(self.mid_channels) - 1):
            self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], self.d_model, num_heads = 4))

        self.ups = nn.ModuleList([])
        for i in reversed(range(len(self.down_channels) - 1)):
            self.ups.append(UpBlock(self.down_channels[i] * 2, self.down_channels[i - 1] if i != 0 else 16, self.d_model, up_sample = self.down_sample[i], num_heads = 4))

        self.norm_out = nn.GroupNorm(8, 16)
        self.conv_out = nn.Conv2d(16, im_channels, kernel_size = 3, padding = 1)


    def forward(self, x, t):
        out = self.conv_in(x)
        t_emb = self.t_proj(get_time_embedding(torch.as_tensor(t).long(), self.d_model))

        down_outs = []
        for down in self.downs:
            down_outs.append(out)
            out = down(out, t_emb)

        for mid in self.mids:
            out = mid(out, t_emb)

        for up in self.ups:
            down_out = down_outs.pop()
            out = up(out, down_out, t_emb)

        out = self.norm_out(out)
        out = nn.SiLU()(out)
        out = self.conv_out(out)

        return out
# Train Parameters
batch_size = 64
num_epochs = 40
lr = 1e-4
num_grid_rows = 10
num_samples = 100

# Diffusion Parameters
beta_start = 1e-4
beta_end = 0.02
T = 1000

# Model Parameters
nc = 1
image_size = 28
transform = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([0.5 for _ in range(nc)], [0.5 for _ in range(nc)])])
mnist_train = datasets.MNIST(root = "../CVDatasets", train = True, transform = transform, download = True)
mnist_test = datasets.MNIST(root = "../CVDatasets", train = False, transform = transform, download = True)

mnist_combined_loader = DataLoader(dataset = mnist_train + mnist_test, batch_size = batch_size, shuffle = True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scheduler = LinearNoiseScheduler(T, beta_start, beta_end)

model = UNet(nc).to(device)
model.train()
UNet(
  (t_proj): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): SiLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
  )
  (conv_in): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (downs): ModuleList(
    (0): DownBlock(
      (resnet_conv_first): Sequential(
        (0): GroupNorm(8, 32, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (t_emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=128, out_features=64, bias=True)
      )
      (resnet_conv_second): Sequential(
        (0): GroupNorm(8, 64, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (attention_norm): GroupNorm(8, 64, eps=1e-05, affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
      )
      (residual_input_conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
      (down_sample_conv): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    )
    (1): DownBlock(
      (resnet_conv_first): Sequential(
        (0): GroupNorm(8, 64, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (t_emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=128, out_features=128, bias=True)
      )
      (resnet_conv_second): Sequential(
        (0): GroupNorm(8, 128, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (attention_norm): GroupNorm(8, 128, eps=1e-05, affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (residual_input_conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
      (down_sample_conv): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    )
    (2): DownBlock(
      (resnet_conv_first): Sequential(
        (0): GroupNorm(8, 128, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (t_emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=128, out_features=256, bias=True)
      )
      (resnet_conv_second): Sequential(
        (0): GroupNorm(8, 256, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (attention_norm): GroupNorm(8, 256, eps=1e-05, affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      )
      (residual_input_conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
      (down_sample_conv): Identity()
    )
  )
  (mids): ModuleList(
    (0): MidBlock(
      (resnet_conv_first): ModuleList(
        (0-1): 2 x Sequential(
          (0): GroupNorm(8, 256, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (t_emb_layers): ModuleList(
        (0-1): 2 x Sequential(
          (0): SiLU()
          (1): Linear(in_features=128, out_features=256, bias=True)
        )
      )
      (resnet_conv_second): ModuleList(
        (0-1): 2 x Sequential(
          (0): GroupNorm(8, 256, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (attention_norm): GroupNorm(8, 256, eps=1e-05, affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      )
      (residual_input_conv): ModuleList(
        (0-1): 2 x Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (1): MidBlock(
      (resnet_conv_first): ModuleList(
        (0): Sequential(
          (0): GroupNorm(8, 256, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (1): Sequential(
          (0): GroupNorm(8, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (t_emb_layers): ModuleList(
        (0-1): 2 x Sequential(
          (0): SiLU()
          (1): Linear(in_features=128, out_features=128, bias=True)
        )
      )
      (resnet_conv_second): ModuleList(
        (0-1): 2 x Sequential(
          (0): GroupNorm(8, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (attention_norm): GroupNorm(8, 128, eps=1e-05, affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (residual_input_conv): ModuleList(
        (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
      )
    )
  )
  (ups): ModuleList(
    (0): UpBlock(
      (resnet_conv_first): Sequential(
        (0): GroupNorm(8, 256, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (t_emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=128, out_features=64, bias=True)
      )
      (resnet_conv_second): Sequential(
        (0): GroupNorm(8, 64, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (attention_norm): GroupNorm(8, 64, eps=1e-05, affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
      )
      (residual_input_conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
      (up_sample_conv): Identity()
    )
    (1): UpBlock(
      (resnet_conv_first): Sequential(
        (0): GroupNorm(8, 128, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (t_emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=128, out_features=32, bias=True)
      )
      (resnet_conv_second): Sequential(
        (0): GroupNorm(8, 32, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (attention_norm): GroupNorm(8, 32, eps=1e-05, affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
      )
      (residual_input_conv): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1))
      (up_sample_conv): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    )
    (2): UpBlock(
      (resnet_conv_first): Sequential(
        (0): GroupNorm(8, 64, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (t_emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=128, out_features=16, bias=True)
      )
      (resnet_conv_second): Sequential(
        (0): GroupNorm(8, 16, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (attention_norm): GroupNorm(8, 16, eps=1e-05, affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=16, out_features=16, bias=True)
      )
      (residual_input_conv): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1))
      (up_sample_conv): ConvTranspose2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    )
  )
  (norm_out): GroupNorm(8, 16, eps=1e-05, affine=True)
  (conv_out): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
optimizer = torch.optim.Adam(model.parameters(), lr = lr)
criterion = nn.MSELoss()

for epoch in range(num_epochs):
    losses = []
    for i, (images, _) in enumerate(mnist_combined_loader):
        optimizer.zero_grad()
        images = images.float().to(device)

        noise = torch.randn_like(images).to(device)

        t = torch.randint(0, T, (images.shape[0],)).to(device)

        noisy_img = scheduler.add_noise(images, noise, t)
        noise_pred = model(noisy_img, t)

        loss = criterion(noise_pred, noise)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {np.mean(losses)}")
    torch.save(model.state_dict(), "denoiseModel.pth")
Epoch [1/40], Loss: 0.08707543365466638
Epoch [2/40], Loss: 0.03461716449402368
Epoch [3/40], Loss: 0.030367023231533154
Epoch [4/40], Loss: 0.028854768694305224
Epoch [5/40], Loss: 0.027427057980796428
Epoch [6/40], Loss: 0.026461833401819797
Epoch [7/40], Loss: 0.02577751178645156
Epoch [8/40], Loss: 0.0254848291399383
Epoch [9/40], Loss: 0.02512882234549092
Epoch [10/40], Loss: 0.02472435644107583
Epoch [11/40], Loss: 0.024483303790588537
Epoch [12/40], Loss: 0.024122810720804896
Epoch [13/40], Loss: 0.024349671430116217
Epoch [14/40], Loss: 0.023659083380157868
Epoch [15/40], Loss: 0.023655956190494886
Epoch [16/40], Loss: 0.02371771662398982
Epoch [17/40], Loss: 0.02339242411338106
Epoch [18/40], Loss: 0.023123688979663458
Epoch [19/40], Loss: 0.02335826940306716
Epoch [20/40], Loss: 0.022931612077998285
Epoch [21/40], Loss: 0.022918104677643266
Epoch [22/40], Loss: 0.022703514175601916
Epoch [23/40], Loss: 0.022731547490770245
Epoch [24/40], Loss: 0.022536548953571475
Epoch [25/40], Loss: 0.02256252928134701
Epoch [26/40], Loss: 0.022766522776294616
Epoch [27/40], Loss: 0.022404040800751455
Epoch [28/40], Loss: 0.022666319886348475
Epoch [29/40], Loss: 0.02232284536517108
Epoch [30/40], Loss: 0.022360837320291472
Epoch [31/40], Loss: 0.022248711810707773
KeyboardInterrupt: 
scheduler = LinearNoiseScheduler(T, beta_start, beta_end)

model = UNet(nc).to(device)
model.load_state_dict(torch.load("../CVModels/denoiseModel.pth", map_location = device))

model.eval()

with torch.no_grad():
    xt = torch.randn((num_samples, nc, image_size, image_size)).to(device)

    for t in reversed(range(T)):
        noise_pred = model(xt, torch.as_tensor(t).unsqueeze(0).to(device))

        xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(t).to(device))

        ims = torch.clamp(xt, -1., 1.).detach().cpu()
        ims = (ims + 1) / 2
        grid = vutils.make_grid(ims, nrow = num_grid_rows)

        if (t % 50 == 0):
            plt.figure(figsize = (10, 10))
            plt.axis("off")
            plt.title(f"t = {t}")
            plt.imshow(np.transpose(grid.cpu().detach().numpy(), (1, 2, 0)))
            plt.show()

        img = transforms.ToPILImage()(grid)
        if not os.path.exists("./DDPM/MNIST"):
            os.makedirs("./DDPM/MNIST")

        img.save(f"./DDPM/MNIST/x0_{t}.png")
        img.close()