!sudo apt install cm-super dvipng texlive-latex-extra texlive-latex-recommended
!pip install einops
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.nn.functional as F
import torchvision.utils as vutils
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from einops import rearrange
from latex import latexify
= 2)
latexify(columns
%matplotlib inline
%config InlineBackend.figure_format = "retina"
if torch.backends.mps.is_available():
= torch.device("mps")
device elif torch.cuda.is_available():
= torch.device("cuda")
device else:
= torch.device("cpu")
device
print(device)
= 999
manualSeed #manualSeed = random.randint(1, 10000)
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)True) torch.use_deterministic_algorithms(
mps
Random Seed: 999
= 128
batch_size = 64
image_size = 1
nc = 100
nz = 64
ngf = 64
ndf = 5
num_epochs = 2e-4
lr = 0.5 beta1
= transforms.Compose([transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize([0.5 for _ in range(nc)], [0.5 for _ in range(nc)])]) transform
= datasets.MNIST(root = "../CVDatasets", train = True, transform = transform, download = True)
mnist_train = datasets.MNIST(root = "../CVDatasets", train = False, transform = transform, download = True)
mnist_test
# celebADataset = datasets.ImageFolder(root = "../CVDatasets/CelebA", transform = transform)
# celebALoader = DataLoader(dataset = celebADataset, batch_size = batch_size, shuffle = True)
= DataLoader(dataset = mnist_train, batch_size = batch_size, shuffle = True)
mnist_train_loader = DataLoader(dataset = mnist_test, batch_size = batch_size, shuffle = False) mnist_test_loader
\[ \text{Generator}: G_{\phi} \]
\[ \text{Discriminator}: D_{\theta} \]
\[ \min_{\phi} \max_{\theta} V(G, D) = \underset{\textbf{x} \sim p_{\text{data}}}{\mathbb{E}} [\log(D_{\theta}(\textbf{x}))] + \underset{\textbf{z} \sim p_z(\textbf{z})}{\mathbb{E}} [1 - \log(D_{\theta}(G_{\phi}(\textbf{z})))] \]
Given an \(N \times N\) image, \(f \times f\) filter size, \(s\) as stride and \(p\) as the padding, the output dimension is
\[ \left\lfloor{\frac{N + 2p - f}{s}}\right\rfloor + 1 \]
Generator - Discriminator Architecture
class Discriminator(nn.Module):
def __init__(self, channels_img, feature_d):
super().__init__()
self.disc = nn.Sequential(
= 4, stride = 2, padding = 1),
nn.Conv2d(channels_img, feature_d, kernel_size # BatchNorm is skipped for the first layer
0.2),
nn.LeakyReLU(self._block(feature_d, feature_d * 2, 4, 2, 1),
self._block(feature_d * 2, feature_d * 4, 4, 2, 1),
self._block(feature_d * 4, feature_d * 8, 4, 2, 1),
* 8, 1, kernel_size = 4, stride = 2, padding = 0),
nn.Conv2d(feature_d
)
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
= False),
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias
nn.BatchNorm2d(out_channels),0.2),
nn.LeakyReLU(
)
def forward(self, x):
return self.disc(x)
class Generator(nn.Module):
def __init__(self, channels_noise, channels_img, feature_g):
super().__init__()
self.net = nn.Sequential(
self._block(channels_noise, feature_g * 16, 4, 1, 0),
self._block(feature_g * 16, feature_g * 8, 4, 2, 1),
self._block(feature_g * 8, feature_g * 4, 4, 2, 1),
self._block(feature_g * 4, feature_g * 2, 4, 2, 1),
* 2, channels_img, kernel_size = 4, stride = 2, padding = 1),
nn.ConvTranspose2d(feature_g # No BatchNorm here and Tanh is used in the last layer
nn.Tanh(),# image transformed to be in [-1, 1]
)
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
= False),
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
def forward(self, x):
return self.net(x)
def initialize_weights(model):
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
0.0, 0.02)
nn.init.normal_(m.weight.data, # Initialized with mean 0 and std 0.02
Model Instantiation
= Generator(nz, nc, ngf).to(device)
gen = Discriminator(nc, ndf).to(device)
disc
initialize_weights(gen) initialize_weights(disc)
Optimizers and Criterion
= torch.optim.Adam(gen.parameters(), lr = lr, betas = (beta1, 0.999))
opt_gen = torch.optim.Adam(disc.parameters(), lr = lr, betas = (beta1, 0.999))
opt_disc = nn.BCEWithLogitsLoss() criterion
= torch.randn(32, nz, 1, 1).to(device)
fixed_noise
gen.train() disc.train()
Discriminator(
(disc): Sequential(
(0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): LeakyReLU(negative_slope=0.2)
(2): Sequential(
(0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.2)
)
(3): Sequential(
(0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.2)
)
(4): Sequential(
(0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.2)
)
(5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
)
)
for epoch in range(num_epochs):
for batch_idx, (real, _) in enumerate(mnist_train_loader):
= real.to(device)
real = torch.randn(batch_size, nz, 1, 1).to(device)
noise = gen(noise)
fake
# Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
= disc(real).reshape(-1)
disc_real # BCE(Real, 1) = -log(Real)
= criterion(disc_real, torch.ones_like(disc_real))
loss_disc_real = disc(fake.detach()).reshape(-1)
disc_fake
# BCE(Fake, 0) = -log(1 - Fake) = -log(Fake)
= criterion(disc_fake, torch.zeros_like(disc_fake))
loss_disc_fake = (loss_disc_real + loss_disc_fake) / 2
loss_disc
disc.zero_grad()= True)
loss_disc.backward(retain_graph
opt_disc.step()
# Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
= disc(fake).reshape(-1)
output # BCE(Fake, 1) = -log(Fake)
= criterion(output, torch.ones_like(output))
loss_gen
gen.zero_grad()
loss_gen.backward()
opt_gen.step()
if batch_idx % 100 == 0:
print(f"Epoch [{epoch}/{num_epochs}], Batch Step [{batch_idx}/{len(mnist_train_loader)}], Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}")
with torch.no_grad():
# infer from the Generator by plotting the generated images
eval()
gen.= torch.randn(32, nz, 1, 1).to(device)
fake_noise = gen(fake_noise)
fake = vutils.make_grid(real[:32], normalize = True)
img_grid_real = vutils.make_grid(fake[:32], normalize = True)
img_grid_fake
= (10, 10))
plt.figure(figsize 1, 2, 1)
plt.subplot("off")
plt.axis("Real Images")
plt.title(1, 2, 0)))
plt.imshow(np.transpose(img_grid_real.cpu().detach().numpy(), (
1, 2, 2)
plt.subplot("off")
plt.axis("Fake Images")
plt.title(1, 2, 0)))
plt.imshow(np.transpose(img_grid_fake.cpu().detach().numpy(), (
plt.show()
= transforms.ToPILImage()(img_grid_fake)
img if not os.path.exists("./DCGAN/MNIST"):
"./DCGAN/MNIST")
os.makedirs(f"./DCGAN/MNIST/{epoch}_{batch_idx}.png")
img.save(
gen.train()
"mnist_gen.pth")
torch.save(gen.state_dict(), "mnist_disc.pth") torch.save(disc.state_dict(),
Epoch [0/5], Batch Step [0/469], Loss D: 0.6396, Loss G: 0.8332
Epoch [0/5], Batch Step [100/469], Loss D: 0.0153, Loss G: 4.0897
Epoch [0/5], Batch Step [200/469], Loss D: 0.6575, Loss G: 1.2781
Epoch [0/5], Batch Step [300/469], Loss D: 0.6146, Loss G: 0.8419
Epoch [0/5], Batch Step [400/469], Loss D: 0.5918, Loss G: 0.9050
Epoch [1/5], Batch Step [0/469], Loss D: 0.5435, Loss G: 1.0118
Epoch [1/5], Batch Step [100/469], Loss D: 0.5846, Loss G: 1.1505
Epoch [1/5], Batch Step [200/469], Loss D: 0.5791, Loss G: 0.8657
Epoch [1/5], Batch Step [300/469], Loss D: 0.6101, Loss G: 0.9389
Epoch [1/5], Batch Step [400/469], Loss D: 0.6121, Loss G: 0.9977
Epoch [2/5], Batch Step [0/469], Loss D: 0.5481, Loss G: 1.8410
Epoch [2/5], Batch Step [100/469], Loss D: 0.5722, Loss G: 0.6610
Epoch [2/5], Batch Step [200/469], Loss D: 0.6020, Loss G: 2.2574
Epoch [2/5], Batch Step [300/469], Loss D: 0.3735, Loss G: 1.5304
Epoch [2/5], Batch Step [400/469], Loss D: 0.5576, Loss G: 0.6515
Epoch [3/5], Batch Step [0/469], Loss D: 0.3645, Loss G: 0.9584
Epoch [3/5], Batch Step [100/469], Loss D: 0.4402, Loss G: 2.8341
Epoch [3/5], Batch Step [200/469], Loss D: 0.7291, Loss G: 0.7157
Epoch [3/5], Batch Step [300/469], Loss D: 0.3787, Loss G: 1.6210
Epoch [3/5], Batch Step [400/469], Loss D: 0.3439, Loss G: 1.8571
Epoch [4/5], Batch Step [0/469], Loss D: 0.3109, Loss G: 1.9058
Epoch [4/5], Batch Step [100/469], Loss D: 0.2812, Loss G: 2.7282
Epoch [4/5], Batch Step [200/469], Loss D: 0.4025, Loss G: 1.6386
Epoch [4/5], Batch Step [300/469], Loss D: 0.4022, Loss G: 2.9081
Epoch [4/5], Batch Step [400/469], Loss D: 0.4217, Loss G: 2.0483
/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py:744: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
= torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = Generator(nz, nc, ngf).to(device)
gen "../CVModels/mnist_gen.pth", map_location = device))
gen.load_state_dict(torch.load(eval() gen.
Generator(
(net): Sequential(
(0): Sequential(
(0): ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(1): Sequential(
(0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(2): Sequential(
(0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(3): Sequential(
(0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(4): ConvTranspose2d(128, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(5): Tanh()
)
)
= torch.randn(32, nz, 1, 1).to(device)
noise = gen(noise)
fake = vutils.make_grid(fake[:32], normalize = True)
img_grid_fake
= (10, 10))
plt.figure(figsize "off")
plt.axis("Fake Images")
plt.title(1, 2, 0)))
plt.imshow(np.transpose(img_grid_fake.cpu().detach().numpy(), ( plt.show()
4 Images in the Latent Space
= 4
n = torch.randn(n, nz, 1, 1).to(device)
interpolation
= gen(interpolation)
fake
= (10, 8))
plt.figure(figsize 2, 2, 1)
plt.subplot(0].cpu().detach().numpy(), (1, 2, 0)), cmap = "gray")
plt.imshow(np.transpose(fake["Image 1")
plt.title("off")
plt.axis(
2, 2, 2)
plt.subplot(1].cpu().detach().numpy(), (1, 2, 0)), cmap = "gray")
plt.imshow(np.transpose(fake["Image 2")
plt.title("off")
plt.axis(
2, 2, 3)
plt.subplot(2].cpu().detach().numpy(), (1, 2, 0)), cmap = "gray")
plt.imshow(np.transpose(fake["Image 3")
plt.title("off")
plt.axis(
2, 2, 4)
plt.subplot(3].cpu().detach().numpy(), (1, 2, 0)), cmap = "gray")
plt.imshow(np.transpose(fake["Image 4")
plt.title("off") plt.axis(
Linear Interpolation
\[ \mathbb{w^*} = \lambda \mathbb{a} + (1 - \lambda) \mathbb{b} \]
= 13
n = torch.zeros(n, nz, 1, 1).to(device)
interpolation01 = torch.zeros(n, nz, 1, 1).to(device)
interpolation02 = torch.zeros(n, nz, 1, 1).to(device)
interpolation13 = torch.zeros(n, nz, 1, 1).to(device)
interpolation23
for i in range(n):
= (i / (n - 1)) * interpolation[1] + ((n - i - 1) / (n - 1)) * interpolation[0]
interpolation01[i] = (i / (n - 1)) * interpolation[2] + ((n - i - 1) / (n - 1)) * interpolation[0]
interpolation02[i] = (i / (n - 1)) * interpolation[3] + ((n - i - 1) / (n - 1)) * interpolation[1]
interpolation13[i] = (i / (n - 1)) * interpolation[3] + ((n - i - 1) / (n - 1)) * interpolation[2]
interpolation23[i]
= torch.cat([interpolation[0].unsqueeze(0), interpolation01, interpolation[1].unsqueeze(0)], dim = 0)
topSide = torch.cat([interpolation[2].unsqueeze(0), interpolation23, interpolation[3].unsqueeze(0)], dim = 0)
bottomSide = torch.cat([interpolation[0].unsqueeze(0), interpolation02, interpolation[2].unsqueeze(0)], dim = 0)
leftSide = torch.cat([interpolation[1].unsqueeze(0), interpolation13, interpolation[3].unsqueeze(0)], dim = 0)
rightSide
= (15 * 3, 3))
plt.figure(figsize for i in range(15):
1, 15, i + 1)
plt.subplot(0))[0].cpu().detach().numpy(), (1, 2, 0)), cmap = "gray")
plt.imshow(np.transpose(gen(bottomSide[i].unsqueeze("off")
plt.axis( plt.show()
Interpolation along all the 4 peripheral
= 1
place = (15, 15))
plt.figure(figsize for i in range(15):
for j in range(15):
15, 15, place)
plt.subplot(if (i == 0):
0))[0].cpu().detach().numpy(), (1, 2, 0)), cmap = "gray")
plt.imshow(np.transpose(gen(topSide[j].unsqueeze(if (i == 14):
0))[0].cpu().detach().numpy(), (1, 2, 0)), cmap = "gray")
plt.imshow(np.transpose(gen(bottomSide[j].unsqueeze(if (i != 0 and i != 14):
if (j == 0):
0))[0].cpu().detach().numpy(), (1, 2, 0)), cmap = "gray")
plt.imshow(np.transpose(gen(leftSide[i].unsqueeze(if (j == 14):
0))[0].cpu().detach().numpy(), (1, 2, 0)), cmap = "gray")
plt.imshow(np.transpose(gen(rightSide[i].unsqueeze("off")
plt.axis(+= 1
place plt.show()
= torch.zeros(15, 15, nz, 1, 1).to(device)
final_interpolation 0] = topSide
final_interpolation[14] = bottomSide
final_interpolation[
for i in range(1, 14):
for j in range(15):
if (j == 0):
= leftSide[i]
final_interpolation[i][j] if (j == 14):
= rightSide[i]
final_interpolation[i][j] if (j != 0 or j != 14):
= (j / (n - 1)) * rightSide[i] + ((n - j - 1) / (n - 1)) * leftSide[i] final_interpolation[i][j]
= 1
place = (15, 15))
plt.figure(figsize for i in range(15):
for j in range(15):
15, 15, place)
plt.subplot(0))[0].cpu().detach().numpy(), (1, 2, 0)), cmap = "gray")
plt.imshow(np.transpose(gen(final_interpolation[i][j].unsqueeze("off")
plt.axis(+= 1
place f"./DCGAN/MNIST_Interpolate/Interpolation_{i}_{j}.png")
plt.savefig( plt.show()