"""
__author__: HashTagML
license: MIT
Created: Thursday, 8th April 2021
"""
import collections
import copy
import os
import random
import textwrap
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union, Any
import bounding_box.bounding_box as bb
import matplotlib.pyplot as plt
import mediapy as media
import numpy as np
from mpl_toolkits.axes_grid1 import ImageGrid
from PIL import Image, ImageDraw, ImageOps
from .config import COLORS, IMAGE_BORDER, IMAGE_EXT
[docs]def check_num_imgs(images_dir: Union[str, os.PathLike]) -> int:
"""Checks number of images in given directory"""
images_dir = Path(images_dir)
file_counts = collections.Counter(p.suffix for p in images_dir.iterdir())
return sum([file_counts.get(ext, 0) for ext in IMAGE_EXT])
[docs]def check_df_cols(df_cols: List, req_cols: List) -> bool:
"""Verifies whether input dataframe contains required columns or not.
Args:
df_cols (List): List of columns in the input dataframe.
req_cols (List): List of required columns.
Returns:
bool: ``True`` if all required columns are present, otherwise ``False``.
"""
for r_col in req_cols:
if r_col not in df_cols:
return False
return True
[docs]class Resizer(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size. If tuple, output is matched to output_size. If int, smaller
of image edges is matched to output_size keeping aspect ratio the same.
"""
[docs] def __init__(self, expected_size: Tuple = (512, 512)):
assert isinstance(expected_size, tuple)
self.expected_size = expected_size
def __call__(self, sample):
img_path, anns = sample["image_path"], sample["anns"]
img = self._get_resized_img(img_path)
bboxes = self._regress_boxes(anns)
return img, bboxes
def _set_letterbox_dims(self):
""" Get letterbox resize dimensions of the images."""
iw, ih = self.orig_dim
ew, eh = self.expected_size
scale = min(eh / ih, ew / iw)
nh = int(ih * scale)
nw = int(iw * scale)
self.new_dim = (nw, nh)
offset_x, offset_y = (ew - nw) // 2, (eh - nh) // 2
self.offset = (offset_x, offset_y)
upsample_x, upsample_y = iw / nw, ih / nh
self.upsample = (upsample_x, upsample_y)
def _get_resized_img(self, img_path: str):
"""Resizes the image."""
img = Image.open(img_path)
self.orig_dim = img.size
self._set_letterbox_dims()
img = img.resize(self.new_dim)
new_img = Image.new("RGB", self.expected_size, color=(255, 255, 255))
new_img.paste(img, self.offset)
return new_img
def _regress_boxes(self, bboxes: np.ndarray):
"""Regress the bounding boxes based on resize."""
if not len(bboxes):
return []
if not hasattr(bboxes, "ndim"):
bboxes = np.array(bboxes)
# bboxes[:, 2] += bboxes[:, 0]
# bboxes[:, 3] += bboxes[:, 1]
bboxes[:, 0] = bboxes[:, 0] / self.upsample[0]
bboxes[:, 1] = bboxes[:, 1] / self.upsample[1]
bboxes[:, 2] = bboxes[:, 2] / self.upsample[0]
bboxes[:, 3] = bboxes[:, 3] / self.upsample[1]
bboxes[:, 0] += self.offset[0]
bboxes[:, 1] += self.offset[1]
bboxes[:, 2] += self.offset[0]
bboxes[:, 3] += self.offset[1]
return bboxes
[docs]def plot_boxes(
img: Image,
bboxes: np.ndarray,
scores: Optional[List] = None,
class_map: Optional[Dict] = dict(),
class_color_map: Optional[Dict] = dict(),
**kwargs,
) -> Image:
"""Plots bounding boxes annotations on the images.
Args:
img (Image): Pillow image on which annotations to be drawn.
bboxes (np.ndarray): Bounding boxes of the input image.
scores (Optional[List], optional): Scores incase of simple json format. Defaults to None.
class_map (Optional[Dict], optional): mapping between categories and their ids. Defaults to dict().
class_color_map (Optional[Dict], optional): mapping between categories and their colors. Defaults to dict().
Returns:
Image: PIL images on which annotations are drawn.
"""
draw_img = np.array(img)
for i, box in enumerate(bboxes):
threshold = kwargs.get("threshold", None)
if threshold is not None and scores[i] < threshold:
continue
bbox = list(map(lambda x: max(0, int(x)), box[:-1]))
if not isinstance(box[-1], str):
category = class_map.get(int(box[-1]), str(int(box[-1])))
else:
category = box[-1]
if kwargs.get("truncate_label", None) is not None:
category = "".join([cat[0].lower() for cat in category.split(kwargs.get("truncate_label"))])
if scores is not None:
category = category + ":" + str(round(scores[i], 2))
color = class_color_map.get(int(box[-1]), "green")
bb.add(draw_img, *bbox, category, color=color)
return Image.fromarray(draw_img)
[docs]def check_save_path(save_path: Union[str, os.PathLike], name: str = None) -> str:
"""Validates output path, creates one if it does exist."""
save_path = Path(save_path)
Path.mkdir(save_path, parents=True, exist_ok=True)
if save_path.suffix in IMAGE_EXT:
return save_path
else:
file_name = "vis.jpg" if name is None else name
return str(save_path.joinpath(file_name))
[docs]def get_class_color_map(class_map: Dict) -> Dict:
"""Maps categories in the dataset with random colors."""
class_color_map = dict()
avail_colors = copy.deepcopy(COLORS)
for cat_id, _ in class_map.items():
if len(avail_colors):
color = random.choice(avail_colors)
else:
color = "green"
class_color_map[cat_id] = color
if color in avail_colors:
avail_colors.remove(color)
return class_color_map
[docs]def render_grid_mpl(
drawn_imgs: List,
image_names: List,
num_imgs: int,
cols: int,
rows: int,
img_size: int,
save_path: Optional[str] = None,
**kwargs,
):
"""Uses ``matplotlib`` to render image grid.
Args:
drawn_imgs (List): List of images with annotations.
image_names (List): List of image names to be drawn.
num_imgs (int): Number of images in the batch
cols (int): Number of columns required in the grid.
rows (int): Number of rows required in the grid.
img_size (int): Each resized image size.
save_path (Optional[str], optional): Output path if images and annotations to be saved. Defaults to None.
"""
fig = plt.figure(
figsize=(
(rows * img_size + 3 * IMAGE_BORDER * rows) / 72,
(cols * img_size + 3 * IMAGE_BORDER * cols) / 72,
)
)
grid = ImageGrid(
fig,
111,
nrows_ncols=(rows, cols),
axes_pad=0.5, # pad between axes in inch
)
for ax, im, im_name in zip(grid, drawn_imgs, image_names):
# Iterating over the grid returns the Axes.
ax.imshow(im)
im_name = im_name.split("/")[-1]
title = "\n".join(textwrap.wrap(im_name, width=32))
ax.set_title(title)
ax.axis("off")
ax.set_xticks([])
ax.set_yticks([])
for ax in grid[num_imgs:]:
ax.axis("off")
if save_path is not None:
save_path = check_save_path(save_path)
plt.savefig(save_path)
plt.show()
[docs]def render_grid_mpy(drawn_imgs: List, image_names: List, **kwargs) -> Any:
"""Renders batch of images as a video.
Args:
drawn_imgs (List): List of images with annotations.
image_names (List): List of image names to be drawn.
Returns:
Any: IPython media object.
"""
if kwargs.get("show_image_name", None):
drawn_imgs = [np.array(add_name_strip(img, name)) for img, name in zip(drawn_imgs, image_names)]
else:
drawn_imgs = [np.array(img) for img in drawn_imgs]
fps = 1 if kwargs.get("image_time", None) is None else 1 / kwargs.get("image_time", 1.0)
return media.show_video(drawn_imgs, fps=fps)
[docs]def add_name_strip(img: np.ndarray, name: str):
""" Adds name to image at the top."""
drawn_img = ImageOps.expand(img, border=IMAGE_BORDER, fill=(255, 255, 255))
name = name.split("/")[-1]
lines = textwrap.wrap(name, width=32)
y_text = IMAGE_BORDER // 2 if len(lines) <= 1 else 0
dimg = ImageDraw.Draw(drawn_img)
font = dimg.getfont()
w = drawn_img.size[0]
for line in lines:
width, height = font.getsize(line)
dimg.multiline_text(((w - width) // 2, y_text), line, font=font, fill=(0, 0, 0))
y_text += height
return drawn_img
[docs]def render_grid_pil(
drawn_imgs: List,
image_names: List,
num_imgs: int,
cols: int,
rows: int,
img_size: int,
save_path: Optional[str] = None,
**kwargs,
) -> Any:
"""Uses ``matplotlib`` to render image grid.
Args:
drawn_imgs (List): List of images with annotations.
image_names (List): List of image names to be drawn.
num_imgs (int): Number of images in the batch
cols (int): Number of columns required in the grid.
rows (int): Number of rows required in the grid.
img_size (int): Each resized image size.
save_path (Optional[str], optional): Output path if images and annotations to be saved. Defaults to None.
Returns:
Any: IPython media object.
"""
for i in range(len(drawn_imgs)):
drawn_img = drawn_imgs[i]
img_name = image_names[i]
drawn_img = add_name_strip(drawn_img, img_name)
drawn_imgs[i] = drawn_img
width = cols * (img_size + 2 * IMAGE_BORDER)
height = rows * (img_size + 2 * IMAGE_BORDER)
canvas = Image.new("RGB", (width, height), color=(255, 255, 255))
idx = 0
for y in range(0, height, img_size + 2 * IMAGE_BORDER + 1):
for x in range(0, width, img_size + 2 * IMAGE_BORDER + 1):
if idx < num_imgs:
canvas.paste(drawn_imgs[idx], (x, y))
idx += 1
if save_path is not None:
save_path = check_save_path(save_path)
plt.savefig(save_path)
return canvas