Applying ZerO to a denosiing diffusion model
Denoising diffusion models are a relatively recent type of deep generative models that have become extremely popular with the rapid advance of text-to-image models such as Stable Diffusion. In this notebook, we train a diffusion model on the CIFAR-10 dataset to unconditionally generate images, closely following this tutorial (which is for a different dataset) and adding horovod support. To fit under our time constraints, we only train for 10 epochs, which means that our models will not converge, but the trends should be clearly visible for the sake of comparison.
The question we investigate here is whether ZerO generalises to this task and architecture.
install datasets diffusers accelerate
As usual, we define the model (CIFAR-10), load the data and set up the training script.
import tqdm
from functools import partialmethod
import datasets
datasets.disable_progress_bar()
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
# ------------------------------------------------------------------------------------------------------------
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from datasets import load_dataset
from torchvision import transforms
NUM_WORKERS = 1
@dataclass
class TrainingConfig:
image_size = 32 # the generated image resolution
train_batch_size = 32 * NUM_WORKERS
eval_batch_size = 32 # how many images to sample during evaluation
num_epochs = 10 // NUM_WORKERS
gradient_accumulation_steps = 1
learning_rate = 1e-4 * NUM_WORKERS
lr_warmup_steps = 500
save_image_epochs = 5 // NUM_WORKERS
save_model_epochs = 5 // NUM_WORKERS
mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
output_dir = "/dbfs/ml/Group_7/diffusion"
push_to_hub = False # whether to upload the saved model to the HF Hub
hub_private_repo = False
overwrite_output_dir = True # overwrite the old model when re-running the notebook
seed = 0
dataset_name = "cifar10"
config = TrainingConfig()
preprocess = transforms.Compose(
[
transforms.Resize((config.image_size, config.image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def transform(examples):
images = [preprocess(image.convert("RGB")) for image in examples["img"]]
return {"images": images}
from diffusers import UNet2DModel
from diffusers import DDPMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup
import os
def seed_everything(seed: int = 42):
import random, os
import numpy as np
import torch
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
from accelerate import Accelerator
from diffusers import DDPMPipeline
from diffusers.hub_utils import init_git_repo, push_to_hub
from PIL import Image
from tqdm.auto import tqdm
import horovod as hvd
from torch.utils.data.distributed import DistributedSampler
from torch import distributed as dist
def make_grid(images, rows, cols):
w, h = images[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, image in enumerate(images):
grid.paste(image, box=(i % cols * w, i // cols * h))
return grid
def evaluate(config, epoch, pipeline):
# Sample some images from random noise (this is the backward diffusion process).
# The default pipeline output type is `List[PIL.Image]`
images = pipeline(
batch_size=config.eval_batch_size,
generator=torch.manual_seed(config.seed),
).images
# Make a grid out of the images
image_grid = make_grid(images, rows=8, cols=4)
# Save the images
test_dir = os.path.join(config.output_dir, "samples")
os.makedirs(test_dir, exist_ok=True)
print(test_dir)
image_grid.save(f"{test_dir}/{epoch:04d}.png")
def train_loop(config, zero_init=False):
seed_everything()
print("Initialising...")
hvd.init()
torch.cuda.set_device(hvd.local_rank())
print(hvd.rank())
print("Loading dataset")
dataset = load_dataset(config.dataset_name, split="train")
dataset.set_transform(transform)
sampler = DistributedSampler(dataset, num_replicas=hvd.size(), rank=hvd.rank(), shuffle=True)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, sampler=sampler, worker_init_fn=seed_everything)
print("Creating model")
model = UNet2DModel(
sample_size=config.image_size, # the target image resolution
in_channels=3, # the number of input channels, 3 for RGB images
out_channels=3, # the number of output channels
layers_per_block=2, # how many ResNet layers to use per UNet block
block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channes for each UNet block
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
"DownBlock2D",
),
up_block_types=(
"UpBlock2D", # a regular ResNet upsampling block
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
if zero_init:
zerO_init_model_(model)
print("Setting up training objects...")
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps,
num_training_steps=(len(train_dataloader) * config.num_epochs),
)
# Initialize accelerator and tensorboard logging
accelerator = Accelerator(
mixed_precision=config.mixed_precision,
gradient_accumulation_steps=config.gradient_accumulation_steps,
log_with="tensorboard",
logging_dir=os.path.join(config.output_dir, "logs"),
)
if accelerator.is_main_process:
if config.push_to_hub:
repo = init_git_repo(config, at_init=True)
accelerator.init_trackers("train_example")
# Prepare everything
# There is no specific order to remember, you just need to unpack the
# objects in the same order you gave them to the prepare method.
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
global_step = 0
print(config.output_dir)
# Now you train the model
for epoch in range(config.num_epochs):
sampler.set_epoch(epoch)
for step, batch in enumerate(train_dataloader):
clean_images = batch["images"]
# Sample noise to add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device)
bs = clean_images.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device
).long()
# Add noise to the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
with accelerator.accumulate(model):
# Predict the noise residual
noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
if not step % 200 and accelerator.is_main_process:
print(f"[Epoch {epoch}, {step}/{len(train_dataloader)} steps])", *[str(k) + ": " + str(v) for k,v in logs.items()])
accelerator.log(logs, step=global_step)
global_step += 1
# After each epoch you optionally sample some demo images with evaluate() and save the model
if accelerator.is_main_process:
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
evaluate(config, epoch, pipeline)
if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
if config.push_to_hub:
push_to_hub(config, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=True)
else:
pipeline.save_pretrained(config.output_dir)
Then, we run the script in a distributed manner with HorovodRunner.
import horovod.torch as hvd
from sparkdl import HorovodRunner
import json
hr = HorovodRunner(np=NUM_WORKERS, driver_log_verbosity='all')
hr.run(train_loop, config=config)
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
img = mpimg.imread('/dbfs/ml/Group_7/diffusion/samples/0009.png')
plt.figure(figsize=[10,20])
imgplot = plt.imshow(img)
plt.show()
After this, we define ZerO initialisation for the diffusion model used. We initialise the convolutional layers, the linear layers and the attention modules by taking the corresponding functions from the previous notebooks and applying them on the task in question.
import torch
import torch.nn.functional as F
from datasets import load_dataset
from torchvision import transforms
import numpy as np
from scipy.linalg import hadamard
from torch import nn
def zerO_init_conv_layer_(weight):
"""
In-place initialise the given convolutional layer with zerO-init
using the following equation:
---------------------------------
W[:,:,n,n] := c * I_p * H_m * I_p
---------------------------------
where W: out_dim x in_dim x n_filters
I_p: out_dim x m (partial identity)
H_m: m x m (Hadamard matrix)
I_p: m x in_dim (partial identity)
"""
out_dim, in_dim, k = weight.shape[:3]
n = int(np.floor(k / 2))
if out_dim == in_dim:
weight.data[..., n, n] = torch.eye(in_dim)
elif out_dim < in_dim:
weight.data[..., n, n] = partial_identity(out_dim, in_dim).type_as(weight)
else:
m = int(np.ceil(np.log2(out_dim)))
c = 2 ** (-(m - 1) / 2)
H = lambda dim: torch.tensor(hadamard(dim)).type_as(weight)
I = lambda outd, ind: partial_identity(outd, ind).type_as(weight)
# NOTE: scipy's hadamard function differs from the paper's definition
# in that we need to pass 2^m as its size input instead of m
weight.data[..., n, n] = (
c * I(out_dim, 2**m) @ H(2**m) @ I(2**m, in_dim)
)
def zerO_init_linear(weight):
"""
Algorithm 1.
hadamard: c * I_p * H_m * I_p
I_p: out_dim * m
H_m: m * m
I_p: m * in_dim
"""
out_dim, in_dim = weight.shape
device = weight.device
if out_dim == in_dim:
weight.data = torch.eye(in_dim)
elif out_dim < in_dim:
weight.data = partial_identity(out_dim, in_dim).type_as(weight)
else:
m = int(np.ceil(np.log2(out_dim)))
c = 2 ** (-(m - 1) / 2)
H = lambda dim: torch.tensor(hadamard(dim)).type_as(weight)
weight.data = (
c * H(2**m)[:out_dim, :in_dim]
)
weight.data = weight.data.to(device)
def partial_identity(out_dim, in_dim):
if out_dim < in_dim:
I = torch.eye(out_dim)
O = torch.zeros(out_dim, (in_dim - out_dim))
return torch.cat((I, O), 1)
elif out_dim > in_dim:
I = torch.eye(in_dim)
O = torch.zeros((out_dim - in_dim), in_dim)
return torch.cat((I, O), 0)
else:
return torch.eye(out_dim)
def zerO_init_model_(model):
for m in model.modules():
# Initialize relevant matrices to zero in the beginning:
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.zeros_(m.weight)
for name, m in model.named_modules():
# Linear
if isinstance(m, nn.Linear):
if not name.endswith(".query") and not name.endswith(".key") and not name.endswith(".value"):
zerO_init_linear(m.weight)
# Convolution
elif isinstance(m, nn.Conv2d):
# Ignore last conv layer in each residual block
if not name.endswith(".conv2"):
zerO_init_conv_layer_(m.weight)
# Attention (Q, K, V); K and V are already initialized to null matrices
elif isinstance(m, nn.Linear):
if name.endswith(".query"):
nn.init.eye_(m.weight)
We train the model with this setting as well:
import horovod.torch as hvd
from sparkdl import HorovodRunner
import json
config.output_dir += "_zero_init"
print(config.output_dir)
hr = HorovodRunner(np=NUM_WORKERS, driver_log_verbosity='all')
hr.run(train_loop, config=config, zero_init=True)
Visualising samples halfway through training
Below we visually compare the samples from the default- and the zero-initialised models. We conclude that zero initialisation harms the model a little, as the images tend to have less contrast between the foreground and the background, and also sometimes the objects are less-developed. See the highlighted differences near the end of the notebook. NOTE that the samples are very similar only because we fix the random seed during sampling. If we had not fixed the seed, we would have got completely random samples, making visual comparison quite difficult.
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
fig, axes = plt.subplots(1,2, figsize=[30,40])
img1 = mpimg.imread('/dbfs/ml/Group_7/diffusion/samples/0004.png')
img2 = mpimg.imread('/dbfs/ml/Group_7/diffusion_zero_init_zero_init/samples/0004.png')
axes[0].imshow(img1)
axes[0].set_title("default")
axes[1].imshow(img2)
axes[1].set_title("zero_init")
plt.show()
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
fig, axes = plt.subplots(1,2, figsize=[30,40])
img1 = mpimg.imread('/dbfs/ml/Group_7/diffusion/samples/0009.png')
img2 = mpimg.imread('/dbfs/ml/Group_7/diffusion_zero_init_zero_init/samples/0009.png')
axes[0].imshow(img1)
axes[0].set_title("default")
axes[1].imshow(img2)
axes[1].set_title("zero_init")
plt.show()
Difference between the samples
To make comparison easier, we highlight the differences between the samples.
Loss curves
Below we open the tensorboard logs to inspect the training curves. We can see that ZerO is consistently worse than the default initialisation, further corroborating the finding the ZerO is not applicable as an out-of-the-box replacement for default initialisation.
tensorboard
%tensorboard --logdir /dbfs/ml/Group_7/
# To view the logs, please filter for 'diffusion' in the search bar in tensorboard!