tf.data.Dataset
Storing a set of images require a lot of memory, but saving a list of strings does not. In tensorflow we can base our pipeline on saving image paths.
tf.data.Dataset.list_files
We take the horse2zebra
dataset as an example.
dataset = tf.data.Dataset.list_files("horse2zebra/*/**")
We can use take
to get a specific amount of images from dataset generator:
for d in dataset.take(1): print(d) # output: tf.Tensor(b'horse2zebra\\testB\\n02391049_1020.jpg', shape=(), dtype=string)
Data Processing with tf.data.Dataset: .map(), .filter(), .cache(), .shuffle(), .batch()
Most of the time our dataset is structured as dataset/class1/a.jpg
, dataset/class2/b.jpg
, ..., so it is much more helpful to transform the list of images into img, label
format.
We will construct (generated) batches of tensors (this constitutes our dataset
) in the following procedures
These will be done with the help of .map()
, which will also be used to handle data processing. For this we define:
def get_label(file_path): return tf.strings.split(file_path, os.path.sep)[-2] def path_to_imgLabel(file_path): label = get_label(file_path) img = tf.io.read_file(file_path) img = tf.image.decode_jpeg(img, channels=3) return img, label def normalize_img(img): img = tf.cast(img, dtype=tf.float32) # Map values in the range [-1, 1] return (img / 127.5) - 1.0 def preprocess_train_image(img): img = tf.image.random_flip_left_right(img) img = tf.image.resize(img, [*orig_img_size]) img = tf.image.random_crop(img, size=[*input_img_size]) img = normalize_img(img) return img
Now we can chain our processing as follows:
dataset = tf.data.Dataset.list_files("horse2zebra/*/**").map(path_to_imgLabel) # at this point our dataset generates (img, label)'s train_horses = dataset \ .filter(lambda _, label : label == "trainA") \ .map(lambda img, label: img) \ # this line is specific to this file, usually we keep the label .map(preprocess_train_image, num_parallel_calls=autotune) \ .cache() \ .shuffle(buffer_size) \ .batch(batch_size)
Now we can run
for d in train_zebras.take(1): print(d)
to check if the batch of data suits our training purpose.
Graph Plotting by matplotlib.pyplot
We mainly use ax
(array of axes)
import matplotlib.pyplot as plt _, ax = plt.subplots(4, 2, figsize=(10, 15)) for i, samples in enumerate(zip(train_horses.take(4), train_zebras.take(4))): horse = (((samples[0][0] * 127.5) + 127.5).numpy()).astype(np.uint8) zebra = (((samples[1][0] * 127.5) + 127.5).numpy()).astype(np.uint8) ax[i, 0].imshow(horse) ax[i, 1].imshow(zebra) plt.show()
Result:

ReflectionPadding2D
This is a built-in layer in pyTorch, but in tensorflow we need to built it manually:
class ReflectionPadding2D(layers.Layer): def __init__(self, padding=(1, 1), **kwargs): self.padding = tuple(padding) super(ReflectionPadding2D, self).__init__(**kwargs) def call(self, input_tensor, mask=None): padding_width, padding_height = self.padding # no padding for batch_size and channel axis padding_tensor = [ [0, 0], [padding_height, padding_height], [padding_width, padding_width], [0, 0], ] return tf.pad(input_tensor, padding_tensor, mode="REFLECT")
Residual Blocks
There are many version of residue blocks. In cycleGAN
their blocks keep the number of filters and also the spatial dimension. In some other cases the middle Conv2D
layer shrinks the filter depth and finally restored it for the addition operation in skip connection.
def residual_block( x, activation, kernel_initializer=kernel_init, kernel_size=(3, 3), strides=(1, 1), padding="valid", gamma_initializer=gamma_init, use_bias=False, ): dim = x.shape[-1] input_tensor = x x = ReflectionPadding2D()(input_tensor) x = layers.Conv2D( dim, kernel_size, strides=strides, kernel_initializer=kernel_initializer, padding=padding, use_bias=use_bias, )(x) x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x) x = activation(x) x = ReflectionPadding2D()(x) x = layers.Conv2D( dim, kernel_size, strides=strides, kernel_initializer=kernel_initializer, padding=padding, use_bias=use_bias, )(x) x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x) x = layers.add([input_tensor, x]) return x
Though the kenrnel_size
is (3,3)
, but there will be a ReflectionPadding2D
and therefore the finally the width
remains unchanged ( comes from padding), and so does the height.
Model.compile() and Model.train_step()
In GANs most of the last training step must be implemented manually. We record the one in cycleGAN for reference. The technique will apply to all other custom training.
Note the analogy with PyTorch's .backward()
and .step()
.
class CycleGan(keras.Model): def __init__( self, generator_G, generator_F, discriminator_X, discriminator_Y, lambda_cycle=10.0, lambda_identity=0.5, ): super(CycleGan, self).__init__() self.gen_G = generator_G self.gen_F = generator_F self.disc_X = discriminator_X self.disc_Y = discriminator_Y self.lambda_cycle = lambda_cycle self.lambda_identity = lambda_identity def compile( self, gen_G_optimizer, gen_F_optimizer, disc_X_optimizer, disc_Y_optimizer, gen_loss_fn, disc_loss_fn, ): super(CycleGan, self).compile() self.gen_G_optimizer = gen_G_optimizer self.gen_F_optimizer = gen_F_optimizer self.disc_X_optimizer = disc_X_optimizer self.disc_Y_optimizer = disc_Y_optimizer self.bv = gen_loss_fn self.discriminator_loss_fn = disc_loss_fn self.cycle_loss_fn = keras.losses.MeanAbsoluteError() self.identity_loss_fn = keras.losses.MeanAbsoluteError() def train_step(self, batch_data): # x is Horse and y is zebra real_x, real_y = batch_data # For CycleGAN, we need to calculate different # kinds of losses for the generators and discriminators. # We will perform the following steps here: # # 1. Pass real images through the generators and get the generated images # 2. Pass the generated images back to the generators to check if we # we can predict the original image from the generated image. # 3. Do an identity mapping of the real images using the generators. # 4. Pass the generated images in 1) to the corresponding discriminators. # 5. Calculate the generators total loss (adverserial + cycle + identity) # 6. Calculate the discriminators loss # 7. Update the weights of the generators # 8. Update the weights of the discriminators # 9. Return the losses in a dictionary with tf.GradientTape(persistent=True) as tape: # Horse to fake zebra fake_y = self.gen_G(real_x, training=True) # Zebra to fake horse -> y2x fake_x = self.gen_F(real_y, training=True) # Cycle (Horse to fake zebra to fake horse): x -> y -> x cycled_x = self.gen_F(fake_y, training=True) # Cycle (Zebra to fake horse to fake zebra) y -> x -> y cycled_y = self.gen_G(fake_x, training=True) # Identity mapping # expect/hope that G|_{zebras} = id and F|_{horses} = idm # i.e., almost no change same_x = self.gen_F(real_x, training=True) same_y = self.gen_G(real_y, training=True) # Discriminator output disc_real_x = self.disc_X(real_x, training=True) disc_fake_x = self.disc_X(fake_x, training=True) disc_real_y = self.disc_Y(real_y, training=True) disc_fake_y = self.disc_Y(fake_y, training=True) # Generator adverserial loss gen_G_loss = self.generator_loss_fn(disc_fake_y) gen_F_loss = self.generator_loss_fn(disc_fake_x) # Generator cycle loss cycle_loss_G = self.cycle_loss_fn(real_y, cycled_y) * self.lambda_cycle cycle_loss_F = self.cycle_loss_fn(real_x, cycled_x) * self.lambda_cycle # Generator identity loss id_loss_G = ( self.identity_loss_fn(real_y, same_y) * self.lambda_cycle * self.lambda_identity ) id_loss_F = ( self.identity_loss_fn(real_x, same_x) * self.lambda_cycle * self.lambda_identity ) # Total generator loss total_loss_G = gen_G_loss + cycle_loss_G + id_loss_G total_loss_F = gen_F_loss + cycle_loss_F + id_loss_F # Discriminator loss disc_X_loss = self.discriminator_loss_fn(disc_real_x, disc_fake_x) disc_Y_loss = self.discriminator_loss_fn(disc_real_y, disc_fake_y) # Get the gradients for the generators # loss.backward() as in pyTorch grads_G = tape.gradient(total_loss_G, self.gen_G.trainable_variables) grads_F = tape.gradient(total_loss_F, self.gen_F.trainable_variables) # Get the gradients for the discriminators disc_X_grads = tape.gradient(disc_X_loss, self.disc_X.trainable_variables) disc_Y_grads = tape.gradient(disc_Y_loss, self.disc_Y.trainable_variables) # Update the weights of the generators # optimizer.step() as in pyTorch self.gen_G_optimizer.apply_gradients( zip(grads_G, self.gen_G.trainable_variables) ) self.gen_F_optimizer.apply_gradients( zip(grads_F, self.gen_F.trainable_variables) ) # Update the weights of the discriminators self.disc_X_optimizer.apply_gradients( zip(disc_X_grads, self.disc_X.trainable_variables) ) self.disc_Y_optimizer.apply_gradients( zip(disc_Y_grads, self.disc_Y.trainable_variables) ) # Conclusion: the only difference to pytorch is that we need to # wrap the calculation of loss to get the calculation graph # i.e., wrap the stuff whose weight needs to be updated. # which usually starts from the beginning of getting batch_data return { "G_loss": total_loss_G, "F_loss": total_loss_F, "D_X_loss": disc_X_loss, "D_Y_loss": disc_Y_loss, }
Monitor
class GANMonitor(keras.callbacks.Callback): """A callback to generate and save images after each epoch""" def __init__(self, num_img=4): self.num_img = num_img def on_epoch_end(self, epoch, logs=None): _, ax = plt.subplots(4, 2, figsize=(12, 12)) for i, img in enumerate(test_horses.take(self.num_img)): prediction = self.model.gen_G(img)[0].numpy() prediction = (prediction * 127.5 + 127.5).astype(np.uint8) img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8) ax[i, 0].imshow(img) ax[i, 1].imshow(prediction) ax[i, 0].set_title("Input image") ax[i, 1].set_title("Translated image") ax[i, 0].axis("off") ax[i, 1].axis("off") prediction = keras.preprocessing.image.array_to_img(prediction) prediction.save( "generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch + 1) ) plt.show() plt.close()
Start the Training Process
Our gen_G
and gen_F
only need optimizer, we calculate the loss on our own. We also pass tensorflow built-in loss functions for convenience (which will turns out to be one of the summands in our total loss).
# Loss function for evaluating adversarial loss adv_loss_fn = keras.losses.MeanSquaredError() # Define the loss function for the generators def generator_loss_fn(fake): fake_loss = adv_loss_fn(tf.ones_like(fake), fake) return fake_loss # Define the loss function for the discriminators def discriminator_loss_fn(real, fake): real_loss = adv_loss_fn(tf.ones_like(real), real) fake_loss = adv_loss_fn(tf.zeros_like(fake), fake) return (real_loss + fake_loss) * 0.5 # Create cycle gan model cycle_gan_model = CycleGan( generator_G=gen_G, generator_F=gen_F, discriminator_X=disc_X, discriminator_Y=disc_Y ) # Compile the model cycle_gan_model.compile( gen_G_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5), gen_F_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5), disc_X_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5), disc_Y_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5), gen_loss_fn=generator_loss_fn, disc_loss_fn=discriminator_loss_fn, ) # Callbacks plotter = GANMonitor() checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.{epoch:03d}" model_checkpoint_callback = keras.callbacks.ModelCheckpoint( filepath=checkpoint_filepath ) # Here we will train the model for just one epoch as each epoch takes around # 7 minutes on a single P100 backed machine. cycle_gan_model.fit( tf.data.Dataset.zip((train_horses, train_zebras)), epochs=1, callbacks=[plotter, model_checkpoint_callback], )