0%
April 22, 2022

YOLOv3 Deep Dive

deep-learning

About this Article

The main objective of this article is to understand every function of YOLO-v3 implemented in tensorfow. We will understand the algorithm by examining the shape of every output.

The sections below are named by module_name::{ function_1, function_2, ... }, where module_name means we are studying the file module_name.py.

The link of source code are provided at the beginning of each of such sections.

dataset::{ parse_annotation, bbox_iou }

Source Code

def parse_annotation(self, annotation):
    line = annotation.split()
    image_path = line[0]
    if not os.path.exists(image_path):
        raise KeyError("%s does not exist ... " % imagparse_annotation

        e_path)
    image = cv2.imread(image_path)
    bboxes = np.array([list(map(int, box.split(','))) for box in line[1:]])

    if self.data_aug:
        image, bboxes = self.random_horizontal_flip(np.copy(image), np.copy(bboxes))
        image, bboxes = self.random_crop(np.copy(image), np.copy(bboxes))
        image, bboxes = self.random_translate(np.copy(image), np.copy(bboxes))

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image, bboxes = utils.image_preporcess(np.copy(image), [self.train_input_size, self.train_input_size], np.copy(bboxes))
    return image, bboxes
    line = annotation.split()
    image_path = line[0]
    if not os.path.exists(image_path):
        raise KeyError("%s does not exist ... " % image_path)
    image = cv2.imread(image_path)
    bboxes = np.array([list(map(int, box.split(','))) for box in line[1:]])

    if self.data_aug:
        image, bboxes = self.random_horizontal_flip(np.copy(image), np.copy(bboxes))
        image, bboxes = self.random_crop(np.copy(image), np.copy(bboxes))
        image, bboxes = self.random_translate(np.copy(image), np.copy(bboxes))

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image, bboxes = utils.image_preporcess(np.copy(image), [self.train_input_size, self.train_input_size], np.copy(bboxes))
    return image, bboxes

This part is straightforward.

1def bbox_iou(self, boxes1, boxes2):
2
3    boxes1 = np.array(boxes1)
4    boxes2 = np.array(boxes2)
5
6    boxes1_area = boxes1[..., 2] * boxes1[..., 3]
7    boxes2_area = boxes2[..., 2] * boxes2[..., 3]
8
9    boxes1 = np.concatenate([boxes1[..., :2] - boxes1[..., 2:] * 0.5,
10                             boxes1[..., :2] + boxes1[..., 2:] * 0.5], axis=-1)
11    boxes2 = np.concatenate([boxes2[..., :2] - boxes2[..., 2:] * 0.5,
12                             boxes2[..., :2] + boxes2[..., 2:] * 0.5], axis=-1)

From the function we can deduce that , therefore

and

while concatenating np-array along the last axis simpliy means combing them into one np-array. Numerically in training and are like:

[[0.59375 3.78125 0.5     0.625  ]]

and

[[ 0.5      3.5      3.625    2.8125 ]
 [ 0.5      3.5      4.875    6.1875 ]
 [ 0.5      3.5     11.65625 10.1875 ]]

respectively,

13    left_up = np.maximum(boxes1[..., :2], boxes2[..., :2])
14    right_down = np.minimum(boxes1[..., 2:], boxes2[..., 2:])

Think of the above as entrywise comparisons that give an array of maximum, which yields the coordinates of intersection rectangle for each fixed to boxes in .

15    inter_section = np.maximum(right_down - left_up, 0.0)

The entries in are the and of the intersection, the (broadcasted) is just a tricky way to handle empty intersection.

16    inter_area = inter_section[..., 0] * inter_section[..., 1]
17    union_area = boxes1_area + boxes2_area - inter_area
18
19    return inter_area / union_area

1def preprocess_true_boxes(self, bboxes):

Here are the boxes from annotation file in which each line takes the form:

2    label = [np.zeros((self.train_output_sizes[i],
3                       self.train_output_sizes[i],
4                       self.anchor_per_scale,
5                       5 + self.num_classes)) for i in range(3)]
6    bboxes_xywh = [np.zeros((self.max_bbox_per_scale, 4)) for _ in range(3)]
7    bbox_count = np.zeros((3,))
8
9    for bbox in bboxes:
10        bbox_coor = bbox[:4]
11        bbox_class_ind = bbox[4]
12
13        onehot = np.zeros(self.num_classes, dtype=np.float)
14        onehot[bbox_class_ind] = 1.0
15        uniform_distribution = np.full(self.num_classes, 1.0 / self.num_classes)
16        deta = 0.01
17        smooth_onehot = onehot * (1 - deta) + deta * uniform_distribution
18        # bbox_xywh is ground truth
19        bbox_xywh = np.concatenate([(bbox_coor[2:] + bbox_coor[:2]) * 0.5, bbox_coor[2:] - bbox_coor[:2]], axis=-1)
20        # bbox_xywh_scaled is scaled ground truth relative to stride (13, 26, 52, as a unit)
21        bbox_xywh_scaled = 1.0 * bbox_xywh[np.newaxis, :] / self.strides[:, np.newaxis]

Note that

is of shape and

is of shape , their multiplication will be conducted by "broadcasting'' in numpy, which yields a dimensional numpy array. The product consists of which use "stride" as a unit, so 1 means "1 grid" (recall there are , , grids predictions from Darknet backbone).

22        iou = []
23        exist_positive = False
24        for i in range(3):
25            anchors_xywh = np.zeros((self.anchor_per_scale, 4))
26            anchors_xywh[:, 0:2] = np.floor(bbox_xywh_scaled[i, 0:2]).astype(np.int32) + 0.5
27            anchors_xywh[:, 2:4] = self.anchors[i]

essentially moves centers of to the middle of the grid where the original center lies in, then the anchor boxes' width and height are assigned, replacing the original width, height of .

28            iou_scale = self.bbox_iou(bbox_xywh_scaled[i][np.newaxis, :], anchors_xywh)

The presence of is simply because multiplication between and array does not make sense. The additional dimension expand array into array, which is broadcasted and multiplied to array to give another array, and therefore, .

29            iou.append(iou_scale)
30            iou_mask = iou_scale > 0.3 # a boolean list of length 3
31
32            if np.any(iou_mask): # if one of them is True
33                xind, yind = np.floor(bbox_xywh_scaled[i, 0:2]).astype(np.int32)
34                label[i][yind, xind, iou_mask, :] = 0
35                label[i][yind, xind, iou_mask, 0:4] = bbox_xywh
36                label[i][yind, xind, iou_mask, 4:5] = 1.0
37                label[i][yind, xind, iou_mask, 5:] = smooth_onehot

is initialized at the beginning which is of size

for each , where or .

38                bbox_ind = int(bbox_count[i] % self.max_bbox_per_scale)
39                bboxes_xywh[i][bbox_ind, :4] = bbox_xywh

is initialized (with zeros) at the beginning, .

40                bbox_count[i] += 1
41
42                exist_positive = True
43
44        if not exist_positive:
45            best_anchor_ind = np.argmax(np.array(iou).reshape(-1), axis=-1) # flatten and take max
46            # best_detect belongs to which "i", namely, best "i"
47            best_detect = int(best_anchor_ind / self.anchor_per_scale)
48            # for this i, which index it is:
49            best_anchor = int(best_anchor_ind % self.anchor_per_scale)
50            # get the grid point in our 13x13, 26x26, 52x52 grid:
51            xind, yind = np.floor(bbox_xywh_scaled[best_detect, 0:2]).astype(np.int32)
52
53            label[best_detect][yind, xind, best_anchor, :] = 0
54            label[best_detect][yind, xind, best_anchor, 0:4] = bbox_xywh
55            label[best_detect][yind, xind, best_anchor, 4:5] = 1.0
56            label[best_detect][yind, xind, best_anchor, 5:] = smooth_onehot
57
58            bbox_ind = int(bbox_count[best_detect] % self.max_bbox_per_scale)
59            bboxes_xywh[best_detect][bbox_ind, :4] = bbox_xywh
60            # assign bbox_xywh into the list instead of append,
61            # this is to make sure there are at most 150 boxes within all 3 scales.
62
63            bbox_count[best_detect] += 1
64    label_sbbox, label_mbbox, label_lbbox = label
65    sbboxes, mbboxes, lbboxes = bboxes_xywh
66    return label_sbbox, label_mbbox, label_lbbox, sbboxes, mbboxes, lbboxes

In short,

are just ground truth bounding boxes (center, width and height), while

are ground truth bounding boxes with objectiveness and probabilities of each grid.

Model Structure of YOLOv3

In after a bunch of residue modules we get 3 branches , and , where route_1, route_2, conv = backbone.darknet53(input_layer) in YOLOv3 function. Moreover,

Each branch then jumps into several stages of feature extractions, the whole process finally gives another 3 branches of undecoded/raw data of features, and they are endowed with the meaning of "grid-based detection" after reshaping into

dimensional array in decode function (see conv_output in the body).

yolov3::{ YOLOv3, decode }

Source Code

1def YOLOv3(input_layer):
2    route_1, route_2, conv = backbone.darknet53(input_layer)
3
4    conv = common.convolutional(conv, (1, 1, 1024,  512))
5    conv = common.convolutional(conv, (3, 3,  512, 1024))
6    conv = common.convolutional(conv, (1, 1, 1024,  512))
7    conv = common.convolutional(conv, (3, 3,  512, 1024))
8    conv = common.convolutional(conv, (1, 1, 1024,  512))

As padding="same" is being used along the chain of conv nets, there is no spatial dimension change.

9    conv_lobj_branch = common.convolutional(conv, (3, 3, 512, 1024))
10    conv_lbbox = common.convolutional(conv_lobj_branch, (1, 1, 1024, 3*(NUM_CLASS + 5)), activate=False, bn=False)
11
12    conv = common.convolutional(conv, (1, 1,  512,  256))
13    conv = common.upsample(conv)
14
15    conv = tf.concat([conv, route_2], axis=-1)
16
17    conv = common.convolutional(conv, (1, 1, 768, 256))
18    conv = common.convolutional(conv, (3, 3, 256, 512))
19    conv = common.convolutional(conv, (1, 1, 512, 256))
20    conv = common.convolutional(conv, (3, 3, 256, 512))
21    conv = common.convolutional(conv, (1, 1, 512, 256))
22
23    conv_mobj_branch = common.convolutional(conv, (3, 3, 256, 512))
24    conv_mbbox = common.convolutional(conv_mobj_branch, (1, 1, 512, 3*(NUM_CLASS + 5)), activate=False, bn=False)
25
26    conv = common.convolutional(conv, (1, 1, 256, 128))
27    conv = common.upsample(conv)
28
29    conv = tf.concat([conv, route_1], axis=-1)
30
31    conv = common.convolutional(conv, (1, 1, 384, 128))
32    conv = common.convolutional(conv, (3, 3, 128, 256))
33    conv = common.convolutional(conv, (1, 1, 256, 128))
34    conv = common.convolutional(conv, (3, 3, 128, 256))
35    conv = common.convolutional(conv, (1, 1, 256, 128))
36
37    conv_sobj_branch = common.convolutional(conv, (3, 3, 128, 256))
38    conv_sbbox = common.convolutional(conv_sobj_branch, (1, 1, 256, 3*(NUM_CLASS +5)), activate=False, bn=False)
39
40    return [conv_sbbox, conv_mbbox, conv_lbbox]

1def decode(conv_output, i=0):
2    """
3    return tensor of shape [batch_size, output_size, output_size, anchor_per_scale, 5 + num_classes]
4            contains (x, y, w, h, score, probability)
5    """

is the output of (i.e., , or ).

6    conv_shape       = tf.shape(conv_output)
7    batch_size       = conv_shape[0]
8    output_size      = conv_shape[1]
9
10    conv_output = tf.reshape(conv_output, (batch_size, output_size, output_size, 3, 5 + NUM_CLASS))
11
12    conv_raw_dxdy = conv_output[:, :, :, :, 0:2]
13    conv_raw_dwdh = conv_output[:, :, :, :, 2:4]
14    conv_raw_conf = conv_output[:, :, :, :, 4:5]
15    conv_raw_prob = conv_output[:, :, :, :, 5: ]
16
17    y = tf.tile(tf.range(output_size, dtype=tf.int32)[:, tf.newaxis], [1, output_size])
18    x = tf.tile(tf.range(output_size, dtype=tf.int32)[tf.newaxis, :], [output_size, 1])

For example, let's take , then

and

are respectively:

[[ 0  0  0  0  0  0  0  0  0  0  0  0  0]   [[ 0  1  2  3  4  5  6  7  8  9 10 11 12]
 [ 1  1  1  1  1  1  1  1  1  1  1  1  1]    [ 0  1  2  3  4  5  6  7  8  9 10 11 12]
 [ 2  2  2  2  2  2  2  2  2  2  2  2  2]    [ 0  1  2  3  4  5  6  7  8  9 10 11 12]
 [ 3  3  3  3  3  3  3  3  3  3  3  3  3]    [ 0  1  2  3  4  5  6  7  8  9 10 11 12]
 [ 4  4  4  4  4  4  4  4  4  4  4  4  4]    [ 0  1  2  3  4  5  6  7  8  9 10 11 12]
 [ 5  5  5  5  5  5  5  5  5  5  5  5  5]    [ 0  1  2  3  4  5  6  7  8  9 10 11 12]
 [ 6  6  6  6  6  6  6  6  6  6  6  6  6]    [ 0  1  2  3  4  5  6  7  8  9 10 11 12]
 [ 7  7  7  7  7  7  7  7  7  7  7  7  7]    [ 0  1  2  3  4  5  6  7  8  9 10 11 12]
 [ 8  8  8  8  8  8  8  8  8  8  8  8  8]    [ 0  1  2  3  4  5  6  7  8  9 10 11 12]
 [ 9  9  9  9  9  9  9  9  9  9  9  9  9]    [ 0  1  2  3  4  5  6  7  8  9 10 11 12]
 [10 10 10 10 10 10 10 10 10 10 10 10 10]    [ 0  1  2  3  4  5  6  7  8  9 10 11 12]
 [11 11 11 11 11 11 11 11 11 11 11 11 11]    [ 0  1  2  3  4  5  6  7  8  9 10 11 12]
 [12 12 12 12 12 12 12 12 12 12 12 12 12]]   [ 0  1  2  3  4  5  6  7  8  9 10 11 12]]

For and we expand dimension again along the last axis (break every single element into a bracketed element) before concatenation:

19    xy_grid = tf.concat([x[:, :, tf.newaxis], y[:, :, tf.newaxis]], axis=-1)

At this point, is dimensional.

20    xy_grid = tf.tile(xy_grid[tf.newaxis, :, :, tf.newaxis, :], [batch_size, 1, 1, 3, 1])
21    xy_grid = tf.cast(xy_grid, tf.float32)

Now is dimensional. Recall that

22    pred_xy = (tf.sigmoid(conv_raw_dxdy) + xy_grid) * STRIDES[i]
23    pred_wh = (tf.exp(conv_raw_dwdh) * ANCHORS[i]) * STRIDES[i]
24    pred_xywh = tf.concat([pred_xy, pred_wh], axis=-1)
25
26    pred_conf = tf.sigmoid(conv_raw_conf)
27    pred_prob = tf.sigmoid(conv_raw_prob)
28
29    return tf.concat([pred_xywh, pred_conf, pred_prob], axis=-1)

Bear in mind that decoded , in denote the center of prediction rectangle, as is the output of the function .

yolov3::{ compute_loss }

Source Code

def bbox_giou(boxes1, boxes2):

    boxes1 = tf.concat([boxes1[..., :2] - boxes1[..., 2:] * 0.5,
                        boxes1[..., :2] + boxes1[..., 2:] * 0.5], axis=-1)
    boxes2 = tf.concat([boxes2[..., :2] - boxes2[..., 2:] * 0.5,
                        boxes2[..., :2] + boxes2[..., 2:] * 0.5], axis=-1)

    boxes1 = tf.concat([tf.minimum(boxes1[..., :2], boxes1[..., 2:]),
                        tf.maximum(boxes1[..., :2], boxes1[..., 2:])], axis=-1)
    boxes2 = tf.concat([tf.minimum(boxes2[..., :2], boxes2[..., 2:]),
                        tf.maximum(boxes2[..., :2], boxes2[..., 2:])], axis=-1)

    boxes1_area = (boxes1[..., 2] - boxes1[..., 0]) * (boxes1[..., 3] - boxes1[..., 1])
    boxes2_area = (boxes2[..., 2] - boxes2[..., 0]) * (boxes2[..., 3] - boxes2[..., 1])

    left_up = tf.maximum(boxes1[..., :2], boxes2[..., :2])
    right_down = tf.minimum(boxes1[..., 2:], boxes2[..., 2:])

    inter_section = tf.maximum(right_down - left_up, 0.0)
    inter_area = inter_section[..., 0] * inter_section[..., 1]
    union_area = boxes1_area + boxes2_area - inter_area
    iou = inter_area / union_area

    enclose_left_up = tf.minimum(boxes1[..., :2], boxes2[..., :2])
    enclose_right_down = tf.maximum(boxes1[..., 2:], boxes2[..., 2:])
    enclose = tf.maximum(enclose_right_down - enclose_left_up, 0.0)
    enclose_area = enclose[..., 0] * enclose[..., 1]
    giou = iou - 1.0 * (enclose_area - union_area) / enclose_area

    return giou
1def compute_loss(pred, conv, label, bboxes, i=0):
2
3    conv_shape  = tf.shape(conv)
4    batch_size  = conv_shape[0]
5    output_size = conv_shape[1]
6    input_size  = STRIDES[i] * output_size
7    conv = tf.reshape(conv, (batch_size, output_size, output_size, 3, 5 + NUM_CLASS))
8
9    conv_raw_conf = conv[:, :, :, :, 4:5]
10    conv_raw_prob = conv[:, :, :, :, 5:]
11
12    pred_xywh     = pred[:, :, :, :, 0:4]
13    pred_conf     = pred[:, :, :, :, 4:5]
14
15    label_xywh    = label[:, :, :, :, 0:4]
16    respond_bbox  = label[:, :, :, :, 4:5] # objectiveness
17    label_prob    = label[:, :, :, :, 5:]
18
19    giou = tf.expand_dims(bbox_giou(pred_xywh, label_xywh), axis=-1)
20    input_size = tf.cast(input_size, tf.float32)
21
22    bbox_loss_scale = 2.0 - 1.0 * label_xywh[:, :, :, :, 2:3] * label_xywh[:, :, :, :, 3:4] / (input_size ** 2)
23    giou_loss = respond_bbox * bbox_loss_scale * (1 - giou)

Note that for two sets , where , the function defines a metric, so makes sense.

24    iou = bbox_iou(pred_xywh[:, :, :, :, np.newaxis, :], bboxes[:, np.newaxis, np.newaxis, np.newaxis, :, :])

are batched inside Dataset("train").__next__ before passing into (in a while loop until image count reaches batch size). Therefore , where is the maximal number of anchors (most of them are zeros due to initialization), so we see 3 :'s in .

Finally

and

where computation gets rid of the last dimension. is copied to every grid for computation because from the original paper of YOLOv3:

"the confidence prediction represents the IOU between the predicted box and any ground truth box"

25    max_iou = tf.expand_dims(tf.reduce_max(iou, axis=-1), axis=-1)
26    respond_bgd = (1.0 - respond_bbox) * tf.cast( max_iou < IOU_LOSS_THRESH, tf.float32 )

In the internet some people call as . determines whether to penalize a prediction

  • that overlaps too few with ground truth anchors (i.e., detected wrong location) and
  • that makes false positive error.
27    conf_focal = tf.pow(respond_bbox - pred_conf, 2)

The concept of focal loss with was introduced in [TY], which down-weights the loss contributed by well-classificed (high confidence) examples.

28    conf_loss = conf_focal *
29    (
30        respond_bbox * tf.nn.sigmoid_cross_entropy_with_logits(labels=respond_bbox, logits=conv_raw_conf)
31        +
32        respond_bgd * tf.nn.sigmoid_cross_entropy_with_logits(labels=respond_bbox, logits=conv_raw_conf)
33    )

Where is

therefore has to be a raw prediction data (i.e., without signmoid).

34    prob_loss = respond_bbox * tf.nn.sigmoid_cross_entropy_with_logits(labels=label_prob, logits=conv_raw_prob)
35
36    giou_loss = tf.reduce_mean(tf.reduce_sum(giou_loss, axis=[1,2,3,4]))
37    conf_loss = tf.reduce_mean(tf.reduce_sum(conf_loss, axis=[1,2,3,4]))
38    prob_loss = tf.reduce_mean(tf.reduce_sum(prob_loss, axis=[1,2,3,4]))
39
40    return giou_loss, conf_loss, prob_loss

Training

Source Code

Recap of Customized Training Loop with tf.GradientTape

Apart from predefined loss functions (such as for classfication, for regression, etc), it is ocassional to come across non-standard loss functions from other repository with the use of .

Such implementation usually involves 4 components:

  • Component 1. The model architecture

  • Component 2. The loss function used when computing the model loss

  • Component 3. The optimizer used to update the model weights

  • Component 4. The step function that encapsulates the forward and backward pass of the network

Now the code below is self-explanatory:

Start Training
def train_step(image_data, target, epoch):

    # image_data = batch of images

    with tf.GradientTape() as tape:
        pred_result = model(image_data, training=True)
        giou_loss = conf_loss = prob_loss = 0

        # optimizing process
        for i in range(3):
            conv, pred = pred_result[i*2], pred_result[i*2+1]
            batch_label, batch_bboxes = target[i]
            loss_items = compute_loss(pred, conv, batch_label, batch_bboxes, i)
            giou_loss += loss_items[0]
            conf_loss += loss_items[1]
            prob_loss += loss_items[2]

        total_loss = giou_loss + conf_loss + prob_loss

        gradients = tape.gradient(total_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # update learning rate
        global_steps.assign_add(1)
        if global_steps < warmup_steps:
            lr = global_steps / warmup_steps * cfg.TRAIN.LR_INIT
        else:
            lr = cfg.TRAIN.LR_END + 0.5 * (cfg.TRAIN.LR_INIT - cfg.TRAIN.LR_END) * (
                (1 + tf.cos((global_steps - warmup_steps) / (total_steps - warmup_steps) * np.pi))
            )
        optimizer.lr.assign(lr.numpy())

        # writing summary data
        with writer.as_default():
            tf.summary.scalar("lr", optimizer.lr, step=global_steps)
            tf.summary.scalar("loss/total_loss", total_loss, step=global_steps)
            tf.summary.scalar("loss/giou_loss", giou_loss, step=global_steps)
            tf.summary.scalar("loss/conf_loss", conf_loss, step=global_steps)
            tf.summary.scalar("loss/prob_loss", prob_loss, step=global_steps)
        writer.flush()

for epoch in range(cfg.TRAIN.EPOCHS):
    for index, (image_data, target) in enumerate(trainset):
        train_step(image_data, target, epoch)

    model.save_weights("./checkpoints/yolov3-{}-{}.h5".format(cfg.WEIGHT_NAME_TO_SAVE, epoch))

Reference