branch: master
transforms.py
2469 bytesRaw
# Copied from https://github.com/mlcommons/training/blob/637c82f9e699cd6caf108f92efb2c1d446b630e0/single_stage_detector/ssd/transforms.py

import torch
import torchvision

from torch import nn, Tensor
from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T
from typing import List, Tuple, Dict, Optional

from PIL import Image
Image.MAX_IMAGE_PIXELS = None
from typing import Any

try:
    import accimage
except ImportError:
    accimage = None

@torch.jit.unused
def _is_pil_image(img: Any) -> bool:
    if accimage is not None:
        return isinstance(img, (Image.Image, accimage.Image))
    else:
        return isinstance(img, Image.Image)

def get_image_size_tensor(img: Tensor) -> List[int]:
    # Returns (w, h) of tensor image
    torchvision.transforms._functional_tensor._assert_image_tensor(img)
    return [img.shape[-1], img.shape[-2]]

@torch.jit.unused
def get_image_size_pil(img: Any) -> List[int]:
    if _is_pil_image(img):
        return list(img.size)
    raise TypeError("Unexpected type {}".format(type(img)))

def get_image_size(img: Tensor) -> List[int]:
    """Returns the size of an image as [width, height].
    Args:
        img (PIL Image or Tensor): The image to be checked.
    Returns:
        List[int]: The image size.
    """
    if isinstance(img, torch.Tensor):
        return get_image_size_tensor(img)

    return get_image_size_pil(img)

class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target


class RandomHorizontalFlip(T.RandomHorizontalFlip):
    def forward(self, image: Tensor,
                target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
        if torch.rand(1) < self.p:
            image = F.hflip(image)
            if target is not None:
                width, _ = get_image_size(image)
                target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
                if "masks" in target:
                    target["masks"] = target["masks"].flip(-1)
        return image, target


class ToTensor(nn.Module):
    def forward(self, image: Tensor,
                target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
        image = F.to_tensor(image)
        return image, target