0%
March 7, 2023

Diffusion Model Study

deep-learning

python

Result on Local Machine

Due to time constraint I didn't wait the model to train for long enough time as It is convincing to me that the model is trying to converge.

Reference

Main reference for coding part:

Other references for theory:

Introduction

  • Define as the image at time for . When we travel about , we add noise gradually until the image is unreadable. The noise is added by

    for , where and .

  • In practice (coding) will go from to , needs not be very small when the noise is enough.

  • Note that started from , we no longer consider as a concrete image, rather we consider as a random variable where only the mean and variance makes perfect sense.

  • The true image depense on the instance of values that a gaussian noise provide.

  • That means denotes a set of possibilities of images (data point). To understant , we need to understand the density of the probability distribution .

  • By direct expansion we have

    where .

  • Since , the last term becomes

    for some .

  • Define for , then for some .

  • Note that .

  • The forward process of adding noise is denoted

    We wish to calculate the reverse (denoise) process

  • Recall the Bayse Forumla .

  • Given a set of images , we wish to understand the distribution of , i.e., we wish to calculate

  • usually is used to denote known distribution.

  • To emphasize we don't truly understand the distribution, we replace by to denote unknown distribution (the distribution that we are going to find or estimate, or to learn), the problem becomes estimating the distribution .

  • As we know we add random noise from to , it makes no sense to estimate the exact value of a random variable from .

  • Therefore what we want to estiamte is the average of from an instance of .

  • By Bayse formula, .

  • To enable ourself to do computation, we also assume follows Gaussian distribution.

  • We now try to estimate the mean of , name it .

  • We have already studied the distribution of in .

  • Suppose that is given and we know that it comes from the previous distribution by adding gaussian noise with some weight (in the same way as before), then

    where .

  • By comparing coefficients we have

    by we have

  • will be what we are trying to learn.

  • When we predict image in reverse timesteps, we iteratively predict image by

    for some sampled from normal distribution. In code it is implemented as follows:

    @torch.no_grad()
    def sample_timestep(model, x, t):
        """
        Calls the model to predict the noise in the image and returns
        the denoised image.
        Applies noise to this image, if we are not in the last step yet.
        """
        betas_t = get_index_from_list(betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
            sqrt_one_minus_alphas_cumprod, t, x.shape
        )
        sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
    
        # Call model (current image - noise prediction)
        model_mean = sqrt_recip_alphas_t * (
            x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
        )
        posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
    
        if t == 0:
            return model_mean
        else:
            noise = torch.randn_like(x)
            return model_mean + torch.sqrt(posterior_variance_t) * noise
  • From to we add a noize . We estimate (learn) from to , then will be our ground truth in the model. We elaborate this in the next section.

    In code it is implemented as follows:

    def get_loss(model, x_0, times):
      # times is of shape (128, )
      x_noisy, noise = forward_diffusion_sample(x_0, times, device)
      # 128 time times, therefore 128 images, x_noisy is of shape [128, 3, 64, 64]
      noise_pred = model(x_noisy, times)
    
      return F.l1_loss(noise, noise_pred)

Training Algorithm

  • In algorithm on the LHS:

    1. means we sample an image from our collection of image dataset ( means the distribution of the images that lives in, like category of dogs, cats, etc)

    2. means the timestamp is uniformly random

    3. means the noise we add from to .

    4. is the estimate of from to (as we want to do the reverse). This is estimated from

      • (see ) and
      • timestamp

      our loss function becomes .

Coding

Constants
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)

def get_index_from_list(vals, t, x_shape):
    """
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    # same as .reshape( (batch_size,) + ((1,) * (len(x_shape) - 1)) )
    result = out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
    return result

# Define beta schedule
T = 300
IMG_SIZE = 64
TIMESTEPS_BATCH_SIZE = 128
betas = linear_beta_schedule(timesteps=T)

# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
SinusoidalPositionEmbeddings

This is exactly the same as the one we use in transformer, which basically takes a time to a vector of size (32,).

To recall, positional encoding takes the following form: for each fixed ,

where .

1class SinusoidalPositionEmbeddings(nn.Module):
2    def __init__(self, dim):
3        super().__init__()
4        # dim = 32
5        self.dim = dim
6
7    def forward(self, times):
8        half_dim = self.dim // 2
9        embeddings = math.log(10000) / (half_dim - 1)
10        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
11
12        # a_i = 1/10000^(i/half_dim)
13        # embeddings above = [a_1, a_2, a_3, ..., a_16]
14        embeddings = times[:, None] * embeddings[None, :]
15        # embeddings above <=>
16        # t |-> ( sin t*a_1, cos t*a_1, sin t*a_2, cos t*a_2, sin t*a_3, cos t*a_3, ... )
17        # for each t, therefore the final dimension will be (128, 32)
18        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
19        # TODO: Double check the ordering here
20        return embeddings

The variable embeddings in line 14 above is exactly

with timestep in place of pos above.

UNet that Predicts Noise
class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """

    def __init__(self):
        super().__init__()
        image_channels = 3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 1
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        ).to(device)

        # Initial projection
        # stride = 1, padding = 1, no change in spatial dimension
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1).to(device)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i + 1],
                                    time_emb_dim).to(device)
                                    for i in range(len(down_channels) - 1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i + 1],
                                        time_emb_dim, up=True).to(device)
                                  for i in range(len(up_channels) - 1)])

        self.output = nn.Conv2d(up_channels[-1], 3, out_dim).to(device)

    def forward(self, x, times):
        # Embedd time
        t = self.time_mlp(times)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            # for the bottom block the x adds an identical copy of x (just poped out) for unity of coding.
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
        return self.output(x)
Sampling / Prediction
@torch.no_grad()
def sample_timestep(model, x, t):
    """
    Calls the model to predict the noise in the image and returns
    the denoised image.
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)

    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)

    if t == 0:
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise


@torch.no_grad()
def sample_plot_image(model, img_path):
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=device)
    plt.figure(figsize=(15, 15))
    plt.axis('off')
    num_images = 10
    stepsize = int(T / num_images)

    for i in range(0, T)[::-1]:
        # just create a tensor t of shape (1,), the result is [1], [2], ..., etc
        times = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(model, img, times)
        if i % stepsize == 0:
            plt.subplot(1, num_images, i // stepsize + 1)
            show_tensor_image(img.detach().cpu())

    plt.savefig(img_path)