import os
import yaml
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.nn.functional as F
import torchvision.utils as vutils
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm import tqdm
from torch.utils.data.dataset import Dataset
import glob
import pickle
from PIL import Image
from Blocks import *
from VQVAE import *
from Loader import *
from Unet import *
from Scheduler import *
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
%matplotlib inline
%config InlineBackend.figure_format = "retina"
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
print(device)
batch_size_autoenc = 8
batch_size_ldm = 16
# Model Parameters
nc = 3
image_size = 256
Loading Config
def get_config_value(config, key, default_value):
return config[key] if key in config else default_value
config_path = "../celebhq_cond.yaml"
with open(config_path, 'r') as file:
try:
config = yaml.safe_load(file)
except yaml.YAMLError as exc:
print(exc)
diffusion_model_config = config['ldm_params']
condition_config = get_config_value(diffusion_model_config, key='condition_config', default_value=None)
path = "../CelebAMask-HQ"
im_dataset = CelebDataset(split='train',
im_path=path,
im_size=image_size,
im_channels=nc,
use_latents=False,
latent_path="../vqvaelatents",
condition_config=condition_config)
celebAloader = DataLoader(im_dataset,
batch_size=batch_size_ldm,
shuffle=True)
100%|██████████| 30000/30000 [00:00<00:00, 43585.02it/s]
Found 30000 images
Found 30000 masks
Found 30000 captions
print(im_dataset[0][1]["image"].shape)
print(im_dataset[0][0].shape)
torch.Size([1, 512, 512])
torch.Size([3, 256, 256])
indices = np.random.choice(30000, 16, replace = False)
Images \((B \times C \times H \times W) \to (16 \times 3 \times 256 \times 256)\)
images = torch.stack([im_dataset[i][0] for i in indices], dim = 0)
mask = torch.stack([im_dataset[i][1]["image"] for i in indices], dim = 0)
# mask = F.interpolate(mask, size=(32, 32), mode='nearest')
grid = vutils.make_grid(images, nrow = 4, normalize = True)
grid1 = vutils.make_grid(mask, nrow = 4, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(grid, (1, 2, 0)))
plt.axis('off')
plt.show()
Hair Mask \((B \times C \times H \times W) \to (16 \times 1 \times 512 \times 512)\)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(grid1, (1, 2, 0)))
plt.axis('off')
plt.show()
Inverted Hair Mask \((1 - \texttt{mask})\) \((B \times C \times H \times W) \to (16 \times 1 \times 512 \times 512)\)
grid3 = vutils.make_grid((1 - mask), nrow = 4, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(grid3.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()
Only Hair with Mask Pooled to Image Dimension \((\texttt{mask\_pooled} \times \texttt{images})\)
# mask * images output
maskNN = F.interpolate(mask, size=(256, 256), mode='nearest')
gridNN = vutils.make_grid((maskNN * images), nrow = 4, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(gridNN.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()
No Hair \(((1 - \texttt{mask\_pooled}) \times \texttt{images})\)
gridNNN = vutils.make_grid(((1 - maskNN) * images), nrow = 4, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(gridNNN.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()
Hair Mask Pooled to Latent Dimension \((B \times C \times H \times W) \to (16 \times 1 \times 32 \times 32)\)
maskN = F.interpolate(mask, size=(32, 32), mode='nearest')
gridN = vutils.make_grid(maskN, nrow = 4, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(gridN.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()
gridN = vutils.make_grid((1 - maskN), nrow = 4, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(gridN.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()
model = VQVAE().to(device)
model.load_state_dict(torch.load("../vqvaeCeleb/vqvae_autoencoder.pth", map_location = device))
model.eval()
Encoded Images \((B \times C \times H \times W) \to (16 \times 3 \times 32 \times 32)\)
im = model.encode(images.to(device))[0]
grid2 = vutils.make_grid(im, nrow = 4, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(grid2.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()
Only Hair in Encoded Dimension \((\texttt{mask\_pooled} \times \texttt{images})\)
only_hair = maskN.to(device) * im
grid4 = vutils.make_grid(only_hair, nrow = 4, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(grid4.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()
No Hair in Encoded Dimension \(((1 - \texttt{mask\_pooled}) \times \texttt{images})\)
no_hair = (1 - maskN.to(device)) * im
grid5 = vutils.make_grid(no_hair, nrow = 4, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(grid5.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()
num_grid_rows = 4
num_samples = 16
# Train Parameters
lr = 5e-5
# Diffusion Parameters
beta_start = 1e-4
beta_end = 2e-2
T = 1000
# Model Parameters
nc = 3
image_size = 256
z_channels = 3
scheduler = LinearNoiseScheduler(T, beta_start, beta_end)
xt = torch.randn((num_samples, z_channels, image_size, image_size)).to(device)
noise = xt
noise = noise.to(device)
noisy_encoded_images = scheduler.add_noise(images.to(device), noise, torch.full((num_samples,), 200, device = device))
noisy_encoded_images = torch.clamp(noisy_encoded_images, -1., 1.).detach().cpu()
noisy_encoded_images = (noisy_encoded_images + 1) / 2
gridNoise = vutils.make_grid(noisy_encoded_images, nrow = num_grid_rows, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(gridNoise.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()
\[ \textbf{x}_{t - 1}^{\text{denoise}} {}_{(32 \times 32)} = \left( [ \textbf{x}_t^{\text{enc}} {}_{(32 \times 32)} \sim q(\textbf{x}_t | \textbf{x}_{t - 1}, \textbf{x}_0) ] * [\texttt{invert\_mask} {}_{(32 \times 32)}] \right) + \left( [ \textbf{x}_t^{\text{denoise}} {}_{(32 \times 32)} \sim q(\textbf{x}_{t - 1} | \textbf{x}_t) ] * [\texttt{mask} {}_{(32 \times 32)}] \right) \]
Only interpolated/pooled the mask from 512 to 32
scheduler = LinearNoiseScheduler(T, beta_start, beta_end)
model = Unet(im_channels = z_channels).to(device)
model.load_state_dict(torch.load("../ldmCeleb/denoiseLatentModelCeleb.pth", map_location = device))
model.eval()
vae = VQVAE().to(device)
vae.eval()
vae.load_state_dict(torch.load("../vqvaeCeleb/vqvae_autoencoder.pth", map_location=device), strict=True)
with torch.no_grad():
im_size = image_size // (2 ** (sum(model.down_sample)))
xt = torch.randn((num_samples, z_channels, im_size, im_size)).to(device)
noise = xt
noise = noise.to(device)
# encoded images -> im (32 x 32)
# mask -> maskN (32 x 32)
maskN = maskN.to(device)
# use the same noise
# noise = torch.randn_like(im).to(device) # -> This noise has to be fixed made out
for t in reversed(range(T)):
noise_pred = model(xt, torch.as_tensor(t).unsqueeze(0).to(device))
xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(t).to(device))
noisy_encoded_images = scheduler.add_noise(im, noise, torch.full((num_samples,), t, device = device))
xt = ((1 - maskN) * noisy_encoded_images) + ((maskN) * xt)
ims_raw = torch.clamp(xt, -1., 1.).detach().cpu()
ims_raw = (ims_raw + 1) / 2
ims = vae.decode(xt)
ims = torch.clamp(ims, -1., 1.).detach().cpu()
ims = (ims + 1) / 2
grid_latent = vutils.make_grid(ims_raw, nrow = num_grid_rows, normalize = True)
grid_reconstructed = vutils.make_grid(ims, nrow = num_grid_rows, normalize = True)
if (t % 100 == 0 or t == T - 1):
plt.figure(figsize = (15, 15))
plt.subplot(1, 2, 1)
plt.axis("off")
plt.imshow(np.transpose(grid_latent.cpu().detach().numpy(), (1, 2, 0)))
plt.subplot(1, 2, 2)
plt.axis("off")
plt.imshow(np.transpose(grid_reconstructed.cpu().detach().numpy(), (1, 2, 0)))
plt.show()
gridFinal = vutils.make_grid(ims, nrow = num_grid_rows, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(gridFinal.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()
# finally for images decoded
# mask * images_output + (1 - mask) * images original
images = (images + 1) / 2
images = images.to(device)
ims = ims.to(device)
maskNN = maskNN.to(device)
final_images = (maskNN * ims) + ((1 - maskNN) * images)
grid_final = vutils.make_grid(final_images, nrow = num_grid_rows, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(grid_final.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()