Skip to content

Commit

Permalink
Merge pull request #458 from SamSamhuns/master
Browse files Browse the repository at this point in the history
Add batch img inference support for ocr det with readtext_batched
  • Loading branch information
rkcosmos authored Jun 24, 2021
2 parents 229a451 + 4a3829c commit 89ec92f
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 50 deletions.
83 changes: 51 additions & 32 deletions easyocr/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,42 +22,55 @@ def copyStateDict(state_dict):
return new_state_dict

def test_net(canvas_size, mag_ratio, net, image, text_threshold, link_threshold, low_text, poly, device, estimate_num_chars=False):
if isinstance(image, np.ndarray) and len(image.shape) == 4: # image is batch of np arrays
image_arrs = image
else: # image is single numpy array
image_arrs = [image]

img_resized_list = []
# resize
img_resized, target_ratio, size_heatmap = resize_aspect_ratio(image, canvas_size,\
interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio)
for img in image_arrs:
img_resized, target_ratio, size_heatmap = resize_aspect_ratio(img, canvas_size,
interpolation=cv2.INTER_LINEAR,
mag_ratio=mag_ratio)
img_resized_list.append(img_resized)
ratio_h = ratio_w = 1 / target_ratio

# preprocessing
x = normalizeMeanVariance(img_resized)
x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
x = np.array([normalizeMeanVariance(n_img) for n_img in img_resized_list])
x = Variable(torch.from_numpy(x).permute(0, 3, 1, 2)) # [b,h,w,c] to [b,c,h,w]
x = x.to(device)

# forward pass
with torch.no_grad():
y, feature = net(x)

# make score and link map
score_text = y[0,:,:,0].cpu().data.numpy()
score_link = y[0,:,:,1].cpu().data.numpy()
boxes_list, polys_list = [], []
for out in y:
# make score and link map
score_text = out[:, :, 0].cpu().data.numpy()
score_link = out[:, :, 1].cpu().data.numpy()

# Post-processing
boxes, polys, mapper = getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly, estimate_num_chars)
# Post-processing
boxes, polys, mapper = getDetBoxes(
score_text, score_link, text_threshold, link_threshold, low_text, poly, estimate_num_chars)

# coordinate adjustment
boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
if estimate_num_chars:
boxes = list(boxes)
polys = list(polys)
for k in range(len(polys)):
# coordinate adjustment
boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
if estimate_num_chars:
boxes[k] = (boxes[k], mapper[k])
if polys[k] is None: polys[k] = boxes[k]

return boxes, polys

def get_detector(trained_model, device='cpu', quantize=True):
boxes = list(boxes)
polys = list(polys)
for k in range(len(polys)):
if estimate_num_chars:
boxes[k] = (boxes[k], mapper[k])
if polys[k] is None:
polys[k] = boxes[k]
boxes_list.append(boxes)
polys_list.append(polys)

return boxes_list, polys_list

def get_detector(trained_model, device='cpu', quantize=True, cudnn_benchmark=False):
net = CRAFT()

if device == 'cpu':
Expand All @@ -70,21 +83,27 @@ def get_detector(trained_model, device='cpu', quantize=True):
else:
net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device)))
net = torch.nn.DataParallel(net).to(device)
cudnn.benchmark = False
cudnn.benchmark = cudnn_benchmark

net.eval()
return net

def get_textbox(detector, image, canvas_size, mag_ratio, text_threshold, link_threshold, low_text, poly, device, optimal_num_chars=None):
result = []
estimate_num_chars = optimal_num_chars is not None
bboxes, polys = test_net(canvas_size, mag_ratio, detector, image, text_threshold, link_threshold, low_text, poly, device, estimate_num_chars)

bboxes_list, polys_list = test_net(canvas_size, mag_ratio, detector,
image, text_threshold,
link_threshold, low_text, poly,
device, estimate_num_chars)
if estimate_num_chars:
polys = [p for p, _ in sorted(polys, key=lambda x: abs(optimal_num_chars - x[1]))]

for i, box in enumerate(polys):
poly = np.array(box).astype(np.int32).reshape((-1))
result.append(poly)
polys_list = [[p for p, _ in sorted(polys, key=lambda x: abs(optimal_num_chars - x[1]))]
for polys in polys_list]

for polys in polys_list:
single_img_result = []
for i, box in enumerate(polys):
poly = np.array(box).astype(np.int32).reshape((-1))
single_img_result.append(poly)
result.append(single_img_result)

return result
81 changes: 63 additions & 18 deletions easyocr/easyocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from .recognition import get_recognizer, get_text
from .utils import group_text_box, get_image_list, calculate_md5, get_paragraph,\
download_and_unzip, printProgressBar, diff, reformat_input,\
make_rotated_img_list, set_result_with_confidence
make_rotated_img_list, set_result_with_confidence,\
reformat_input_batched
from .config import *
from bidi.algorithm import get_display
import numpy as np
Expand All @@ -31,7 +32,7 @@ class Reader(object):
def __init__(self, lang_list, gpu=True, model_storage_directory=None,
user_network_directory=None, recog_network = 'standard',
download_enabled=True, detector=True, recognizer=True,
verbose=True, quantize=True):
verbose=True, quantize=True, cudnn_benchmark=False):
"""Create an EasyOCR Reader.
Parameters:
Expand Down Expand Up @@ -75,7 +76,7 @@ def __init__(self, lang_list, gpu=True, model_storage_directory=None,
else:
self.device = gpu
self.recognition_models = recognition_models

# check and download detection model
detector_model = 'craft'
corrupt_msg = 'MD5 hash mismatch, possible file corruption'
Expand Down Expand Up @@ -215,7 +216,7 @@ def __init__(self, lang_list, gpu=True, model_storage_directory=None,
dict_list[lang] = os.path.join(BASE_PATH, 'dict', lang + ".txt")

if detector:
self.detector = get_detector(detector_path, self.device, quantize)
self.detector = get_detector(detector_path, self.device, quantize, cudnn_benchmark=cudnn_benchmark)
if recognizer:
if recog_network == 'generation1':
network_params = {
Expand Down Expand Up @@ -271,19 +272,25 @@ def detect(self, img, min_size = 20, text_threshold = 0.7, low_text = 0.4,\
if reformat:
img, img_cv_grey = reformat_input(img)

text_box = get_textbox(self.detector, img, canvas_size, mag_ratio,\
text_threshold, link_threshold, low_text,\
False, self.device, optimal_num_chars)
horizontal_list, free_list = group_text_box(text_box, slope_ths,\
ycenter_ths, height_ths,\
width_ths, add_margin, \
(optimal_num_chars is None))

if min_size:
horizontal_list = [i for i in horizontal_list if max(i[1]-i[0],i[3]-i[2]) > min_size]
free_list = [i for i in free_list if max(diff([c[0] for c in i]), diff([c[1] for c in i]))>min_size]

return horizontal_list, free_list
text_box_list = get_textbox(self.detector, img, canvas_size, mag_ratio,
text_threshold, link_threshold, low_text,
False, self.device, optimal_num_chars)

horizontal_list_agg, free_list_agg = [], []
for text_box in text_box_list:
horizontal_list, free_list = group_text_box(text_box, slope_ths,
ycenter_ths, height_ths,
width_ths, add_margin,
(optimal_num_chars is None))
if min_size:
horizontal_list = [i for i in horizontal_list if max(
i[1] - i[0], i[3] - i[2]) > min_size]
free_list = [i for i in free_list if max(
diff([c[0] for c in i]), diff([c[1] for c in i])) > min_size]
horizontal_list_agg.append(horizontal_list)
free_list_agg.append(free_list)

return horizontal_list_agg, free_list_agg

def recognize(self, img_cv_grey, horizontal_list=None, free_list=None,\
decoder = 'greedy', beamWidth= 5, batch_size = 1,\
Expand Down Expand Up @@ -381,11 +388,49 @@ def readtext(self, image, decoder = 'greedy', beamWidth= 5, batch_size = 1,\
slope_ths, ycenter_ths,\
height_ths,width_ths,\
add_margin, False)

# get the 1st result from hor & free list as self.detect returns a list of depth 3
horizontal_list, free_list = horizontal_list[0], free_list[0]
result = self.recognize(img_cv_grey, horizontal_list, free_list,\
decoder, beamWidth, batch_size,\
workers, allowlist, blocklist, detail, rotation_info,\
paragraph, contrast_ths, adjust_contrast,\
filter_ths, y_ths, x_ths, False, output_format)

return result

def readtext_batched(self, image, n_width=None, n_height=None,\
decoder = 'greedy', beamWidth= 5, batch_size = 1,\
workers = 0, allowlist = None, blocklist = None, detail = 1,\
rotation_info = None, paragraph = False, min_size = 20,\
contrast_ths = 0.1,adjust_contrast = 0.5, filter_ths = 0.003,\
text_threshold = 0.7, low_text = 0.4, link_threshold = 0.4,\
canvas_size = 2560, mag_ratio = 1.,\
slope_ths = 0.1, ycenter_ths = 0.5, height_ths = 0.5,\
width_ths = 0.5, y_ths = 0.5, x_ths = 1.0, add_margin = 0.1, output_format='standard'):
'''
Parameters:
image: file path or numpy-array or a byte stream object
When sending a list of images, they all must of the same size,
the following parameters will automatically resize if they are not None
n_width: int, new width
n_height: int, new height
'''
img, img_cv_grey = reformat_input_batched(image, n_width, n_height)

horizontal_list_agg, free_list_agg = self.detect(img, min_size, text_threshold,\
low_text, link_threshold,\
canvas_size, mag_ratio,\
slope_ths, ycenter_ths,\
height_ths, width_ths,\
add_margin, False)
result_agg = []
# put img_cv_grey in a list if its a single img
img_cv_grey = [img_cv_grey] if len(img_cv_grey.shape) == 2 else img_cv_grey
for grey_img, horizontal_list, free_list in zip(img_cv_grey, horizontal_list_agg, free_list_agg):
result_agg.append(self.recognize(grey_img, horizontal_list, free_list,\
decoder, beamWidth, batch_size,\
workers, allowlist, blocklist, detail, rotation_info,\
paragraph, contrast_ths, adjust_contrast,\
filter_ths, y_ths, x_ths, False, output_format))

return result_agg
29 changes: 29 additions & 0 deletions easyocr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,35 @@ def reformat_input(image):
return img, img_cv_grey


def reformat_input_batched(image, n_width=None, n_height=None):
"""
reformats an image or list of images or a 4D numpy image array &
returns a list of corresponding img, img_cv_grey nd.arrays
image:
[file path, numpy-array, byte stream object,
list of file paths, list of numpy-array, 4D numpy array,
list of byte stream objects]
"""
if ((isinstance(image, np.ndarray) and len(image.shape) == 4) or isinstance(image, list)):
# process image batches if image is list of image np arr, paths, bytes
img, img_cv_grey = [], []
for single_img in image:
clr, gry = reformat_input(single_img)
if n_width is not None and n_height is not None:
clr = cv2.resize(clr, (n_width, n_height))
gry = cv2.resize(gry, (n_width, n_height))
img.append(clr)
img_cv_grey.append(gry)
img, img_cv_grey = np.array(img), np.array(img_cv_grey)
# ragged tensors created when all input imgs are not of the same size
if len(img.shape) == 1 and len(img_cv_grey.shape) == 1:
raise ValueError("The input image array contains images of different sizes. " +
"Please resize all images to same shape or pass n_width, n_height to auto-resize")
else:
img, img_cv_grey = reformat_input(image)
return img, img_cv_grey


def make_rotated_img_list(rotationInfo, img_list):
result_img_list = img_list[:]

Expand Down

0 comments on commit 89ec92f

Please sign in to comment.