Image InPainting using the learnt LDM

Generative Models
Author

Guntas Singh Saran

Published

July 7, 2024

import os
import yaml
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.nn.functional as F
import torchvision.utils as vutils
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm import tqdm
from torch.utils.data.dataset import Dataset
import glob
import pickle
from PIL import Image

from Blocks import *
from VQVAE import *
from Loader import *
from Unet import *
from Scheduler import *

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

%matplotlib inline
%config InlineBackend.figure_format = "retina"

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)
cuda
batch_size_autoenc = 8
batch_size_ldm = 16
# Model Parameters
nc = 3
image_size = 256

Loading Config

def get_config_value(config, key, default_value):
    return config[key] if key in config else default_value

config_path = "../celebhq_cond.yaml"
with open(config_path, 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)

diffusion_model_config = config['ldm_params']
condition_config = get_config_value(diffusion_model_config, key='condition_config', default_value=None)
path = "../CelebAMask-HQ"

im_dataset = CelebDataset(split='train',
                                im_path=path,
                                im_size=image_size,
                                im_channels=nc,
                                use_latents=False,
                                latent_path="../vqvaelatents",
                                condition_config=condition_config)
    
celebAloader = DataLoader(im_dataset,
                            batch_size=batch_size_ldm,
                            shuffle=True)
100%|██████████| 30000/30000 [00:00<00:00, 43585.02it/s]
Found 30000 images
Found 30000 masks
Found 30000 captions
print(im_dataset[0][1]["image"].shape)
print(im_dataset[0][0].shape)
torch.Size([1, 512, 512])
torch.Size([3, 256, 256])
indices = np.random.choice(30000, 16, replace = False)

Images \((B \times C \times H \times W) \to (16 \times 3 \times 256 \times 256)\)

images = torch.stack([im_dataset[i][0] for i in indices], dim = 0)
mask = torch.stack([im_dataset[i][1]["image"] for i in indices], dim = 0)
# mask = F.interpolate(mask, size=(32, 32), mode='nearest')
grid = vutils.make_grid(images, nrow = 4, normalize = True)
grid1 = vutils.make_grid(mask, nrow = 4, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(grid, (1, 2, 0)))
plt.axis('off')
plt.show()

Hair Mask \((B \times C \times H \times W) \to (16 \times 1 \times 512 \times 512)\)

plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(grid1, (1, 2, 0)))
plt.axis('off')
plt.show()

Inverted Hair Mask \((1 - \texttt{mask})\) \((B \times C \times H \times W) \to (16 \times 1 \times 512 \times 512)\)

grid3 = vutils.make_grid((1 - mask), nrow = 4, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(grid3.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()

Only Hair with Mask Pooled to Image Dimension \((\texttt{mask\_pooled} \times \texttt{images})\)

# mask * images output
maskNN = F.interpolate(mask, size=(256, 256), mode='nearest')
gridNN = vutils.make_grid((maskNN * images), nrow = 4, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(gridNN.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()

No Hair \(((1 - \texttt{mask\_pooled}) \times \texttt{images})\)

gridNNN = vutils.make_grid(((1 - maskNN) * images), nrow = 4, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(gridNNN.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()

Hair Mask Pooled to Latent Dimension \((B \times C \times H \times W) \to (16 \times 1 \times 32 \times 32)\)

maskN = F.interpolate(mask, size=(32, 32), mode='nearest')
gridN = vutils.make_grid(maskN, nrow = 4, normalize = True)

plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(gridN.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()

gridN = vutils.make_grid((1 - maskN), nrow = 4, normalize = True)

plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(gridN.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()

model = VQVAE().to(device)
model.load_state_dict(torch.load("../vqvaeCeleb/vqvae_autoencoder.pth", map_location = device))
model.eval()

Encoded Images \((B \times C \times H \times W) \to (16 \times 3 \times 32 \times 32)\)

im = model.encode(images.to(device))[0]
grid2 = vutils.make_grid(im, nrow = 4, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(grid2.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()

Only Hair in Encoded Dimension \((\texttt{mask\_pooled} \times \texttt{images})\)

only_hair = maskN.to(device) * im
grid4 = vutils.make_grid(only_hair, nrow = 4, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(grid4.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()

No Hair in Encoded Dimension \(((1 - \texttt{mask\_pooled}) \times \texttt{images})\)

no_hair = (1 - maskN.to(device)) * im
grid5 = vutils.make_grid(no_hair, nrow = 4, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(grid5.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()

del model
num_grid_rows = 4
num_samples = 16
# Train Parameters
lr = 5e-5

# Diffusion Parameters
beta_start = 1e-4
beta_end = 2e-2
T = 1000

# Model Parameters
nc = 3
image_size = 256
z_channels = 3
scheduler = LinearNoiseScheduler(T, beta_start, beta_end)
xt = torch.randn((num_samples, z_channels, image_size, image_size)).to(device)
noise = xt
noise = noise.to(device)
noisy_encoded_images = scheduler.add_noise(images.to(device), noise, torch.full((num_samples,), 200, device = device))
noisy_encoded_images = torch.clamp(noisy_encoded_images, -1., 1.).detach().cpu()
noisy_encoded_images = (noisy_encoded_images + 1) / 2
gridNoise = vutils.make_grid(noisy_encoded_images, nrow = num_grid_rows, normalize = True)

plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(gridNoise.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()

\[ \textbf{x}_{t - 1}^{\text{denoise}} {}_{(32 \times 32)} = \left( [ \textbf{x}_t^{\text{enc}} {}_{(32 \times 32)} \sim q(\textbf{x}_t | \textbf{x}_{t - 1}, \textbf{x}_0) ] * [\texttt{invert\_mask} {}_{(32 \times 32)}] \right) + \left( [ \textbf{x}_t^{\text{denoise}} {}_{(32 \times 32)} \sim q(\textbf{x}_{t - 1} | \textbf{x}_t) ] * [\texttt{mask} {}_{(32 \times 32)}] \right) \]

Only interpolated/pooled the mask from 512 to 32

scheduler = LinearNoiseScheduler(T, beta_start, beta_end)

model = Unet(im_channels = z_channels).to(device)
model.load_state_dict(torch.load("../ldmCeleb/denoiseLatentModelCeleb.pth", map_location = device))
model.eval()

vae = VQVAE().to(device)
vae.eval()
vae.load_state_dict(torch.load("../vqvaeCeleb/vqvae_autoencoder.pth", map_location=device), strict=True)


with torch.no_grad():
    im_size = image_size // (2 ** (sum(model.down_sample)))
    xt = torch.randn((num_samples, z_channels, im_size, im_size)).to(device)
    noise = xt
    noise = noise.to(device)
    # encoded images -> im (32 x 32)
    # mask -> maskN (32 x 32)
    maskN = maskN.to(device)
    # use the same noise
    # noise = torch.randn_like(im).to(device) # -> This noise has to be fixed made out

    for t in reversed(range(T)):
        noise_pred = model(xt, torch.as_tensor(t).unsqueeze(0).to(device))
        xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(t).to(device))

        noisy_encoded_images = scheduler.add_noise(im, noise, torch.full((num_samples,), t, device = device))
        
        xt = ((1 - maskN) * noisy_encoded_images) + ((maskN) * xt)

        ims_raw = torch.clamp(xt, -1., 1.).detach().cpu()
        ims_raw = (ims_raw + 1) / 2
        
        ims = vae.decode(xt)
        ims = torch.clamp(ims, -1., 1.).detach().cpu()
        ims = (ims + 1) / 2
        
        grid_latent = vutils.make_grid(ims_raw, nrow = num_grid_rows, normalize = True)
        grid_reconstructed = vutils.make_grid(ims, nrow = num_grid_rows, normalize = True)
        
        if (t % 100 == 0 or t == T - 1):
            plt.figure(figsize = (15, 15))
            plt.subplot(1, 2, 1)
            plt.axis("off")
            plt.imshow(np.transpose(grid_latent.cpu().detach().numpy(), (1, 2, 0)))
            
            plt.subplot(1, 2, 2)
            plt.axis("off")
            plt.imshow(np.transpose(grid_reconstructed.cpu().detach().numpy(), (1, 2, 0)))
            
            plt.show()

gridFinal = vutils.make_grid(ims, nrow = num_grid_rows, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(gridFinal.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()

# finally for images decoded
# mask * images_output + (1 - mask) * images original
images = (images + 1) / 2
images = images.to(device)
ims = ims.to(device)
maskNN = maskNN.to(device)
final_images = (maskNN * ims) + ((1 - maskNN) * images)
grid_final = vutils.make_grid(final_images, nrow = num_grid_rows, normalize = True)
plt.figure(figsize = (8, 8))
plt.imshow(np.transpose(grid_final.cpu().detach().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()