0%
April 1, 2022

Color-GAN with Auto-Coloring on Animated Gray Images

deep-learning

tensorflow

About this Article

This article is an experiment inspired by this tutorial: Colorizing black & white images with U-Net and conditional GAN — A Tutorial.

Since I am not used to pytorch which the tutorial bases on, the following things will be rewritten in tensorflow:

  • The models
  • The data processing pipeline
  • The generator of dataset
  • The train loop (including the update of training weights)
  • The visualization of our results

Enjoy!

Results

Since the training datasets are just animated characters. The only common characteristic are the color of skins, therefore the model is not able to paint clothes, hair in a colorful way (as it has no idea how to learn).

Original image:

Transferred to gray scale and let the GAN color it:

Preliminary Import

Usual Packages
from numpy.random import randint
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MeanAbsoluteError, MeanSquaredError
from tensorflow.keras.initializers import  HeNormal
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from tensorflow.keras.layers import (
  Input, Conv2D, Conv2DTranspose, LeakyReLU,
  Activation, Concatenate, BatchNormalization, ZeroPadding2D
)
import numpy as np
import cv2
import tensorflow as tf
import matplotlib.pyplot as plt
from glob import glob
import os
from tensorflow.keras.models import Model
from tqdm.notebook import tqdm
from skimage.color import rgb2lab, lab2rgb
import os

%matplotlib inline
For Using GPU
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'

config = tf.compat.v1.ConfigProto(
  gpu_options=tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.8)
)
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
tf.compat.v1.keras.backend.set_session(session)

RGB Color Space and Lab Color Space (WIP)

To be added.

Dataset Generator

We will again make use of the tf.data.Dataset.list_files as a starting point to create a tf.data.Dataset object.

The Dataset object is very handy because of the following useful methods:

  • .map()
  • .filter()
  • .shuffle(buffer_size)
  • .batch(batch_size)
  • .cache()
Preprocessing Functions
SIZE = 256

def path_to_img(file_path):
  img = tf.io.read_file(file_path)
  img = tf.image.decode_jpeg(img, channels=3)
  img = tf.image.resize(img, (SIZE, SIZE))
  return img

def rgb_normalize_to_0_1(img):
  img = tf.cast(img, dtype=tf.float32)
  return img/255

def rgb_denormalize_from_0_1(img):
  return (img + 1)*127.5

def lab_normalize_to_minus1_to_1(img):
  L = img[:, :, 0] / 50. - 1
  ab = img[:, :, [1, 2]]/110.
  return L[..., np.newaxis], ab

def lab_denormalize_from_minus1_to_1(img):
  L = (img[:, :, [0]] + 1) * 50
  ab = img[:, :, [1, 2]] * 110
  return np.concatenate([L, ab], axis=-1)

def process_img(img):
  # tf.numpy_function(func=lambda x: print(x), inp=[img], Tout=tf.float32)
  img = img/255.
  img = tf.image.random_flip_left_right(img)

  img_lab = tf.numpy_function(func=lambda x: rgb2lab(x).astype("float32"),
                              inp=[img],
                              Tout=tf.float32)

  L, ab = tf.numpy_function(func=lab_normalize_to_minus1_to_1,
                            inp=[img_lab],
                            Tout=[tf.float32, tf.float32])

  return L, ab
Chaining Preprocessing Functions

Therefore we just need to take care of how to preprocess data from individual file path. Dataset's api will handle the rest.

Among the above, we use .map to chain our data processing pipeline:

def get_data_generator():
  buffer_size = 100
  batch_size = 16
  dataset = tf.data.Dataset.list_files(f"{dataset_name}/*.jpg")\
    .map(path_to_img)\
    .map(process_img)\
    .shuffle(buffer_size)\
    .batch(batch_size)

  return (data for data in iter(dataset))

Implementation

Generator by UNet Structure
def conv_block(n_filters, input, kernel_initialization=None):
  if kernel_initialization:
    y = Conv2D(n_filters, (3, 3), strides=(2, 2), padding="same", use_bias=False, kernel_initializer=kernel_initialization)(input)
  else:
    y = Conv2D(n_filters, (3, 3), strides=(2, 2), padding="same", use_bias=False)(input)
  y = BatchNormalization()(y)
  y = LeakyReLU(0.2)(y)
  return y

def upconv_block(n_filters, input, skip_connection):
  u = Conv2DTranspose(n_filters, (3, 3), strides=(2, 2), padding="same")(input)
  u = Concatenate(axis=-1)([u, skip_connection])
  u = Conv2D(n_filters, (3, 3), strides=1, padding="same", activation="relu")(u)
  u = Conv2D(n_filters, (3, 3), strides=1, padding="same", activation="relu")(u)
  return u

def get_generator():
  init = HeNormal()
  x = Input(shape=(SIZE, SIZE, 1))
  d1 = conv_block(64, x, kernel_initialization=init)
  d2 = conv_block(128, d1)
  d3 = conv_block(256, d2)

  d4 = conv_block(512, d3)

  u3 = upconv_block(256, d4, d3)
  u2 = upconv_block(128, u3, d2)
  u1 = upconv_block(64, u2, d1)

  final = upconv_block(2, u1, x)
  final = Activation("tanh")(final)

  return Model(x, final)
PatchGAN Discriminator by Repeated Conv Blocks
def add_padding(padding=(1,1)):
  return ZeroPadding2D(padding=padding)

def get_discriminator():
  input = Input(shape=(256,256,2))
  x = add_padding()(input)
  x = Conv2D(64, (4, 4), strides=2, padding="same", use_bias=False)(x)
  x = LeakyReLU(0.2)(x)

  x = add_padding()(x)
  x = Conv2D(128, (4, 4), strides=2, padding="same", use_bias=False)(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)

  x = add_padding()(x)
  x = Conv2D(256, (4, 4), strides=2, padding="same", use_bias=False)(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)

  x = add_padding()(x)
  x = Conv2D(512, (4, 4), strides=1, padding="same", use_bias=False)(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)


  x = add_padding()(x)
  x = Conv2D(1, (4, 4), strides=1, padding="same")(x)
  return Model(input, x)

Training

Functions to Visualize Intermediate Performance
def get_gray_image_from_path(img_path):
  im_gray = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
  im_gray_3ch = np.concatenate([im_gray[...,np.newaxis] for _ in range(3)], axis=-1)
  return im_gray_3ch

def visualize_result(epoch=0, step=0):
  random_index = np.random.randint(0, len(imgs_paths))
  img_path = imgs_paths[random_index]
  im_gray_3ch = get_gray_image_from_path(img_path)
  filename = os.path.basename(img_path)

  original_size = im_gray_3ch.shape[0:2][::-1]
  img_ = cv2.resize(im_gray_3ch, dsize=(SIZE, SIZE), interpolation=cv2.INTER_CUBIC)
  img_ = img_/255.
  img_lab = rgb2lab(img_).astype("float32")
  L, _ = lab_normalize_to_minus1_to_1(img_lab)
  faked_coloring = gen.predict(np.array([L]))[0]
  colored_img_in_lab_in_minus1_to_1 = np.concatenate([L, faked_coloring], axis=-1)
  colored_img_in_lab = lab_denormalize_from_minus1_to_1(colored_img_in_lab_in_minus1_to_1)
  faked_colored_image = (lab2rgb(colored_img_in_lab) * 255).astype("uint8")
  faked_colored_image = cv2.resize(faked_colored_image, dsize=original_size, interpolation=cv2.INTER_CUBIC)

  plt.figure(figsize=(18, 30))
  plt.subplot(1, 2, 1)
  plt.axis("off")
  plt.imshow(im_gray_3ch.astype("uint8"))

  plt.subplot(1, 2, 2)
  plt.axis("off")
  plt.imshow(faked_colored_image)

  taget_folder = "./epoch_{}".format(str(epoch).zfill(2))
  if not os.path.exists(taget_folder):
    os.makedirs(taget_folder)

  plt.savefig("./epoch_{}/result_{}_from_{}.png".format(str(epoch).zfill(2), str(step).zfill(3), filename), dpi=80, bbox_inches="tight")
Start Training
Training Without Noise

Now we implement our custom training loop for 10 epochs. We use get_data_generator to get new dataset for each epoch.

We start our model/data initilization and training loop in separate code block:

gen_opt = Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.999)
disc_opt = Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.999)

gen_L1_loss_lambda = 100

mse = MeanSquaredError()
mae = MeanAbsoluteError()

and start:

for epoch in range(0, 10):
  epoch = epoch + 1
  batch = 0
  data_generator = get_data_generator()
  while True:
    try:
      batch += 1
      print(f"{batch}-th batch", end="\r")
      Ls, abs = next(data_generator)
      real_images = tf.concat([Ls, abs], axis=-1)

      with tf.GradientTape(persistent=True) as tape:
        faked_coloring = gen(Ls)
        true_coloring = abs

        critic_on_faked_colorings = disc(faked_coloring)
        critic_on_true_coloring = disc(true_coloring)


        gen_loss = mse(tf.ones_like(critic_on_faked_colorings), critic_on_faked_colorings)\
                        + 100 * mae(abs, faked_coloring)

        disc_loss = 0.5 * mse(tf.zeros_like(critic_on_faked_colorings), critic_on_faked_colorings)\
            + 0.5 * mse(tf.ones_like(critic_on_true_coloring), critic_on_true_coloring)


      grad_gen = tape.gradient(gen_loss, gen.trainable_variables)
      grad_disc = tape.gradient(disc_loss, disc.trainable_variables)

      gen_opt.apply_gradients(zip(grad_gen, gen.trainable_variables))
      disc_opt.apply_gradients(zip(grad_disc, disc.trainable_variables))

      if batch % 10 == 0:
        visualize_result(epoch, int(batch/10))

    except StopIteration:
      print(f"Epoch {epoch} Ended")
      break
    except Exception as err:
      print(err)
      break
Training With Noise (WIP)
Training With Pretrained VGG-16 as a Backbone (WIP)

References