branch: master
mask_rcnn.py
41774 bytesRaw
import re
import math
import os
import numpy as np
from pathlib import Path
from tinygrad import nn, Tensor, dtypes
from tinygrad.tensor import _to_np_dtype
from tinygrad.helpers import get_child, fetch
from tinygrad.nn.state import torch_load
from extra.models.resnet import ResNet
from extra.models.retinanet import nms as _box_nms

USE_NP_GATHER = os.getenv('FULL_TINYGRAD', '0') == '0'

def rint(tensor):
  x = (tensor*2).cast(dtypes.int32).contiguous().cast(dtypes.float32)/2
  return (x<0).where(x.floor(), x.ceil())

def nearest_interpolate(tensor, scale_factor):
  bs, c, py, px = tensor.shape
  return tensor.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, scale_factor, px, scale_factor).reshape(bs, c, py * scale_factor, px * scale_factor)

def meshgrid(x, y):
  grid_x = Tensor.cat(*[x[idx:idx+1].expand(y.shape).unsqueeze(0) for idx in range(x.shape[0])])
  grid_y = Tensor.cat(*[y.unsqueeze(0)]*x.shape[0])
  return grid_x.reshape(-1, 1), grid_y.reshape(-1, 1)

def topk(input_, k, dim=-1, largest=True, sorted=False):
  k = min(k, input_.shape[dim]-1)
  input_ = input_.numpy()
  if largest: input_ *= -1
  ind = np.argpartition(input_, k, axis=dim)
  if largest: input_ *= -1
  ind = np.take(ind, np.arange(k), axis=dim) # k non-sorted indices
  input_ = np.take_along_axis(input_, ind, axis=dim) # k non-sorted values
  if not sorted: return Tensor(input_), ind
  if largest: input_ *= -1
  ind_part = np.argsort(input_, axis=dim)
  ind = np.take_along_axis(ind, ind_part, axis=dim)
  if largest: input_ *= -1
  val = np.take_along_axis(input_, ind_part, axis=dim)
  return Tensor(val), ind

# This is very slow for large arrays, or indices
def _gather(array, indices):
  indices = indices.float().to(array.device)
  reshape_arg = [1]*array.ndim + [array.shape[-1]]
  return Tensor.where(
    indices.unsqueeze(indices.ndim).expand(*indices.shape, array.shape[-1]) == Tensor.arange(array.shape[-1]).reshape(*reshape_arg).expand(*indices.shape, array.shape[-1]),
    array, 0,
  ).sum(indices.ndim)

# TODO: replace npgather with a faster gather using tinygrad only
# NOTE: this blocks the gradient
def npgather(array,indices):
  if isinstance(array, Tensor): array = array.numpy()
  if isinstance(indices, Tensor): indices = indices.numpy()
  if isinstance(indices, list): indices = np.asarray(indices)
  return Tensor(array[indices.astype(int)])

def get_strides(shape):
  prod = [1]
  for idx in range(len(shape)-1, -1, -1): prod.append(prod[-1] * shape[idx])
  # something about ints is broken with gpu, cuda
  return Tensor(prod[::-1][1:], dtype=dtypes.int32).unsqueeze(0)

# with keys as integer array for all axes
def tensor_getitem(tensor, *keys):
  # something about ints is broken with gpu, cuda
  flat_keys = Tensor.stack(*[key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1).cast(dtypes.int32)
  strides = get_strides(tensor.shape)
  idxs = (flat_keys * strides).sum(1)
  gatherer = npgather if USE_NP_GATHER else _gather
  return gatherer(tensor.reshape(-1), idxs).reshape(sum(keys).shape)


# for gather with indicies only on axis=0
def tensor_gather(tensor, indices):
  if not isinstance(indices, Tensor):
    indices = Tensor(indices, requires_grad=False)
  if len(tensor.shape) > 2:
    rem_shape = list(tensor.shape)[1:]
    tensor = tensor.reshape(tensor.shape[0], -1)
  else:
    rem_shape = None
  if len(tensor.shape) > 1:
    tensor = tensor.T
    repeat_arg = [1]*(tensor.ndim-1) + [tensor.shape[-2]]
    indices = indices.unsqueeze(indices.ndim).repeat(repeat_arg)
    ret = _gather(tensor, indices)
    if rem_shape:
      ret = ret.reshape([indices.shape[0]] + rem_shape)
  else:
    ret = _gather(tensor, indices)
  del indices
  return ret


class LastLevelMaxPool:
  def __call__(self, x): return [Tensor.max_pool2d(x, 1, 2)]


# transpose
FLIP_LEFT_RIGHT = 0
FLIP_TOP_BOTTOM = 1


def permute_and_flatten(layer:Tensor, N, A, C, H, W):
  layer = layer.reshape(N, -1, C, H, W)
  layer = layer.permute(0, 3, 4, 1, 2)
  layer = layer.reshape(N, -1, C)
  return layer


class BoxList:
  def __init__(self, bbox, image_size, mode="xyxy"):
    if not isinstance(bbox, Tensor):
      bbox = Tensor(bbox)
    if bbox.ndim != 2:
      raise ValueError(
        "bbox should have 2 dimensions, got {}".format(bbox.ndim)
      )
    if bbox.shape[-1] != 4:
      raise ValueError(
        "last dimenion of bbox should have a "
        "size of 4, got {}".format(bbox.shape[-1])
      )
    if mode not in ("xyxy", "xywh"):
      raise ValueError("mode should be 'xyxy' or 'xywh'")

    self.bbox = bbox
    self.size = image_size  # (image_width, image_height)
    self.mode = mode
    self.extra_fields = {}

  def __repr__(self):
    s = self.__class__.__name__ + "("
    s += "num_boxes={}, ".format(len(self))
    s += "image_width={}, ".format(self.size[0])
    s += "image_height={}, ".format(self.size[1])
    s += "mode={})".format(self.mode)
    return s

  def area(self):
    box = self.bbox
    if self.mode == "xyxy":
      TO_REMOVE = 1
      area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE)
    elif self.mode == "xywh":
      area = box[:, 2] * box[:, 3]
    return area

  def add_field(self, field, field_data):
    self.extra_fields[field] = field_data

  def get_field(self, field):
    return self.extra_fields[field]

  def has_field(self, field):
    return field in self.extra_fields

  def fields(self):
    return list(self.extra_fields.keys())

  def _copy_extra_fields(self, bbox):
    for k, v in bbox.extra_fields.items():
      self.extra_fields[k] = v

  def convert(self, mode):
    if mode == self.mode:
      return self
    xmin, ymin, xmax, ymax = self._split_into_xyxy()
    if mode == "xyxy":
      bbox = Tensor.cat(*(xmin, ymin, xmax, ymax), dim=-1)
      bbox = BoxList(bbox, self.size, mode=mode)
    else:
      TO_REMOVE = 1
      bbox = Tensor.cat(
        *(xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=-1
      )
      bbox = BoxList(bbox, self.size, mode=mode)
    bbox._copy_extra_fields(self)
    return bbox

  def _split_into_xyxy(self):
    if self.mode == "xyxy":
      xmin, ymin, xmax, ymax = self.bbox.chunk(4, dim=-1)
      return xmin, ymin, xmax, ymax
    if self.mode == "xywh":
      TO_REMOVE = 1
      xmin, ymin, w, h = self.bbox.chunk(4, dim=-1)
      return (
        xmin,
        ymin,
        xmin + (w - TO_REMOVE).clamp(min=0),
        ymin + (h - TO_REMOVE).clamp(min=0),
      )

  def resize(self, size, *args, **kwargs):
    ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
    if ratios[0] == ratios[1]:
      ratio = ratios[0]
      scaled_box = self.bbox * ratio
      bbox = BoxList(scaled_box, size, mode=self.mode)
      for k, v in self.extra_fields.items():
        if not isinstance(v, Tensor):
          v = v.resize(size, *args, **kwargs)
        bbox.add_field(k, v)
      return bbox

    ratio_width, ratio_height = ratios
    xmin, ymin, xmax, ymax = self._split_into_xyxy()
    scaled_xmin = xmin * ratio_width
    scaled_xmax = xmax * ratio_width
    scaled_ymin = ymin * ratio_height
    scaled_ymax = ymax * ratio_height
    scaled_box = Tensor.cat(
      *(scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1
    )
    bbox = BoxList(scaled_box, size, mode="xyxy")
    for k, v in self.extra_fields.items():
      if not isinstance(v, Tensor):
        v = v.resize(size, *args, **kwargs)
      bbox.add_field(k, v)

    return bbox.convert(self.mode)

  def transpose(self, method):
    image_width, image_height = self.size
    xmin, ymin, xmax, ymax = self._split_into_xyxy()
    if method == FLIP_LEFT_RIGHT:
      TO_REMOVE = 1
      transposed_xmin = image_width - xmax - TO_REMOVE
      transposed_xmax = image_width - xmin - TO_REMOVE
      transposed_ymin = ymin
      transposed_ymax = ymax
    elif method == FLIP_TOP_BOTTOM:
      transposed_xmin = xmin
      transposed_xmax = xmax
      transposed_ymin = image_height - ymax
      transposed_ymax = image_height - ymin

    transposed_boxes = Tensor.cat(
      *(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1
    )
    bbox = BoxList(transposed_boxes, self.size, mode="xyxy")
    for k, v in self.extra_fields.items():
      if not isinstance(v, Tensor):
        v = v.transpose(method)
      bbox.add_field(k, v)
    return bbox.convert(self.mode)

  def clip_to_image(self, remove_empty=True):
    TO_REMOVE = 1
    bb1 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 0]
    bb2 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 1]
    bb3 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 2]
    bb4 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 3]
    self.bbox = Tensor.stack(bb1, bb2, bb3, bb4, dim=1)
    if remove_empty:
      box = self.bbox
      keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0])
      return self[keep]
    return self

  def __getitem__(self, item):
    if isinstance(item, list):
      if len(item) == 0:
        return []
      if sum(item) == len(item) and isinstance(item[0], bool):
        return self
    bbox = BoxList(tensor_gather(self.bbox, item), self.size, self.mode)
    for k, v in self.extra_fields.items():
      bbox.add_field(k, tensor_gather(v, item))
    return bbox

  def __len__(self):
    return self.bbox.shape[0]


def cat_boxlist(bboxes):
  size = bboxes[0].size
  mode = bboxes[0].mode
  fields = set(bboxes[0].fields())
  cat_box_list = [bbox.bbox for bbox in bboxes if bbox.bbox.shape[0] > 0]

  if len(cat_box_list) > 0:
    cat_boxes = BoxList(Tensor.cat(*cat_box_list, dim=0), size, mode)
  else:
    cat_boxes = BoxList(bboxes[0].bbox, size, mode)
  for field in fields:
    cat_field_list = [bbox.get_field(field) for bbox in bboxes if bbox.get_field(field).shape[0] > 0]

    if len(cat_box_list) > 0:
      data = Tensor.cat(*cat_field_list, dim=0)
    else:
      data = bboxes[0].get_field(field)

    cat_boxes.add_field(field, data)

  return cat_boxes


class FPN:
  def __init__(self, in_channels_list, out_channels):
    self.inner_blocks, self.layer_blocks = [], []
    for in_channels in in_channels_list:
      self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
      self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
    self.top_block = LastLevelMaxPool()

  def __call__(self, x: Tensor):
    last_inner = self.inner_blocks[-1](x[-1])
    results = []
    results.append(self.layer_blocks[-1](last_inner))
    for feature, inner_block, layer_block in zip(
            x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]
    ):
      if not inner_block:
        continue
      inner_top_down = nearest_interpolate(last_inner, scale_factor=2)
      inner_lateral = inner_block(feature)
      last_inner = inner_lateral + inner_top_down
      layer_result = layer_block(last_inner)
      results.insert(0, layer_result)
    last_results = self.top_block(results[-1])
    results.extend(last_results)

    return tuple(results)


class ResNetFPN:
  def __init__(self, resnet, out_channels=256):
    self.out_channels = out_channels
    self.body = resnet
    in_channels_stage2 = 256
    in_channels_list = [
      in_channels_stage2,
      in_channels_stage2 * 2,
      in_channels_stage2 * 4,
      in_channels_stage2 * 8,
    ]
    self.fpn = FPN(in_channels_list, out_channels)

  def __call__(self, x):
    x = self.body(x)
    return self.fpn(x)


class AnchorGenerator:
  def __init__(
          self,
          sizes=(32, 64, 128, 256, 512),
          aspect_ratios=(0.5, 1.0, 2.0),
          anchor_strides=(4, 8, 16, 32, 64),
          straddle_thresh=0,
  ):
    if len(anchor_strides) == 1:
      anchor_stride = anchor_strides[0]
      cell_anchors = [
        generate_anchors(anchor_stride, sizes, aspect_ratios)
      ]
    else:
      if len(anchor_strides) != len(sizes):
        raise RuntimeError("FPN should have #anchor_strides == #sizes")

      cell_anchors = [
        generate_anchors(
          anchor_stride,
          size if isinstance(size, (tuple, list)) else (size,),
          aspect_ratios
        )
        for anchor_stride, size in zip(anchor_strides, sizes)
      ]
    self.strides = anchor_strides
    self.cell_anchors = cell_anchors
    self.straddle_thresh = straddle_thresh

  def num_anchors_per_location(self):
    return [cell_anchors.shape[0] for cell_anchors in self.cell_anchors]

  def grid_anchors(self, grid_sizes):
    anchors = []
    for size, stride, base_anchors in zip(
            grid_sizes, self.strides, self.cell_anchors
    ):
      grid_height, grid_width = size
      device = base_anchors.device
      shifts_x = Tensor.arange(
        start=0, stop=grid_width * stride, step=stride, dtype=dtypes.float32, device=device
      )
      shifts_y = Tensor.arange(
        start=0, stop=grid_height * stride, step=stride, dtype=dtypes.float32, device=device
      )
      shift_y, shift_x = meshgrid(shifts_y, shifts_x)
      shift_x = shift_x.reshape(-1)
      shift_y = shift_y.reshape(-1)
      shifts = Tensor.stack(shift_x, shift_y, shift_x, shift_y, dim=1)

      anchors.append(
        (shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape(-1, 4)
      )

    return anchors

  def add_visibility_to(self, boxlist):
    image_width, image_height = boxlist.size
    anchors = boxlist.bbox
    if self.straddle_thresh >= 0:
      inds_inside = (
              (anchors[:, 0] >= -self.straddle_thresh)
              * (anchors[:, 1] >= -self.straddle_thresh)
              * (anchors[:, 2] < image_width + self.straddle_thresh)
              * (anchors[:, 3] < image_height + self.straddle_thresh)
      )
    else:
      device = anchors.device
      inds_inside = Tensor.ones(anchors.shape[0], dtype=dtypes.uint8, device=device)
    boxlist.add_field("visibility", inds_inside)

  def __call__(self, image_list, feature_maps):
    grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
    anchors_over_all_feature_maps = self.grid_anchors(grid_sizes)
    anchors = []
    for (image_height, image_width) in image_list.image_sizes:
      anchors_in_image = []
      for anchors_per_feature_map in anchors_over_all_feature_maps:
        boxlist = BoxList(
          anchors_per_feature_map, (image_width, image_height), mode="xyxy"
        )
        self.add_visibility_to(boxlist)
        anchors_in_image.append(boxlist)
      anchors.append(anchors_in_image)
    return anchors


def generate_anchors(
    stride=16, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2)
):
  return _generate_anchors(stride, Tensor(list(sizes)) / stride, Tensor(list(aspect_ratios)))


def _generate_anchors(base_size, scales, aspect_ratios):
  anchor = Tensor([1, 1, base_size, base_size]) - 1
  anchors = _ratio_enum(anchor, aspect_ratios)
  anchors = Tensor.cat(
    *[_scale_enum(anchors[i, :], scales).reshape(-1, 4) for i in range(anchors.shape[0])]
  )
  return anchors


def _whctrs(anchor):
  w = anchor[2] - anchor[0] + 1
  h = anchor[3] - anchor[1] + 1
  x_ctr = anchor[0] + 0.5 * (w - 1)
  y_ctr = anchor[1] + 0.5 * (h - 1)
  return w, h, x_ctr, y_ctr


def _mkanchors(ws, hs, x_ctr, y_ctr):
  ws = ws[:, None]
  hs = hs[:, None]
  anchors = Tensor.cat(*(
    x_ctr - 0.5 * (ws - 1),
    y_ctr - 0.5 * (hs - 1),
    x_ctr + 0.5 * (ws - 1),
    y_ctr + 0.5 * (hs - 1),
  ), dim=1)
  return anchors


def _ratio_enum(anchor, ratios):
  w, h, x_ctr, y_ctr = _whctrs(anchor)
  size = w * h
  size_ratios = size / ratios
  ws = rint(Tensor.sqrt(size_ratios))
  hs = rint(ws * ratios)
  anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
  return anchors


def _scale_enum(anchor, scales):
  w, h, x_ctr, y_ctr = _whctrs(anchor)
  ws = w * scales
  hs = h * scales
  anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
  return anchors


class RPNHead:
  def __init__(self, in_channels, num_anchors):
    self.conv = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1)
    self.cls_logits = nn.Conv2d(256, num_anchors, kernel_size=1)
    self.bbox_pred = nn.Conv2d(256, num_anchors * 4, kernel_size=1)

  def __call__(self, x):
    logits = []
    bbox_reg = []
    for feature in x:
      t = Tensor.relu(self.conv(feature))
      logits.append(self.cls_logits(t))
      bbox_reg.append(self.bbox_pred(t))
    return logits, bbox_reg


class BoxCoder(object):
  def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
    self.weights = weights
    self.bbox_xform_clip = bbox_xform_clip

  def encode(self, reference_boxes, proposals):
    TO_REMOVE = 1  # TODO remove
    ex_widths = proposals[:, 2] - proposals[:, 0] + TO_REMOVE
    ex_heights = proposals[:, 3] - proposals[:, 1] + TO_REMOVE
    ex_ctr_x = proposals[:, 0] + 0.5 * ex_widths
    ex_ctr_y = proposals[:, 1] + 0.5 * ex_heights

    gt_widths = reference_boxes[:, 2] - reference_boxes[:, 0] + TO_REMOVE
    gt_heights = reference_boxes[:, 3] - reference_boxes[:, 1] + TO_REMOVE
    gt_ctr_x = reference_boxes[:, 0] + 0.5 * gt_widths
    gt_ctr_y = reference_boxes[:, 1] + 0.5 * gt_heights

    wx, wy, ww, wh = self.weights
    targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
    targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
    targets_dw = ww * Tensor.log(gt_widths / ex_widths)
    targets_dh = wh * Tensor.log(gt_heights / ex_heights)

    targets = Tensor.stack(targets_dx, targets_dy, targets_dw, targets_dh, dim=1)
    return targets

  def decode(self, rel_codes, boxes):
    boxes = boxes.cast(rel_codes.dtype)
    rel_codes = rel_codes

    TO_REMOVE = 1  # TODO remove
    widths = boxes[:, 2] - boxes[:, 0] + TO_REMOVE
    heights = boxes[:, 3] - boxes[:, 1] + TO_REMOVE
    ctr_x = boxes[:, 0] + 0.5 * widths
    ctr_y = boxes[:, 1] + 0.5 * heights

    wx, wy, ww, wh = self.weights
    dx = rel_codes[:, 0::4] / wx
    dy = rel_codes[:, 1::4] / wy
    dw = rel_codes[:, 2::4] / ww
    dh = rel_codes[:, 3::4] / wh

    # Prevent sending too large values into Tensor.exp()
    dw = dw.clip(min_=dw.min(), max_=self.bbox_xform_clip)
    dh = dh.clip(min_=dh.min(), max_=self.bbox_xform_clip)

    pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
    pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
    pred_w = dw.exp() * widths[:, None]
    pred_h = dh.exp() * heights[:, None]
    x = pred_ctr_x - 0.5 * pred_w
    y = pred_ctr_y - 0.5 * pred_h
    w = pred_ctr_x + 0.5 * pred_w - 1
    h = pred_ctr_y + 0.5 * pred_h - 1
    pred_boxes = Tensor.stack(x, y, w, h).permute(1,2,0).reshape(rel_codes.shape[0], rel_codes.shape[1])
    return pred_boxes


def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"):
  if nms_thresh <= 0:
    return boxlist
  mode = boxlist.mode
  boxlist = boxlist.convert("xyxy")
  boxes = boxlist.bbox
  score = boxlist.get_field(score_field)
  keep = _box_nms(boxes.numpy(), score.numpy(), nms_thresh)
  if max_proposals > 0:
    keep = keep[:max_proposals]
  boxlist = boxlist[keep]
  return boxlist.convert(mode)


def remove_small_boxes(boxlist, min_size):
  xywh_boxes = boxlist.convert("xywh").bbox
  _, _, ws, hs = xywh_boxes.chunk(4, dim=1)
  keep = ((
          (ws >= min_size) * (hs >= min_size)
  ) > 0).reshape(-1)
  if keep.sum().numpy() == len(boxlist):
    return boxlist
  else:
    keep = keep.numpy().nonzero()[0]
  return boxlist[keep]


class RPNPostProcessor:
  # Not used in Loss calculation
  def __init__(
          self,
          pre_nms_top_n,
          post_nms_top_n,
          nms_thresh,
          min_size,
          box_coder=None,
          fpn_post_nms_top_n=None,
  ):
    self.pre_nms_top_n = pre_nms_top_n
    self.post_nms_top_n = post_nms_top_n
    self.nms_thresh = nms_thresh
    self.min_size = min_size

    if box_coder is None:
      box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
    self.box_coder = box_coder

    if fpn_post_nms_top_n is None:
      fpn_post_nms_top_n = post_nms_top_n
    self.fpn_post_nms_top_n = fpn_post_nms_top_n

  def forward_for_single_feature_map(self, anchors, objectness, box_regression):
    device = objectness.device
    N, A, H, W = objectness.shape
    objectness = permute_and_flatten(objectness, N, A, 1, H, W).reshape(N, -1)
    objectness = objectness.sigmoid()

    box_regression = permute_and_flatten(box_regression, N, A, 4, H, W)

    num_anchors = A * H * W

    pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
    objectness, topk_idx = topk(objectness, pre_nms_top_n, dim=1, sorted=False)
    concat_anchors = Tensor.cat(*[a.bbox for a in anchors], dim=0).reshape(N, -1, 4)
    image_shapes = [box.size for box in anchors]

    box_regression_list = []
    concat_anchors_list = []
    for batch_idx in range(N):
      box_regression_list.append(tensor_gather(box_regression[batch_idx], topk_idx[batch_idx]))
      concat_anchors_list.append(tensor_gather(concat_anchors[batch_idx], topk_idx[batch_idx]))

    box_regression = Tensor.stack(*box_regression_list)
    concat_anchors = Tensor.stack(*concat_anchors_list)

    proposals = self.box_coder.decode(
      box_regression.reshape(-1, 4), concat_anchors.reshape(-1, 4)
    )

    proposals = proposals.reshape(N, -1, 4)

    result = []
    for proposal, score, im_shape in zip(proposals, objectness, image_shapes):
      boxlist = BoxList(proposal, im_shape, mode="xyxy")
      boxlist.add_field("objectness", score)
      boxlist = boxlist.clip_to_image(remove_empty=False)
      boxlist = remove_small_boxes(boxlist, self.min_size)
      boxlist = boxlist_nms(
        boxlist,
        self.nms_thresh,
        max_proposals=self.post_nms_top_n,
        score_field="objectness",
      )
      result.append(boxlist)
    return result

  def __call__(self, anchors, objectness, box_regression):
    sampled_boxes = []
    num_levels = len(objectness)
    anchors = list(zip(*anchors))
    for a, o, b in zip(anchors, objectness, box_regression):
      sampled_boxes.append(self.forward_for_single_feature_map(a, o, b))

    boxlists = list(zip(*sampled_boxes))
    boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]

    if num_levels > 1:
      boxlists = self.select_over_all_levels(boxlists)

    return boxlists

  def select_over_all_levels(self, boxlists):
    num_images = len(boxlists)
    for i in range(num_images):
      objectness = boxlists[i].get_field("objectness")
      post_nms_top_n = min(self.fpn_post_nms_top_n, objectness.shape[0])
      _, inds_sorted = topk(objectness,
        post_nms_top_n, dim=0, sorted=False
      )
      boxlists[i] = boxlists[i][inds_sorted]
    return boxlists


class RPN:
  def __init__(self, in_channels):
    self.anchor_generator = AnchorGenerator()

    in_channels = 256
    head = RPNHead(
      in_channels, self.anchor_generator.num_anchors_per_location()[0]
    )
    rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
    box_selector_test = RPNPostProcessor(
        pre_nms_top_n=1000,
        post_nms_top_n=1000,
        nms_thresh=0.7,
        min_size=0,
        box_coder=rpn_box_coder,
        fpn_post_nms_top_n=1000
    )
    self.head = head
    self.box_selector_test = box_selector_test

  def __call__(self, images, features, targets=None):
    objectness, rpn_box_regression = self.head(features)
    anchors = self.anchor_generator(images, features)
    boxes = self.box_selector_test(anchors, objectness, rpn_box_regression)
    return boxes, {}


def make_conv3x3(
  in_channels,
  out_channels,
  dilation=1,
  stride=1,
  use_gn=False,
):
  conv = nn.Conv2d(
    in_channels,
    out_channels,
    kernel_size=3,
    stride=stride,
    padding=dilation,
    dilation=dilation,
    bias=False if use_gn else True
  )
  return conv


class MaskRCNNFPNFeatureExtractor:
  def __init__(self):
    resolution = 14
    scales = (0.25, 0.125, 0.0625, 0.03125)
    sampling_ratio = 2
    pooler = Pooler(
      output_size=(resolution, resolution),
      scales=scales,
      sampling_ratio=sampling_ratio,
    )
    input_size = 256
    self.pooler = pooler

    use_gn = False
    layers = (256, 256, 256, 256)
    dilation = 1
    self.mask_fcn1 = make_conv3x3(input_size, layers[0], dilation=dilation, stride=1, use_gn=use_gn)
    self.mask_fcn2 = make_conv3x3(layers[0], layers[1], dilation=dilation, stride=1, use_gn=use_gn)
    self.mask_fcn3 = make_conv3x3(layers[1], layers[2], dilation=dilation, stride=1, use_gn=use_gn)
    self.mask_fcn4 = make_conv3x3(layers[2], layers[3], dilation=dilation, stride=1, use_gn=use_gn)
    self.blocks = [self.mask_fcn1, self.mask_fcn2, self.mask_fcn3, self.mask_fcn4]

  def __call__(self, x, proposals):
    x = self.pooler(x, proposals)
    for layer in self.blocks:
      if x is not None:
        x = Tensor.relu(layer(x))
    return x


class MaskRCNNC4Predictor:
  def __init__(self):
    num_classes = 81
    dim_reduced = 256
    num_inputs = dim_reduced
    self.conv5_mask = nn.ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0)
    self.mask_fcn_logits = nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)

  def __call__(self, x):
    x = Tensor.relu(self.conv5_mask(x))
    return self.mask_fcn_logits(x)


class FPN2MLPFeatureExtractor:
  def __init__(self, cfg):
    resolution = 7
    scales = (0.25, 0.125, 0.0625, 0.03125)
    sampling_ratio = 2
    pooler = Pooler(
      output_size=(resolution, resolution),
      scales=scales,
      sampling_ratio=sampling_ratio,
    )
    input_size = 256 * resolution ** 2
    representation_size = 1024
    self.pooler = pooler
    self.fc6 = nn.Linear(input_size, representation_size)
    self.fc7 = nn.Linear(representation_size, representation_size)

  def __call__(self, x, proposals):
    x = self.pooler(x, proposals)
    x = x.reshape(x.shape[0], -1)
    x = Tensor.relu(self.fc6(x))
    x = Tensor.relu(self.fc7(x))
    return x


def _bilinear_interpolate(
  input,  # [N, C, H, W]
  roi_batch_ind,  # [K]
  y,  # [K, PH, IY]
  x,  # [K, PW, IX]
  ymask,  # [K, IY]
  xmask,  # [K, IX]
):
  _, channels, height, width = input.shape
  y = y.clip(min_=0.0, max_=float(height-1))
  x = x.clip(min_=0.0, max_=float(width-1))

  # Tensor.where doesnt work well with int32 data so cast to float32
  y_low = y.cast(dtypes.int32).contiguous().float().contiguous()
  x_low = x.cast(dtypes.int32).contiguous().float().contiguous()

  y_high = Tensor.where(y_low >= height - 1, float(height - 1), y_low + 1)
  y_low = Tensor.where(y_low >= height - 1, float(height - 1), y_low)

  x_high = Tensor.where(x_low >= width - 1, float(width - 1), x_low + 1)
  x_low = Tensor.where(x_low >= width - 1, float(width - 1), x_low)

  ly = y - y_low
  lx = x - x_low
  hy = 1.0 - ly
  hx = 1.0 - lx

  def masked_index(
    y,  # [K, PH, IY]
    x,  # [K, PW, IX]
  ):
    if ymask is not None:
      assert xmask is not None
      y = Tensor.where(ymask[:, None, :], y, 0)
      x = Tensor.where(xmask[:, None, :], x, 0)
    key1 = roi_batch_ind[:, None, None, None, None, None]
    key2 = Tensor.arange(channels, device=input.device)[None, :, None, None, None, None]
    key3 = y[:, None, :, None, :, None]
    key4 = x[:, None, None, :, None, :]
    return tensor_getitem(input,key1,key2,key3,key4)  # [K, C, PH, PW, IY, IX]

  v1 = masked_index(y_low, x_low)
  v2 = masked_index(y_low, x_high)
  v3 = masked_index(y_high, x_low)
  v4 = masked_index(y_high, x_high)

  # all ws preemptively [K, C, PH, PW, IY, IX]
  def outer_prod(y, x):
    return y[:, None, :, None, :, None] * x[:, None, None, :, None, :]

  w1 = outer_prod(hy, hx)
  w2 = outer_prod(hy, lx)
  w3 = outer_prod(ly, hx)
  w4 = outer_prod(ly, lx)

  val = w1*v1 + w2*v2 + w3*v3 + w4*v4
  return val

#https://pytorch.org/vision/main/_modules/torchvision/ops/roi_align.html#roi_align
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
  orig_dtype = input.dtype
  _, _, height, width = input.shape
  ph = Tensor.arange(pooled_height, device=input.device)
  pw = Tensor.arange(pooled_width, device=input.device)

  roi_batch_ind = rois[:, 0].cast(dtypes.int32).contiguous()
  offset = 0.5 if aligned else 0.0
  roi_start_w = rois[:, 1] * spatial_scale - offset
  roi_start_h = rois[:, 2] * spatial_scale - offset
  roi_end_w = rois[:, 3] * spatial_scale - offset
  roi_end_h = rois[:, 4] * spatial_scale - offset

  roi_width = roi_end_w - roi_start_w
  roi_height = roi_end_h - roi_start_h
  if not aligned:
    roi_width = roi_width.maximum(1.0)
    roi_height = roi_height.maximum(1.0)

  bin_size_h = roi_height / pooled_height
  bin_size_w = roi_width / pooled_width

  exact_sampling = sampling_ratio > 0
  roi_bin_grid_h = sampling_ratio if exact_sampling else (roi_height / pooled_height).ceil()
  roi_bin_grid_w = sampling_ratio if exact_sampling else (roi_width / pooled_width).ceil()

  if exact_sampling:
    count = max(roi_bin_grid_h * roi_bin_grid_w, 1)
    iy = Tensor.arange(roi_bin_grid_h, device=input.device)
    ix = Tensor.arange(roi_bin_grid_w, device=input.device)
    ymask = None
    xmask = None
  else:
    count = (roi_bin_grid_h * roi_bin_grid_w).maximum(1)
    iy = Tensor.arange(height, device=input.device)
    ix = Tensor.arange(width, device=input.device)
    ymask = iy[None, :] < roi_bin_grid_h[:, None]
    xmask = ix[None, :] < roi_bin_grid_w[:, None]

  def from_K(t):
    return t[:, None, None]

  y = (
    from_K(roi_start_h)
    + ph[None, :, None] * from_K(bin_size_h)
    + (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h)
  )
  x = (
    from_K(roi_start_w)
    + pw[None, :, None] * from_K(bin_size_w)
    + (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w)
  )

  val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask)
  if not exact_sampling:
    val = ymask[:, None, None, None, :, None].where(val, 0)
    val = xmask[:, None, None, None, None, :].where(val, 0)

  output = val.sum((-1, -2))
  if isinstance(count, Tensor):
    output /= count[:, None, None, None]
  else:
    output /= count

  output = output.cast(orig_dtype)
  return output

class ROIAlign:
  def __init__(self, output_size, spatial_scale, sampling_ratio):
    self.output_size = output_size
    self.spatial_scale = spatial_scale
    self.sampling_ratio = sampling_ratio

  def __call__(self, input, rois):
    output = _roi_align(
      input, rois, self.spatial_scale, self.output_size[0], self.output_size[1], self.sampling_ratio, aligned=False
    )
    return output


class LevelMapper:
  def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
    self.k_min = k_min
    self.k_max = k_max
    self.s0 = canonical_scale
    self.lvl0 = canonical_level
    self.eps = eps

  def __call__(self, boxlists):
    s = Tensor.sqrt(Tensor.cat(*[boxlist.area() for boxlist in boxlists]))
    target_lvls = (self.lvl0 + Tensor.log2(s / self.s0 + self.eps)).floor()
    target_lvls = target_lvls.clip(min_=self.k_min, max_=self.k_max)
    return target_lvls - self.k_min


class Pooler:
  def __init__(self, output_size, scales, sampling_ratio):
    self.output_size = output_size
    self.scales = scales
    self.sampling_ratio = sampling_ratio
    poolers = []
    for scale in scales:
      poolers.append(
        ROIAlign(
          output_size, spatial_scale=scale, sampling_ratio=sampling_ratio
        )
      )
    self.poolers = poolers
    self.output_size = output_size
    lvl_min = -math.log2(scales[0])
    lvl_max = -math.log2(scales[-1])
    self.map_levels = LevelMapper(lvl_min, lvl_max)

  def convert_to_roi_format(self, boxes):
    concat_boxes = Tensor.cat(*[b.bbox for b in boxes], dim=0)
    device, dtype = concat_boxes.device, concat_boxes.dtype
    ids = Tensor.cat(
      *[
        Tensor.full((len(b), 1), i, dtype=dtype, device=device)
        for i, b in enumerate(boxes)
      ],
      dim=0,
    )
    if concat_boxes.shape[0] != 0:
      rois = Tensor.cat(*[ids, concat_boxes], dim=1)
      return rois

  def __call__(self, x, boxes):
    num_levels = len(self.poolers)
    rois = self.convert_to_roi_format(boxes)
    if rois is not None:
      if num_levels == 1:
        return self.poolers[0](x[0], rois)

      levels = self.map_levels(boxes)
      results = []
      all_idxs = []
      for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)):
        # this is fine because no grad will flow through index
        idx_in_level = (levels.numpy() == level).nonzero()[0]
        if len(idx_in_level) > 0:
          rois_per_level = tensor_gather(rois, idx_in_level)
          pooler_output = pooler(per_level_feature, rois_per_level)
          all_idxs.extend(idx_in_level)
          results.append(pooler_output)

      return tensor_gather(Tensor.cat(*results), [x[0] for x in sorted({i:idx for i, idx in enumerate(all_idxs)}.items(), key=lambda x: x[1])])


class FPNPredictor:
  def __init__(self):
    num_classes = 81
    representation_size = 1024
    self.cls_score = nn.Linear(representation_size, num_classes)
    num_bbox_reg_classes = num_classes
    self.bbox_pred = nn.Linear(representation_size, num_bbox_reg_classes * 4)

  def __call__(self, x):
    scores = self.cls_score(x)
    bbox_deltas = self.bbox_pred(x)
    return scores, bbox_deltas


class PostProcessor:
  # Not used in training
  def __init__(
          self,
          score_thresh=0.05,
          nms=0.5,
          detections_per_img=100,
          box_coder=None,
          cls_agnostic_bbox_reg=False
  ):
    self.score_thresh = score_thresh
    self.nms = nms
    self.detections_per_img = detections_per_img
    if box_coder is None:
      box_coder = BoxCoder(weights=(10., 10., 5., 5.))
    self.box_coder = box_coder
    self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg

  def __call__(self, x, boxes):
    class_logits, box_regression = x
    class_prob = Tensor.softmax(class_logits, -1)
    image_shapes = [box.size for box in boxes]
    boxes_per_image = [len(box) for box in boxes]
    concat_boxes = Tensor.cat(*[a.bbox for a in boxes], dim=0)

    if self.cls_agnostic_bbox_reg:
      box_regression = box_regression[:, -4:]
    proposals = self.box_coder.decode(
      box_regression.reshape(sum(boxes_per_image), -1), concat_boxes
    )
    if self.cls_agnostic_bbox_reg:
      proposals = proposals.repeat([1, class_prob.shape[1]])
    num_classes = class_prob.shape[1]
    proposals = proposals.unsqueeze(0)
    class_prob = class_prob.unsqueeze(0)
    results = []
    for prob, boxes_per_img, image_shape in zip(
            class_prob, proposals, image_shapes
    ):
      boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape)
      boxlist = boxlist.clip_to_image(remove_empty=False)
      boxlist = self.filter_results(boxlist, num_classes)
      results.append(boxlist)
    return results

  def prepare_boxlist(self, boxes, scores, image_shape):
    boxes = boxes.reshape(-1, 4)
    scores = scores.reshape(-1)
    boxlist = BoxList(boxes, image_shape, mode="xyxy")
    boxlist.add_field("scores", scores)
    return boxlist

  def filter_results(self, boxlist, num_classes):
    boxes = boxlist.bbox.reshape(-1, num_classes * 4)
    scores = boxlist.get_field("scores").reshape(-1, num_classes)

    device = scores.device
    result = []
    scores = scores.numpy()
    boxes = boxes.numpy()
    inds_all = scores > self.score_thresh
    for j in range(1, num_classes):
      inds = inds_all[:, j].nonzero()[0]
      # This needs to be done in numpy because it can create empty arrays
      scores_j = scores[inds, j]
      boxes_j = boxes[inds, j * 4: (j + 1) * 4]
      boxes_j = Tensor(boxes_j)
      scores_j = Tensor(scores_j)
      boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
      boxlist_for_class.add_field("scores", scores_j)
      if len(boxlist_for_class):
        boxlist_for_class = boxlist_nms(
          boxlist_for_class, self.nms
        )
      num_labels = len(boxlist_for_class)
      boxlist_for_class.add_field(
        "labels", Tensor.full((num_labels,), j, device=device)
      )
      result.append(boxlist_for_class)

    result = cat_boxlist(result)
    number_of_detections = len(result)

    if number_of_detections > self.detections_per_img > 0:
      cls_scores = result.get_field("scores")
      image_thresh, _ = topk(cls_scores, k=self.detections_per_img)
      image_thresh = image_thresh.numpy()[-1]
      keep = (cls_scores.numpy() >= image_thresh).nonzero()[0]
      result = result[keep]
    return result


class RoIBoxHead:
  def __init__(self, in_channels):
    self.feature_extractor = FPN2MLPFeatureExtractor(in_channels)
    self.predictor = FPNPredictor()
    self.post_processor = PostProcessor(
        score_thresh=0.05,
        nms=0.5,
        detections_per_img=100,
        box_coder=BoxCoder(weights=(10., 10., 5., 5.)),
        cls_agnostic_bbox_reg=False
    )

  def __call__(self, features, proposals, targets=None):
    x = self.feature_extractor(features, proposals)
    class_logits, box_regression = self.predictor(x)
    if not Tensor.training:
      result = self.post_processor((class_logits, box_regression), proposals)
      return x, result, {}


class MaskPostProcessor:
  # Not used in loss calculation
  def __call__(self, x, boxes):
    mask_prob = x.sigmoid().numpy()
    num_masks = x.shape[0]
    labels = [bbox.get_field("labels") for bbox in boxes]
    labels = Tensor.cat(*labels).numpy().astype(np.int32)
    index = np.arange(num_masks)
    mask_prob = mask_prob[index, labels][:, None]
    boxes_per_image, cumsum = [], 0
    for box in boxes:
      cumsum += len(box)
      boxes_per_image.append(cumsum)
    # using numpy here as Tensor.chunk doesnt have custom chunk sizes
    mask_prob = np.split(mask_prob, boxes_per_image, axis=0)
    results = []
    for prob, box in zip(mask_prob, boxes):
      bbox = BoxList(box.bbox, box.size, mode="xyxy")
      for field in box.fields():
        bbox.add_field(field, box.get_field(field))
      prob = Tensor(prob)
      bbox.add_field("mask", prob)
      results.append(bbox)

    return results


class Mask:
  def __init__(self):
    self.feature_extractor = MaskRCNNFPNFeatureExtractor()
    self.predictor = MaskRCNNC4Predictor()
    self.post_processor = MaskPostProcessor()

  def __call__(self, features, proposals, targets=None):
    x = self.feature_extractor(features, proposals)
    if x is not None:
      mask_logits = self.predictor(x)
      if not Tensor.training:
        result = self.post_processor(mask_logits, proposals)
        return x, result, {}
    return x, [], {}


class RoIHeads:
  def __init__(self, in_channels):
    self.box = RoIBoxHead(in_channels)
    self.mask = Mask()

  def __call__(self, features, proposals, targets=None):
    x, detections, _ = self.box(features, proposals, targets)
    x, detections, _ = self.mask(features, detections, targets)
    return x, detections, {}


class ImageList(object):
  def __init__(self, tensors, image_sizes):
    self.tensors = tensors
    self.image_sizes = image_sizes

  def to(self, *args, **kwargs):
    cast_tensor = self.tensors.to(*args, **kwargs)
    return ImageList(cast_tensor, self.image_sizes)


def to_image_list(tensors, size_divisible=32):
  # Preprocessing
  if isinstance(tensors, Tensor) and size_divisible > 0:
    tensors = [tensors]

  if isinstance(tensors, ImageList):
    return tensors
  elif isinstance(tensors, Tensor):
    # single tensor shape can be inferred
    assert tensors.ndim == 4
    image_sizes = [tensor.shape[-2:] for tensor in tensors]
    return ImageList(tensors, image_sizes)
  elif isinstance(tensors, (tuple, list)):
    max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))
    if size_divisible > 0:

      stride = size_divisible
      max_size = list(max_size)
      max_size[1] = int(math.ceil(max_size[1] / stride) * stride)
      max_size[2] = int(math.ceil(max_size[2] / stride) * stride)
      max_size = tuple(max_size)

    batch_shape = (len(tensors),) + max_size
    batched_imgs = np.zeros(batch_shape, dtype=_to_np_dtype(tensors[0].dtype))
    for img, pad_img in zip(tensors, batched_imgs):
      pad_img[: img.shape[0], : img.shape[1], : img.shape[2]] += img.numpy()

    batched_imgs = Tensor(batched_imgs)
    image_sizes = [im.shape[-2:] for im in tensors]

    return ImageList(batched_imgs, image_sizes)
  else:
    raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors)))


class MaskRCNN:
  def __init__(self, backbone: ResNet):
    self.backbone = ResNetFPN(backbone, out_channels=256)
    self.rpn = RPN(self.backbone.out_channels)
    self.roi_heads = RoIHeads(self.backbone.out_channels)

  def load_from_pretrained(self):
    fn = Path('./') / "weights/maskrcnn.pt"
    fetch("https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_50_FPN_1x.pth", fn)

    state_dict = torch_load(fn)['model']
    loaded_keys = []
    for k, v in state_dict.items():
      if "module." in k:
        k = k.replace("module.", "")
      if "stem." in k:
        k = k.replace("stem.", "")
      if "fpn_inner" in k:
        block_index = int(re.search(r"fpn_inner(\d+)", k).group(1))
        k = re.sub(r"fpn_inner\d+", f"inner_blocks.{block_index - 1}", k)
      if "fpn_layer" in k:
        block_index = int(re.search(r"fpn_layer(\d+)", k).group(1))
        k = re.sub(r"fpn_layer\d+", f"layer_blocks.{block_index - 1}", k)
      loaded_keys.append(k)
      get_child(self, k).assign(v.numpy()).realize()
    return loaded_keys

  def __call__(self, images):
    images = to_image_list(images)
    features = self.backbone(images.tensors)
    proposals, _ = self.rpn(images, features)
    x, result, _ = self.roi_heads(features, proposals)
    return result


if __name__ == '__main__':
  resnet = resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
  model = MaskRCNN(backbone=resnet)
  model.load_from_pretrained()