Applying ZerO to a denosiing diffusion model

Denoising diffusion models are a relatively recent type of deep generative models that have become extremely popular with the rapid advance of text-to-image models such as Stable Diffusion. In this notebook, we train a diffusion model on the CIFAR-10 dataset to unconditionally generate images, closely following this tutorial (which is for a different dataset) and adding horovod support. To fit under our time constraints, we only train for 10 epochs, which means that our models will not converge, but the trends should be clearly visible for the sake of comparison.

The question we investigate here is whether ZerO generalises to this task and architecture.

install datasets diffusers accelerate

As usual, we define the model (CIFAR-10), load the data and set up the training script.

import tqdm
from functools import partialmethod
import datasets
datasets.disable_progress_bar()
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
# ------------------------------------------------------------------------------------------------------------
from dataclasses import dataclass

import torch
import torch.nn.functional as F
from datasets import load_dataset
from torchvision import transforms

NUM_WORKERS = 1

@dataclass
class TrainingConfig:
    image_size = 32  # the generated image resolution
    train_batch_size = 32 * NUM_WORKERS
    eval_batch_size = 32  # how many images to sample during evaluation
    num_epochs = 10 // NUM_WORKERS
    gradient_accumulation_steps = 1
    learning_rate = 1e-4 * NUM_WORKERS 
    lr_warmup_steps = 500
    save_image_epochs = 5 // NUM_WORKERS
    save_model_epochs = 5 // NUM_WORKERS
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "/dbfs/ml/Group_7/diffusion"  

    push_to_hub = False  # whether to upload the saved model to the HF Hub
    hub_private_repo = False
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 0
    dataset_name = "cifar10"


config = TrainingConfig()

preprocess = transforms.Compose(
    [
        transforms.Resize((config.image_size, config.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)


def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["img"]]
    return {"images": images}

from diffusers import UNet2DModel
from diffusers import DDPMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup

import os
def seed_everything(seed: int = 42):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

from accelerate import Accelerator
from diffusers import DDPMPipeline
from diffusers.hub_utils import init_git_repo, push_to_hub
from PIL import Image
from tqdm.auto import tqdm
import horovod as hvd
from torch.utils.data.distributed import DistributedSampler
from torch import distributed as dist

def make_grid(images, rows, cols):
    w, h = images[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, image in enumerate(images):
        grid.paste(image, box=(i % cols * w, i // cols * h))
    return grid


def evaluate(config, epoch, pipeline):
    # Sample some images from random noise (this is the backward diffusion process).
    # The default pipeline output type is `List[PIL.Image]`
    images = pipeline(
        batch_size=config.eval_batch_size,
        generator=torch.manual_seed(config.seed),
    ).images

    # Make a grid out of the images
    image_grid = make_grid(images, rows=8, cols=4)

    # Save the images
    test_dir = os.path.join(config.output_dir, "samples")
    os.makedirs(test_dir, exist_ok=True)
    print(test_dir)
    image_grid.save(f"{test_dir}/{epoch:04d}.png")


def train_loop(config, zero_init=False):
    seed_everything()
    print("Initialising...")
    hvd.init()
    torch.cuda.set_device(hvd.local_rank())
    print(hvd.rank())
    print("Loading dataset")

    dataset = load_dataset(config.dataset_name, split="train")
    dataset.set_transform(transform)
    sampler = DistributedSampler(dataset, num_replicas=hvd.size(), rank=hvd.rank(), shuffle=True)
    train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, sampler=sampler, worker_init_fn=seed_everything)

    print("Creating model")
    
    model = UNet2DModel(
        sample_size=config.image_size,  # the target image resolution
        in_channels=3,  # the number of input channels, 3 for RGB images
        out_channels=3,  # the number of output channels
        layers_per_block=2,  # how many ResNet layers to use per UNet block
        block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channes for each UNet block
        down_block_types=(
            "DownBlock2D",  # a regular ResNet downsampling block
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
            "DownBlock2D",
        ),
        up_block_types=(
            "UpBlock2D",  # a regular ResNet upsampling block
            "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
        ),
    )
    
    if zero_init:
        zerO_init_model_(model)
    
    print("Setting up training objects...")
    noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
    optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=config.lr_warmup_steps,
        num_training_steps=(len(train_dataloader) * config.num_epochs),
    )
    
     # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        logging_dir=os.path.join(config.output_dir, "logs"),
    )
    
    if accelerator.is_main_process:
        if config.push_to_hub:
            repo = init_git_repo(config, at_init=True)
        accelerator.init_trackers("train_example")


    # Prepare everything
    # There is no specific order to remember, you just need to unpack the
    # objects in the same order you gave them to the prepare method.
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )
    

    global_step = 0
    print(config.output_dir)
    # Now you train the model
    for epoch in range(config.num_epochs):
        sampler.set_epoch(epoch)
        for step, batch in enumerate(train_dataloader):
            clean_images = batch["images"]
            
            # Sample noise to add to the images
            noise = torch.randn(clean_images.shape).to(clean_images.device)
            bs = clean_images.shape[0]

            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device
            ).long()

            # Add noise to the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            with accelerator.accumulate(model):
                # Predict the noise residual
                noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
                loss = F.mse_loss(noise_pred, noise)
                accelerator.backward(loss)

                accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            if not step % 200 and accelerator.is_main_process:
                print(f"[Epoch {epoch}, {step}/{len(train_dataloader)} steps])", *[str(k) + ": " + str(v) for k,v in logs.items()])
            accelerator.log(logs, step=global_step)
            global_step += 1

        # After each epoch you optionally sample some demo images with evaluate() and save the model
        if accelerator.is_main_process:
            pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)

            if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
                evaluate(config, epoch, pipeline)

            if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
                if config.push_to_hub:
                    push_to_hub(config, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=True)
                else:
                    pipeline.save_pretrained(config.output_dir)

Then, we run the script in a distributed manner with HorovodRunner.

import horovod.torch as hvd
from sparkdl import HorovodRunner
import json


hr = HorovodRunner(np=NUM_WORKERS, driver_log_verbosity='all') 

hr.run(train_loop, config=config)
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
img = mpimg.imread('/dbfs/ml/Group_7/diffusion/samples/0009.png')
plt.figure(figsize=[10,20])
imgplot = plt.imshow(img)
plt.show()

After this, we define ZerO initialisation for the diffusion model used. We initialise the convolutional layers, the linear layers and the attention modules by taking the corresponding functions from the previous notebooks and applying them on the task in question.

import torch
import torch.nn.functional as F
from datasets import load_dataset
from torchvision import transforms
import numpy as np
from scipy.linalg import hadamard
from torch import nn

def zerO_init_conv_layer_(weight):
    """
    In-place initialise the given convolutional layer with zerO-init 
    using the following equation:
    ---------------------------------
    W[:,:,n,n] := c * I_p * H_m * I_p
    ---------------------------------
     where W:   out_dim x in_dim x n_filters
           I_p: out_dim x m  (partial identity)
           H_m: m x m        (Hadamard matrix)
           I_p: m x in_dim   (partial identity)
    """
    out_dim, in_dim, k = weight.shape[:3]
    n = int(np.floor(k / 2))
    
    if out_dim == in_dim:
        weight.data[..., n, n] = torch.eye(in_dim)
    elif out_dim < in_dim:
        weight.data[..., n, n] = partial_identity(out_dim, in_dim).type_as(weight)
    else:
        m = int(np.ceil(np.log2(out_dim)))
        c = 2 ** (-(m - 1) / 2)

        H = lambda dim: torch.tensor(hadamard(dim)).type_as(weight)
        I = lambda outd, ind: partial_identity(outd, ind).type_as(weight)
        
        # NOTE: scipy's hadamard function differs from the paper's definition
        #       in that we need to pass 2^m as its size input instead of m
        weight.data[..., n, n] = (
            c * I(out_dim, 2**m) @ H(2**m) @ I(2**m, in_dim)
        )

def zerO_init_linear(weight):
    """
    Algorithm 1.
    
    hadamard: c * I_p * H_m * I_p
         I_p: out_dim * m
         H_m: m * m
         I_p: m * in_dim
    """
    out_dim, in_dim = weight.shape
    device = weight.device
    
    if out_dim == in_dim:
        weight.data = torch.eye(in_dim)
    elif out_dim < in_dim:
        weight.data = partial_identity(out_dim, in_dim).type_as(weight)
    else:
        m = int(np.ceil(np.log2(out_dim)))
        c = 2 ** (-(m - 1) / 2)

        H = lambda dim: torch.tensor(hadamard(dim)).type_as(weight)

        weight.data = (
            c * H(2**m)[:out_dim, :in_dim]
        )
    weight.data = weight.data.to(device)

def partial_identity(out_dim, in_dim):
    if out_dim < in_dim:
        I = torch.eye(out_dim)
        O = torch.zeros(out_dim, (in_dim - out_dim))

        return torch.cat((I, O), 1)

    elif out_dim > in_dim:
        I = torch.eye(in_dim)
        O = torch.zeros((out_dim - in_dim), in_dim)
        return torch.cat((I, O), 0)

    else:
        return torch.eye(out_dim)


def zerO_init_model_(model):
    for m in model.modules():
        # Initialize relevant matrices to zero in the beginning:
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.zeros_(m.weight)

    for name, m in model.named_modules():
        # Linear
        if isinstance(m, nn.Linear):
            if not name.endswith(".query") and not name.endswith(".key") and not name.endswith(".value"):
                zerO_init_linear(m.weight)
        # Convolution
        elif isinstance(m, nn.Conv2d):
            # Ignore last conv layer in each residual block
            if not name.endswith(".conv2"):
                zerO_init_conv_layer_(m.weight)
        # Attention (Q, K, V); K and V are already initialized to null matrices
        elif isinstance(m, nn.Linear):
            if name.endswith(".query"):
                nn.init.eye_(m.weight)

We train the model with this setting as well:

import horovod.torch as hvd
from sparkdl import HorovodRunner
import json

config.output_dir += "_zero_init"
print(config.output_dir)

hr = HorovodRunner(np=NUM_WORKERS, driver_log_verbosity='all') 

hr.run(train_loop, config=config, zero_init=True)

Visualising samples halfway through training

Below we visually compare the samples from the default- and the zero-initialised models. We conclude that zero initialisation harms the model a little, as the images tend to have less contrast between the foreground and the background, and also sometimes the objects are less-developed. See the highlighted differences near the end of the notebook. NOTE that the samples are very similar only because we fix the random seed during sampling. If we had not fixed the seed, we would have got completely random samples, making visual comparison quite difficult.

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
fig, axes = plt.subplots(1,2, figsize=[30,40])
img1 = mpimg.imread('/dbfs/ml/Group_7/diffusion/samples/0004.png')
img2 = mpimg.imread('/dbfs/ml/Group_7/diffusion_zero_init_zero_init/samples/0004.png')
axes[0].imshow(img1)
axes[0].set_title("default")
axes[1].imshow(img2)
axes[1].set_title("zero_init")
plt.show()
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
fig, axes = plt.subplots(1,2, figsize=[30,40])
img1 = mpimg.imread('/dbfs/ml/Group_7/diffusion/samples/0009.png')
img2 = mpimg.imread('/dbfs/ml/Group_7/diffusion_zero_init_zero_init/samples/0009.png')
axes[0].imshow(img1)
axes[0].set_title("default")
axes[1].imshow(img2)
axes[1].set_title("zero_init")
plt.show()

Difference between the samples

To make comparison easier, we highlight the differences between the samples.

Loss curves

Below we open the tensorboard logs to inspect the training curves. We can see that ZerO is consistently worse than the default initialisation, further corroborating the finding the ZerO is not applicable as an out-of-the-box replacement for default initialisation.

tensorboard
%tensorboard --logdir /dbfs/ml/Group_7/
# To view the logs, please filter for 'diffusion' in the search bar in tensorboard!