1from shapely.geometry import Polygon
2from abc import ABC, abstractmethod
3from torch.utils.data import Dataset
4from PIL import Image
5from torchvision.transforms import transforms
6import torch
7import os
8import numpy as np
9import math
10import cv2
1class CustomDataset(Dataset):
2 def __init__(self, img_path, gt_path, scale=0.25, length=512):
3 super(CustomDataset, self).__init__()
4 self.img_files = []
5 for img_file in sorted(os.listdir(img_path)):
6 if img_file.endswith(".jpg") or img_file.endswith(".png"):
7 self.img_files.append(os.path.join(img_path, img_file))
8
9 self.gt_files = []
10 for gt_file in sorted(os.listdir(gt_path)):
11 if gt_file.endswith(".txt"):
12 self.gt_files.append(os.path.join(gt_path, gt_file))
13
14 self.scale = scale
15 self.length = length
16
17 def __getitem__(self, index):
18 with open(self.gt_files[index], 'r', encoding="utf-8") as f:
19 lines = f.readlines()
20 vertices, labels = extract_vertices(lines)
21
22 img = Image.open(self.img_files[index])
23 img, vertices = adjust_height(img, vertices)
24 img, vertices = rotate_img(img, vertices)
25 img, vertices = crop_img(img, vertices, labels, self.length)
26 transform = transforms.Compose([transforms.ColorJitter(0.5, 0.5, 0.5, 0.25),
27 transforms.ToTensor(),
28 transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
29
30 score_map, geo_map, ignored_map = get_score_geo(img, vertices, labels, self.scale, self.length)
31 return transform(img), score_map, geo_map, ignored_map
32
33 def __len__(self):
34 return len(self.img_files)
1def extract_vertices(lines):
2 '''extract vertices info from txt lines
3 Input:
4 lines : list of string info
5 Output:
6 vertices: vertices of text regions <numpy.ndarray, (n,8)>
7 labels : 1->valid, 0->ignore, <numpy.ndarray, (n,)>
8 '''
9 labels = []
10 vertices = []
11 for line in lines:
12 vertices.append(list(map(int, line.rstrip('\n').lstrip('\ufeff').split(',')[:8])))
13 label = 0 if '###' in line else 1
14 labels.append(label)
15 return np.array(vertices), np.array(labels)
1def adjust_height(img, vertices, ratio=0.2):
2 '''adjust height of image to aug data
3 Input:
4 img : PIL Image
5 vertices : vertices of text regions <numpy.ndarray, (n,8)>
6 ratio : height changes in [0.8, 1.2]
7 Output:
8 img : adjusted PIL Image
9 new_vertices: adjusted vertices
10 '''
11 ratio_h = 1 + ratio * (np.random.rand() * 2 - 1)
12 old_h = img.height
13 new_h = int(np.around(old_h * ratio_h))
14 img = img.resize((img.width, new_h), Image.BILINEAR) # PIL api (caution, widthxheight)
15
16 new_vertices = vertices.copy()
17 if vertices.size > 0:
18 new_vertices[:, [1, 3, 5, 7]] = vertices[:, [1, 3, 5, 7]] * (new_h / old_h)
19 return img, new_vertices
1def rotate_img(img, vertices, angle_range=10):
2 '''rotate image [-10, 10] degree to aug data
3 Input:
4 img : PIL Image
5 vertices : vertices of text regions <numpy.ndarray, (n,8)>
6 angle_range : rotate range
7 Output:
8 img : rotated PIL Image
9 new_vertices: rotated vertices
10 '''
11 center_x = (img.width - 1) / 2
12 center_y = (img.height - 1) / 2
13 angle = angle_range * (np.random.rand() * 2 - 1) # from -10 to 10
14 img = img.rotate(angle, Image.BILINEAR) # PIL api
15 new_vertices = np.zeros(vertices.shape)
16 for i, vertice in enumerate(vertices):
17 new_vertices[i, :] = rotate_vertices(vertice, -angle / 180 * math.pi, np.array([[center_x], [center_y]]))
18 return img, new_vertices
1def rotate_vertices(vertices, theta, anchor=None):
2 '''rotate vertices around anchor
3 Input:
4 vertices: vertices of text region <numpy.ndarray, (8,)>
5 theta : angle in radian measure
6 anchor : fixed position during rotation
7 Output:
8 rotated vertices <numpy.ndarray, (8,)>
9 '''
10 v = vertices.reshape((4, 2)).T
11 if anchor is None:
12 anchor = v[:, :1]
13 rotate_mat = get_rotate_mat(theta)
14 res = np.dot(rotate_mat, v - anchor)
15 return (res + anchor).T.reshape(-1)
1def get_rotate_mat(theta):
2 '''positive theta value means rotate clockwise'''
3 return np.array([[math.cos(theta), -math.sin(theta)], [math.sin(theta), math.cos(theta)]])
1def crop_img(img, vertices, labels, length):
2 '''crop img patches to obtain batch and augment
3 Input:
4 img : PIL Image
5 vertices : vertices of text regions <numpy.ndarray, (n,8)>
6 labels : 1->valid, 0->ignore, <numpy.ndarray, (n,)>
7 length : length of cropped image region
8 Output:
9 region : cropped image region
10 new_vertices: new vertices in cropped region
11 '''
12 h, w = img.height, img.width
13 # confirm the shortest side of image >= length
14 if h >= w and w < length:
15 img = img.resize((length, int(h * length / w)), Image.BILINEAR)
16 elif h < w and h < length:
17 img = img.resize((int(w * length / h), length), Image.BILINEAR)
18 ratio_w = img.width / w
19 ratio_h = img.height / h
20 assert(ratio_w >= 1 and ratio_h >= 1)
21
22 new_vertices = np.zeros(vertices.shape)
23 if vertices.size > 0:
24 new_vertices[:, [0, 2, 4, 6]] = vertices[:, [0, 2, 4, 6]] * ratio_w
25 new_vertices[:, [1, 3, 5, 7]] = vertices[:, [1, 3, 5, 7]] * ratio_h
26
27 # find random position
28 remain_h = img.height - length
29 remain_w = img.width - length
30 flag = True
31 cnt = 0
32 while flag and cnt < 1000:
33 cnt += 1
34 start_w = int(np.random.rand() * remain_w)
35 start_h = int(np.random.rand() * remain_h)
36 flag = is_cross_text([start_w, start_h], length, new_vertices[labels == 1, :])
37 box = (start_w, start_h, start_w + length, start_h + length)
38 region = img.crop(box)
39 if new_vertices.size == 0:
40 return region, new_vertices
41
42 new_vertices[:, [0, 2, 4, 6]] -= start_w
43 new_vertices[:, [1, 3, 5, 7]] -= start_h
44 return region, new_vertices
1def is_cross_text(start_loc, length, vertices):
2 '''check if the crop image crosses text regions
3 Input:
4 start_loc: left-top position
5 length : length of crop image
6 vertices : vertices of text regions <numpy.ndarray, (n,8)>
7 Output:
8 True if crop image crosses text region
9 '''
10 if vertices.size == 0:
11 return False
12 start_w, start_h = start_loc
13 a = np.array([start_w, start_h, start_w + length, start_h,
14 start_w + length, start_h + length, start_w, start_h + length]).reshape((4, 2))
15 p1 = Polygon(a).convex_hull
16 epsilon = 1e-6
17 for vertice in vertices:
18 p2 = Polygon(vertice.reshape((4, 2))).convex_hull
19 inter = p1.intersection(p2).area
20 if 0.01 <= inter / (p2.area + epsilon) <= 0.99:
21 return True
22 return False
1def get_score_geo(img, vertices, labels, scale, length):
2 '''generate score gt and geometry gt
3 Input:
4 img : PIL Image
5 vertices: vertices of text regions <numpy.ndarray, (n,8)>
6 labels : 1->valid, 0->ignore, <numpy.ndarray, (n,)>
7 scale : feature map / image
8 length : image length
9 Output:
10 score gt, geo gt, ignored
11 '''
12 score_map = np.zeros((int(img.height * scale), int(img.width * scale), 1), np.float32)
13 geo_map = np.zeros((int(img.height * scale), int(img.width * scale), 5), np.float32)
14 ignored_map = np.zeros((int(img.height * scale), int(img.width * scale), 1), np.float32)
15
16 index = np.arange(0, length, int(1 / scale))
17 index_x, index_y = np.meshgrid(index, index)
18 ignored_polys = []
19 polys = []
20
21 for i, vertice in enumerate(vertices):
22 if labels[i] == 0:
23 ignored_polys.append(np.around(scale * vertice.reshape((4, 2))).astype(np.int32))
24 continue
25
26 poly = np.around(scale * shrink_poly(vertice).reshape((4, 2))).astype(np.int32) # scaled & shrinked
27 polys.append(poly)
28 temp_mask = np.zeros(score_map.shape[:-1], np.float32)
29 cv2.fillPoly(temp_mask, [poly], 1)
30
31 theta = find_min_rect_angle(vertice)
32 rotate_mat = get_rotate_mat(theta)
33
34 rotated_vertices = rotate_vertices(vertice, theta)
35 x_min, x_max, y_min, y_max = get_boundary(rotated_vertices)
36 rotated_x, rotated_y = rotate_all_pixels(rotate_mat, vertice[0], vertice[1], length)
37
38 # given p in Polygon(vertice), top_distance = p - Pr_{top_L}(p) = r(p)_y - r(Pr_{top_L}(p))_y = r(p)_y - ymin
39 # where r is the rotation anchored at the top-left corner
40 # and p in Polygon(vertice) only if r(p)_y - ymin >= 0
41 # d1 = distance from top to point (j, i)
42
43 # the gt is top, bottom, left, right (上, 下, 左, 右)
44 d1 = rotated_y - y_min
45 d1[d1 < 0] = 0
46 d2 = y_max - rotated_y
47 d2[d2 < 0] = 0
48 d3 = rotated_x - x_min
49 d3[d3 < 0] = 0
50 d4 = x_max - rotated_x
51 d4[d4 < 0] = 0
52 geo_map[:, :, 0] += d1[index_y, index_x] * temp_mask
53 geo_map[:, :, 1] += d2[index_y, index_x] * temp_mask
54 geo_map[:, :, 2] += d3[index_y, index_x] * temp_mask
55 geo_map[:, :, 3] += d4[index_y, index_x] * temp_mask
56 geo_map[:, :, 4] += theta * temp_mask
57
58 cv2.fillPoly(ignored_map, ignored_polys, 1)
59 cv2.fillPoly(score_map, polys, 1)
60 return torch.Tensor(score_map).permute(2, 0, 1), torch.Tensor(geo_map).permute(2, 0, 1), torch.Tensor(ignored_map).permute(2, 0, 1)
1def shrink_poly(vertices, coef=0.3):
2 '''shrink the text region
3 Input:
4 vertices: vertices of text region <numpy.ndarray, (8,)>
5 coef : shrink ratio in paper
6 Output:
7 v : vertices of shrinked text region <numpy.ndarray, (8,)>
8 '''
9 x1, y1, x2, y2, x3, y3, x4, y4 = vertices
10 r1 = min(cal_distance(x1, y1, x2, y2), cal_distance(x1, y1, x4, y4))
11 r2 = min(cal_distance(x2, y2, x1, y1), cal_distance(x2, y2, x3, y3))
12 r3 = min(cal_distance(x3, y3, x2, y2), cal_distance(x3, y3, x4, y4))
13 r4 = min(cal_distance(x4, y4, x1, y1), cal_distance(x4, y4, x3, y3))
14 r = [r1, r2, r3, r4]
15
16 # obtain offset to perform move_points() automatically
17 if cal_distance(x1, y1, x2, y2) + cal_distance(x3, y3, x4, y4) > \
18 cal_distance(x2, y2, x3, y3) + cal_distance(x1, y1, x4, y4):
19 offset = 0 # two longer edges are (x1y1-x2y2) & (x3y3-x4y4)
20 else:
21 offset = 1 # two longer edges are (x2y2-x3y3) & (x4y4-x1y1)
22
23 v = vertices.copy()
24 v = move_points(v, 0 + offset, 1 + offset, r, coef)
- The movement is always parellel to the edges.
- In each
move_points
, two adjacent vectice will be pushed towards each other.
- Each vertex will be adjusted twice in two directions in order to move towards center.
25 v = move_points(v, 2 + offset, 3 + offset, r, coef)
26 v = move_points(v, 1 + offset, 2 + offset, r, coef)
27 v = move_points(v, 3 + offset, 4 + offset, r, coef)
28 return v
1def find_min_rect_angle(vertices):
2 '''find the best angle to rotate poly and obtain min rectangle
3 Input:
4 vertices: vertices of text region <numpy.ndarray, (8,)>
5 Output:
6 the best angle <radian measure>
7 '''
8 angle_interval = 1
9 angle_list = list(range(-90, 90, angle_interval))
10 area_list = []
11 for theta in angle_list:
12 rotated = rotate_vertices(vertices, theta / 180 * math.pi)
13 x1, y1, x2, y2, x3, y3, x4, y4 = rotated
14 temp_area = (max(x1, x2, x3, x4) - min(x1, x2, x3, x4)) * \
15 (max(y1, y2, y3, y4) - min(y1, y2, y3, y4))
16 area_list.append(temp_area)
17
18 sorted_area_index = sorted(list(range(len(area_list))), key=lambda k: area_list[k])
19 min_error = float('inf')
20 best_index = -1
21 rank_num = 10
22 # find the best angle with correct orientation
23 for index in sorted_area_index[:rank_num]:
24 rotated = rotate_vertices(vertices, angle_list[index] / 180 * math.pi)
25 temp_error = cal_error(rotated)
26 if temp_error < min_error:
27 min_error = temp_error
28 best_index = index
29 return angle_list[best_index] / 180 * math.pi
1def cal_distance(x1, y1, x2, y2):
2 '''calculate the Euclidean distance'''
3 return math.sqrt((x1 - x2)**2 + (y1 - y2)**2)
1def get_boundary(vertices):
2 '''get the tight boundary around given vertices
3 Input:
4 vertices: vertices of text region <numpy.ndarray, (8,)>
5 Output:
6 the boundary
7 '''
8 x1, y1, x2, y2, x3, y3, x4, y4 = vertices
9 x_min = min(x1, x2, x3, x4)
10 x_max = max(x1, x2, x3, x4)
11 y_min = min(y1, y2, y3, y4)
12 y_max = max(y1, y2, y3, y4)
13 return x_min, x_max, y_min, y_max
1def rotate_all_pixels(rotate_mat, anchor_x, anchor_y, length):
2 '''get rotated locations of all pixels for next stages
3 Input:
4 rotate_mat: rotatation matrix
5 anchor_x : fixed x position
6 anchor_y : fixed y position
7 length : length of image
8 Output:
9 rotated_x : rotated x positions <numpy.ndarray, (length,length)>
10 rotated_y : rotated y positions <numpy.ndarray, (length,length)>
11 '''
12 x = np.arange(length)
13 y = np.arange(length)
14 x, y = np.meshgrid(x, y)
15 x_lin = x.reshape((1, x.size))
16 y_lin = y.reshape((1, x.size))
17 coord_mat = np.concatenate((x_lin, y_lin), 0)
18 rotated_coord = np.dot(rotate_mat, coord_mat - np.array([[anchor_x], [anchor_y]])) + \
19 np.array([[anchor_x], [anchor_y]])
20 rotated_x = rotated_coord[0, :].reshape(x.shape)
21 rotated_y = rotated_coord[1, :].reshape(y.shape)
22 return rotated_x, rotated_y
1def move_points(vertices, index1, index2, r, coef):
2 '''move the two points to shrink edge
3 Input:
4 vertices: vertices of text region <numpy.ndarray, (8,)>
5 index1 : offset of point1
6 index2 : offset of point2
7 r : [r1, r2, r3, r4] in paper
8 coef : shrink ratio in paper
9 Output:
10 vertices: vertices where one edge has been shinked
11 '''
12 index1 = index1 % 4
13 index2 = index2 % 4
14 x1_index = index1 * 2 + 0
15 y1_index = index1 * 2 + 1
16 x2_index = index2 * 2 + 0
17 y2_index = index2 * 2 + 1
18
19 r1 = r[index1]
20 r2 = r[index2]
21 length_x = vertices[x1_index] - vertices[x2_index]
22 length_y = vertices[y1_index] - vertices[y2_index]
23 length = cal_distance(vertices[x1_index], vertices[y1_index], vertices[x2_index], vertices[y2_index])
24 if length > 1:
25 ratio = (r1 * coef) / length
26 vertices[x1_index] += ratio * (-length_x)
27 vertices[y1_index] += ratio * (-length_y)
28 ratio = (r2 * coef) / length
29 vertices[x2_index] += ratio * length_x
30 vertices[y2_index] += ratio * length_y
31 return vertices
1def cal_error(vertices):
2 '''default orientation is x1y1 : left-top, x2y2 : right-top, x3y3 : right-bot, x4y4 : left-bot
3 calculate the difference between the vertices orientation and default orientation
4 Input:
5 vertices: vertices of text region <numpy.ndarray, (8,)>
6 Output:
7 err : difference measure
8 '''
9 x_min, x_max, y_min, y_max = get_boundary(vertices)
10 x1, y1, x2, y2, x3, y3, x4, y4 = vertices
11 err = cal_distance(x1, y1, x_min, y_min) + cal_distance(x2, y2, x_max, y_min) + \
12 cal_distance(x3, y3, x_max, y_max) + cal_distance(x4, y4, x_min, y_max)
13 return err
1import torch
2import torch.nn as nn
1def get_dice_loss(gt_score, pred_score):
2 inter = torch.sum(gt_score * pred_score)
3 union = torch.sum(gt_score) + torch.sum(pred_score) + 1e-5
4 return 1. - (2 * inter / union)
1def get_geo_loss(gt_geo, pred_geo):
2 d1_gt, d2_gt, d3_gt, d4_gt, angle_gt = torch.split(gt_geo, 1, 1)
3 d1_pred, d2_pred, d3_pred, d4_pred, angle_pred = torch.split(pred_geo, 1, 1)
4 area_gt = (d1_gt + d2_gt) * (d3_gt + d4_gt)
5 area_pred = (d1_pred + d2_pred) * (d3_pred + d4_pred)
6 w_union = torch.min(d3_gt, d3_pred) + torch.min(d4_gt, d4_pred)
7 h_union = torch.min(d1_gt, d1_pred) + torch.min(d2_gt, d2_pred)
8 area_intersect = w_union * h_union
9 area_union = area_gt + area_pred - area_intersect
10 iou_loss_map = -torch.log((area_intersect + 1.0) / (area_union + 1.0))
11 angle_loss_map = 1 - torch.cos(angle_pred - angle_gt)
12 return iou_loss_map, angle_loss_map
1class Loss(nn.Module):
2 def __init__(self, weight_angle=10):
3 super(Loss, self).__init__()
4 self.weight_angle = weight_angle
5
6 def forward(self, gt_score, pred_score, gt_geo, pred_geo, ignored_map):
7 if torch.sum(gt_score) < 1:
8 return torch.sum(pred_score + pred_geo) * 0
9
10 classify_loss = get_dice_loss(gt_score, pred_score * (1 - ignored_map))
11 iou_loss_map, angle_loss_map = get_geo_loss(gt_geo, pred_geo)
12
13 angle_loss = torch.sum(angle_loss_map * gt_score) / torch.sum(gt_score)
14 iou_loss = torch.sum(iou_loss_map * gt_score) / torch.sum(gt_score)
15 geo_loss = self.weight_angle * angle_loss + iou_loss
16 return geo_loss + classify_loss
1import torch
2from torch.utils import data
3from torch import nn
4from torch.optim import lr_scheduler
5from dataset import CustomDataset
6from detect import performance_check
7from models import EAST
8from losses import Loss
9from tqdm import tqdm
10from device import device
11from utils import ConsoleLog
12import os
13import time
1console_log = ConsoleLog(lines_up_on_end=1)
2def train(train_img_path, train_gt_path, pths_path, batch_size, lr, num_workers, epoch_iter, interval):
3 file_num = len(os.listdir(train_img_path))
4 trainset = CustomDataset(train_img_path, train_gt_path)
5 train_loader = data.DataLoader(trainset,
6 batch_size=batch_size,
7 shuffle=True,
8 num_workers=num_workers,
9 drop_last=True)
10
11 criterion = Loss()
12 model = EAST()
13 data_parallel = False
14
15 if torch.cuda.device_count() > 1:
16 model = nn.DataParallel(model)
17 data_parallel = True
18
19 model.to(device)
20 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
21 scheduler = lr_scheduler.MultiStepLR(optimizer,
22 milestones=[epoch_iter // 2],
23 gamma=0.1)
24
25 for epoch in range(epoch_iter):
26 model.train()
27 epoch_loss = 0
28 epoch_time = time.time()
29
30 for batch, (img, gt_score, gt_geo, ignored_map) in enumerate(tqdm(
31 train_loader,
32 total=len(trainset) // batch_size,
33 bar_format="{desc}: {percentage:.1f}%|{bar:15}| {n}/{total_fmt} [{elapsed}, {rate_fmt}{postfix}]"
34 )):
35 start_time = time.time()
36
37 img = img.to(device)
38 gt_score = gt_score.to(device)
39 gt_geo = gt_geo.to(device)
40 ignored_map = ignored_map.to(device)
41
42 pred_score, pred_geo = model(img)
43
44 loss = criterion(gt_score, pred_score, gt_geo, pred_geo, ignored_map)
45
46 epoch_loss += loss.item()
47 optimizer.zero_grad()
48 loss.backward()
49 optimizer.step()
50 scheduler.step()
51
52 if (batch + 1) % save_interval == 0:
53 performance_check(model, save_image_path="results/epoch_{}_batch_{}.jpg".format(epoch, batch + 1))
54
55 console_log.print(
56 'Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'.format(
57 epoch + 1, epoch_iter, batch + 1, int(file_num / batch_size), time.time() - start_time, loss.item()),
58 is_key_value=False
59 )
60
61 if (epoch + 1) % interval == 0:
62 state_dict = model.module.state_dict() if data_parallel else model.state_dict()
63 torch.save(state_dict, os.path.join(pths_path, 'model_epoch_{}.pth'.format(epoch + 1)))
64
65
66if __name__ == '__main__':
67 train_img_path = os.path.abspath('dataset/images')
68 train_gt_path = os.path.abspath('dataset/annotations')
69 pths_path = './pths'
70 batch_size = 24
71 lr = 1e-3
72 num_workers = 4
73 epoch_iter = 600
74 save_interval = 5
75 train(train_img_path, train_gt_path, pths_path, batch_size, lr, num_workers, epoch_iter, save_interval)
1from torchvision import transforms
2from PIL import Image, ImageDraw
3from models import EAST
4from dataset import get_rotate_mat
5from utils import nms_locality
6from device import device
7
8import config
9import torch
10import os
11import numpy as np
12import random
1def resize_img(img):
2 '''resize image to be divisible by 32
3 '''
4 w, h = img.size
5 resize_w = w
6 resize_h = h
7
8 resize_h = resize_h if resize_h % 32 == 0 else int(resize_h / 32) * 32
9 resize_w = resize_w if resize_w % 32 == 0 else int(resize_w / 32) * 32
10 img = img.resize((resize_w, resize_h), Image.BILINEAR)
11 ratio_h = resize_h / h
12 ratio_w = resize_w / w
13
14 return img, ratio_h, ratio_w
1def load_pil(img):
2 '''convert PIL Image to torch.Tensor
3 '''
4 t = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
5 return t(img).unsqueeze(0)
1def is_valid_poly(res, score_shape, scale):
2 '''check if the poly in image scope
3 Input:
4 res : restored poly in original image
5 score_shape: score map shape
6 scale : feature map -> image
7 Output:
8 True if valid
9 '''
10 cnt = 0
11 for i in range(res.shape[1]):
12 if res[0, i] < 0 or res[0, i] >= score_shape[1] * scale or \
13 res[1, i] < 0 or res[1, i] >= score_shape[0] * scale:
14 cnt += 1
15 return True if cnt <= 1 else False
1def restore_polys(valid_pos, valid_geo, score_shape, scale=4):
2 '''restore polys from feature maps in given positions
3 Input:
4 valid_pos : potential text positions <numpy.ndarray, (n,2)>
5 valid_geo : geometry in valid_pos <numpy.ndarray, (5,n)>
6 score_shape: shape of score map
7 scale : image / feature map
8 Output:
9 restored polys <numpy.ndarray, (n,8)>, index
10 '''
11 polys = []
12 index = []
13 valid_pos *= scale
14 d = valid_geo[:4, :] # 4 x N
15 angle = valid_geo[4, :] # N,
16
17 for i in range(valid_pos.shape[0]):
18 x = valid_pos[i, 0]
19 y = valid_pos[i, 1]
20 y_min = y - d[0, i]
21 y_max = y + d[1, i]
22 x_min = x - d[2, i]
23 x_max = x + d[3, i]
24 rotate_mat = get_rotate_mat(-angle[i])
25
26 temp_x = np.array([[x_min, x_max, x_max, x_min]]) - x
27 temp_y = np.array([[y_min, y_min, y_max, y_max]]) - y
28 coordidates = np.concatenate((temp_x, temp_y), axis=0)
29 res = np.dot(rotate_mat, coordidates)
30 res[0, :] += x
31 res[1, :] += y
32
33 if is_valid_poly(res, score_shape, scale):
34 index.append(i)
35 polys.append([res[0, 0], res[1, 0], res[0, 1], res[1, 1], res[0, 2], res[1, 2], res[0, 3], res[1, 3]])
36 return np.array(polys), index
1def get_boxes(score, geo, score_thresh=config.detection_score_threshold, nms_thresh=0.2):
2 '''get boxes from feature map
3 Input:
4 score : score map from model <numpy.ndarray, (1,row,col)>
5 geo : geo map from model <numpy.ndarray, (5,row,col)>
6 score_thresh: threshold to segment score map
7 nms_thresh : threshold in nms
8 Output:
9 boxes : final polys <numpy.ndarray, (n,9)>
10 '''
11 score = score[0, :, :]
12 xy_text = np.argwhere(score > score_thresh) # n x 2, format is [r, c]
13 if xy_text.size == 0:
14 return None
15
16 xy_text = xy_text[np.argsort(xy_text[:, 0])]
17 valid_pos = xy_text[:, ::-1].copy() # n x 2, [x, y]
18 # Due to ::-1, pos is now following (x, y) = (i, j) notational convention
19 valid_geo = geo[:, xy_text[:, 0], xy_text[:, 1]] # 5 x n
20 # So is valid_geo
21 polys_restored, index = restore_polys(valid_pos, valid_geo, score.shape)
22 if polys_restored.size == 0:
23 return None
24
25 boxes = np.zeros((polys_restored.shape[0], 9), dtype=np.float32)
26 boxes[:, :8] = polys_restored
27 boxes[:, 8] = score[xy_text[index, 0], xy_text[index, 1]]
28 boxes = nms_locality(boxes.astype('float32'), nms_thresh)
29 return boxes
1def adjust_ratio(boxes, ratio_w, ratio_h):
2 '''refine boxes
3 Input:
4 boxes : detected polys <numpy.ndarray, (n,9)>
5 ratio_w: ratio of width
6 ratio_h: ratio of height
7 Output:
8 refined boxes
9 '''
10 if boxes is None or boxes.size == 0:
11 return None
12 boxes[:, [0, 2, 4, 6]] /= ratio_w
13 boxes[:, [1, 3, 5, 7]] /= ratio_h
14 return np.around(boxes)
1def detect(img, model, device):
2 '''detect text regions of img using model
3 Input:
4 img : PIL Image
5 model : detection model
6 device: gpu if gpu is available
7 Output:
8 detected polys
9 '''
10 img, ratio_h, ratio_w = resize_img(img)
11 with torch.no_grad():
12 score, geo = model(load_pil(img).to(device))
13 boxes = get_boxes(score.squeeze(0).cpu().numpy(), geo.squeeze(0).cpu().numpy())
14 return adjust_ratio(boxes, ratio_w, ratio_h)
1def plot_boxes(img, boxes):
2 '''plot boxes on image
3 '''
4 if boxes is None:
5 return img
6
7 draw = ImageDraw.Draw(img)
8 for box in boxes:
9 draw.polygon([box[0], box[1], box[2], box[3], box[4], box[5], box[6], box[7]], outline=(0, 255, 0))
10 return img
1def detect_dataset(model, device, test_img_path, submit_path):
2 '''detection on whole dataset, save .txt results in submit_path
3 Input:
4 model : detection model
5 device : gpu if gpu is available
6 test_img_path: dataset path
7 submit_path : submit result for evaluation
8 '''
9 img_files = os.listdir(test_img_path)
10 img_files = sorted([os.path.join(test_img_path, img_file) for img_file in img_files])
11
12 for i, img_file in enumerate(img_files):
13 print('evaluating {} image'.format(i), end='\r')
14 boxes = detect(Image.open(img_file), model, device)
15 seq = []
16 if boxes is not None:
17 seq.extend([','.join([str(int(b)) for b in box[:-1]]) + '\n' for box in boxes])
18 with open(os.path.join(submit_path, 'res_' + os.path.basename(img_file).replace('.jpg', '.txt')), 'w') as f:
19 f.writelines(seq)
1def performance_check(model, save_image_path):
2 model.eval()
3 images = os.listdir("dataset/images")
4 random.shuffle(images)
5 img = Image.open("dataset/images/{}".format(images[0]))
6 boxes = detect(img, model, device)
7 plot_img = plot_boxes(img, boxes)
8 plot_img.save(save_image_path)
9 plot_img.save("results/latest_output.jpg")
10 model.train()
1import torch
2device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1from pydash.objects import get, set_
2from shapely.geometry import Polygon
3import numpy as np
1def iou(g, p):
2 g = Polygon(g[:8].reshape((4, 2)))
3 p = Polygon(p[:8].reshape((4, 2)))
4 if not g.is_valid or not p.is_valid:
5 return 0
6 inter = Polygon(g).intersection(Polygon(p)).area
7 union = g.area + p.area - inter
8 if union == 0:
9 return 0
10 else:
11 return inter / union
1def weighted_merge(g, p):
2 g[:8] = (g[8] * g[:8] + p[8] * p[:8]) / (g[8] + p[8])
3 g[8] = (g[8] + p[8])
4 return g
1def standard_nms(S, thres):
2 order = np.argsort(S[:, 8])[::-1]
3 keep = []
4 while order.size > 0:
5 i = order[0]
6 keep.append(i)
7 ious = np.array([iou(S[i], S[t]) for t in order[1:]])
8
9 inds = np.where(ious <= thres)[0]
10 # since order[0] is taken out
11 order = order[inds + 1]
12
13 return S[keep]
1def nms_locality(polys, thres=0.3):
2 '''
3 locality aware nms of EAST
4 :param polys: a N*9 numpy array. first 8 coordinates, then prob
5 :return: boxes after nms
6 '''
7 S = []
8 p = None
9 for g in polys:
10 if p is not None and iou(g, p) > thres:
11 p = weighted_merge(g, p)
12 else:
13 if p is not None:
14 S.append(p)
15 p = g
16 if p is not None:
17 S.append(p)
18
19 if len(S) == 0:
20 return np.array([])
21 return standard_nms(np.array(S), thres)