branch: master
mask_rcnn.py
9697 bytesRaw
from extra.models.mask_rcnn import MaskRCNN
from extra.models.resnet import ResNet
from extra.models.mask_rcnn import BoxList
from torch.nn import functional as F
from torchvision import transforms as T
from torchvision.transforms import functional as Ft
import random
from tinygrad.tensor import Tensor
from PIL import Image
import numpy as np
import torch
import argparse
import cv2


class Resize:
  def __init__(self, min_size, max_size):
    if not isinstance(min_size, (list, tuple)):
      min_size = (min_size,)
    self.min_size = min_size
    self.max_size = max_size

  # modified from torchvision to add support for max size
  def get_size(self, image_size):
    w, h = image_size
    size = random.choice(self.min_size)
    max_size = self.max_size
    if max_size is not None:
      min_original_size = float(min((w, h)))
      max_original_size = float(max((w, h)))
      if max_original_size / min_original_size * size > max_size:
        size = int(round(max_size * min_original_size / max_original_size))

      if (w <= h and w == size) or (h <= w and h == size):
        return (h, w)

      if w < h:
        ow = size
        oh = int(size * h / w)
      else:
        oh = size
        ow = int(size * w / h)

      return (oh, ow)

  def __call__(self, image):
    size = self.get_size(image.size)
    image = Ft.resize(image, size)
    return image


class Normalize:
  def __init__(self, mean, std, to_bgr255=True):
    self.mean = mean
    self.std = std
    self.to_bgr255 = to_bgr255

  def __call__(self, image):
    if self.to_bgr255:
      image = image[[2, 1, 0]] * 255
    else:
      image = image[[0, 1, 2]] * 255
    image = Ft.normalize(image, mean=self.mean, std=self.std)
    return image

transforms = lambda size_scale: T.Compose(
  [
    Resize(int(800*size_scale), int(1333*size_scale)),
    T.ToTensor(),
    Normalize(
      mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.], to_bgr255=True
    ),
  ]
)

def expand_boxes(boxes, scale):
  w_half = (boxes[:, 2] - boxes[:, 0]) * .5
  h_half = (boxes[:, 3] - boxes[:, 1]) * .5
  x_c = (boxes[:, 2] + boxes[:, 0]) * .5
  y_c = (boxes[:, 3] + boxes[:, 1]) * .5

  w_half *= scale
  h_half *= scale

  boxes_exp = torch.zeros_like(boxes)
  boxes_exp[:, 0] = x_c - w_half
  boxes_exp[:, 2] = x_c + w_half
  boxes_exp[:, 1] = y_c - h_half
  boxes_exp[:, 3] = y_c + h_half
  return boxes_exp


def expand_masks(mask, padding):
  N = mask.shape[0]
  M = mask.shape[-1]
  pad2 = 2 * padding
  scale = float(M + pad2) / M
  padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2))
  padded_mask[:, :, padding:-padding, padding:-padding] = mask
  return padded_mask, scale


def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
  # TODO: remove torch
  mask = torch.tensor(mask.numpy())
  box = torch.tensor(box.numpy())
  padded_mask, scale = expand_masks(mask[None], padding=padding)
  mask = padded_mask[0, 0]
  box = expand_boxes(box[None], scale)[0]
  box = box.to(dtype=torch.int32)

  TO_REMOVE = 1
  w = int(box[2] - box[0] + TO_REMOVE)
  h = int(box[3] - box[1] + TO_REMOVE)
  w = max(w, 1)
  h = max(h, 1)

  mask = mask.expand((1, 1, -1, -1))

  mask = mask.to(torch.float32)
  mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
  mask = mask[0][0]

  if thresh >= 0:
    mask = mask > thresh
  else:
    mask = (mask * 255).to(torch.uint8)

  im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8)
  x_0 = max(box[0], 0)
  x_1 = min(box[2] + 1, im_w)
  y_0 = max(box[1], 0)
  y_1 = min(box[3] + 1, im_h)

  im_mask[y_0:y_1, x_0:x_1] = mask[
                              (y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])
                              ]
  return im_mask


class Masker:
  def __init__(self, threshold=0.5, padding=1):
    self.threshold = threshold
    self.padding = padding

  def forward_single_image(self, masks, boxes):
    boxes = boxes.convert("xyxy")
    im_w, im_h = boxes.size
    res = [
      paste_mask_in_image(mask[0], box, im_h, im_w, self.threshold, self.padding)
      for mask, box in zip(masks, boxes.bbox)
    ]
    if len(res) > 0:
      res = torch.stack(*res, dim=0)[:, None]
    else:
      res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1]))
    return Tensor(res.numpy())

  def __call__(self, masks, boxes):
    if isinstance(boxes, BoxList):
      boxes = [boxes]

    results = []
    for mask, box in zip(masks, boxes):
      result = self.forward_single_image(mask, box)
      results.append(result)
    return results


masker = Masker(threshold=0.5, padding=1)

def select_top_predictions(predictions, confidence_threshold=0.9):
  scores = predictions.get_field("scores").numpy()
  keep = [idx for idx, score in enumerate(scores) if score > confidence_threshold]
  return predictions[keep]

def compute_prediction(original_image, model, confidence_threshold, size_scale=1.0):
  image = transforms(size_scale)(original_image).numpy()
  image = Tensor(image, requires_grad=False)
  predictions = model(image)
  prediction = predictions[0]
  prediction = select_top_predictions(prediction, confidence_threshold)
  width, height = original_image.size
  prediction = prediction.resize((width, height))

  if prediction.has_field("mask"):
    masks = prediction.get_field("mask")
    masks = masker([masks], [prediction])[0]
    prediction.add_field("mask", masks)
  return prediction

def compute_prediction_batched(batch, model, size_scale=1.0):
  imgs = []
  for img in batch:
    imgs.append(transforms(size_scale)(img).numpy())
  image = [Tensor(image, requires_grad=False) for image in imgs]
  predictions = model(image)
  del image
  return predictions

palette = np.array([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])

def findContours(*args, **kwargs):
  if cv2.__version__.startswith('4'):
    contours, hierarchy = cv2.findContours(*args, **kwargs)
  elif cv2.__version__.startswith('3'):
    _, contours, hierarchy = cv2.findContours(*args, **kwargs)
  return contours, hierarchy

def compute_colors_for_labels(labels):
  l = labels[:, None]
  colors = l * palette
  colors = (colors % 255).astype("uint8")
  return colors

def overlay_mask(image, predictions):
  image = np.asarray(image)
  masks = predictions.get_field("mask").numpy()
  labels = predictions.get_field("labels").numpy()

  colors = compute_colors_for_labels(labels).tolist()

  for mask, color in zip(masks, colors):
    thresh = mask[0, :, :, None]
    contours, hierarchy = findContours(
        thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
    )
    image = cv2.drawContours(image, contours, -1, color, 3)

  composite = image

  return composite

CATEGORIES = [
    "__background", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
    "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
    "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
    "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
    "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli",
    "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table",
    "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster",
    "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
]

def overlay_boxes(image, predictions):
  labels = predictions.get_field("labels").numpy()
  boxes = predictions.bbox
  image = np.asarray(image)
  colors = compute_colors_for_labels(labels).tolist()

  for box, color in zip(boxes, colors):
    box = torch.tensor(box.numpy())
    box = box.to(torch.int64)
    top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
    image = cv2.rectangle(
        image, tuple(top_left), tuple(bottom_right), tuple(color), 1
    )

  return image

def overlay_class_names(image, predictions):
  scores = predictions.get_field("scores").numpy().tolist()
  labels = predictions.get_field("labels").numpy().tolist()
  labels = [CATEGORIES[int(i)] for i in labels]
  boxes = predictions.bbox.numpy()
  image = np.asarray(image)
  template = "{}: {:.2f}"
  for box, score, label in zip(boxes, scores, labels):
    x, y = box[:2]
    s = template.format(label, score)
    x, y = int(x), int(y)
    cv2.putText(
        image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1
    )

  return image


if __name__ == '__main__':
  parser = argparse.ArgumentParser(description='Run MaskRCNN', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  parser.add_argument('--image', type=str, help="Path of the image to run")
  parser.add_argument('--threshold', type=float, default=0.7, help="Detector threshold")
  parser.add_argument('--size_scale', type=float, default=1.0, help="Image resize multiplier")
  parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
  args = parser.parse_args()

  resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
  model_tiny = MaskRCNN(resnet)
  model_tiny.load_from_pretrained()
  img = Image.open(args.image)
  top_result_tiny = compute_prediction(img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale)
  bbox_image = overlay_boxes(img, top_result_tiny)
  mask_image = overlay_mask(bbox_image, top_result_tiny)
  final_image = overlay_class_names(mask_image, top_result_tiny)

  im = Image.fromarray(final_image)
  print(f"saving {args.out}")
  im.save(args.out)
  im.show()