0%
April 20, 2022

Different Kinds of Residue Blocks

deep-learning

Two Structures

From the paper Deep Residual Learning for Image Recognition we can find two possible structure:

The left is called a basic block structure, and the right is called a bottleneck structure.

Basic Block Structure

Reference: https://keras.io/examples/generative/cyclegan/

def residual_block(
    x,
    activation,
    kernel_initializer=kernel_init,
    kernel_size=(3, 3),
    strides=(1, 1),
    padding="valid",
    gamma_initializer=gamma_init,
    use_bias=False,
):
    dim = x.shape[-1]
    input_tensor = x

    x = ReflectionPadding2D()(input_tensor)
    x = layers.Conv2D(
        dim,
        kernel_size,
        strides=strides,
        kernel_initializer=kernel_initializer,
        padding=padding,
        use_bias=use_bias,
    )(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    x = activation(x)

    x = ReflectionPadding2D()(x)
    x = layers.Conv2D(
        dim,
        kernel_size,
        strides=strides,
        kernel_initializer=kernel_initializer,
        padding=padding,
        use_bias=use_bias,
    )(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    x = layers.add([input_tensor, x])
    return x
Bottleneck Structure

For deeper network, we use:

def residual_block(
  x,
  filter_depth,
  strides=(1, 1),
  reduce_dim=False,
  reg=0.0001,
  bn_eps=2e-5,
  bn_mom=0.9
):
  shortcut = x
  # we batch-normalize along the the channel axis, which is -1:
  bn = BatchNormalization(axis=-1, epsilon=bn_eps, momentum=bn_mom)(x)
  act = Activation("relu")(bn)
  x = Conv2D(
    int(filter_depth * 0.25),
    (1, 1),
    strides=strides,
    use_bias=False,
    kernel_regularizer=l2(reg)
  )(act)

  x = BatchNormalization(axis=-1, epsilon=bn_eps, momentum=bn_mom)(x)
  x = Activation("relu")(x)
  x = Conv2D(
    int(filter_depth * 0.25),
    (3, 3),
    strides=strides,
    padding="same",
    use_bias=False,
    kernel_regularizer=l2(reg)
  )(x)

  x = BatchNormalization(axis=-1, epsilon=bn_eps, momentum=bn_mom)(x)
  x = Activation("relu")(x)
  x = Conv2D(
    filter_depth,
    (1, 1),
    use_bias=False,
    kernel_regularizer=l2(reg)
  )(x)

  if reduce_dim:
    shortcut = Conv2D(
      filter_depth,
      (1, 1),
      strides=(2, 2),
      use_bias=False,
      kernel_regularizer=l2(reg)
    )(act)

  x = add([x, shortcut])

  return x