branch: master
openimages.py
1506 bytesRaw
from test.external.mlperf_retinanet.model.boxes import box_iou
from test.external.mlperf_retinanet.model.utils import Matcher

import torch

# This applies the filtering in https://github.com/mlcommons/training/blob/cdd928d4596c142c15a7d86b2eeadbac718c8da2/single_stage_detector/ssd/model/retinanet.py#L117
# and https://github.com/mlcommons/training/blob/cdd928d4596c142c15a7d86b2eeadbac718c8da2/single_stage_detector/ssd/model/retinanet.py#L203
# to match with tinygrad's dataloader implementation.
def postprocess_targets(targets, anchors):
    proposal_matcher, matched_idxs = Matcher(0.5, 0.4, allow_low_quality_matches=True), []
    for anchors_per_image, targets_per_image in zip(anchors, targets):
        if targets_per_image['boxes'].numel() == 0:
            matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64,
                                            device=anchors_per_image.device))
            continue

        match_quality_matrix = box_iou(targets_per_image['boxes'], anchors_per_image)
        matched_idxs.append(proposal_matcher(match_quality_matrix))

    for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs):
        foreground_idxs_per_image = matched_idxs_per_image >= 0
        targets_per_image["boxes"] = targets_per_image["boxes"][matched_idxs_per_image[foreground_idxs_per_image]]
        targets_per_image["labels"] = targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]]

    return targets