SoFunction
Updated on 2024-11-18

TorchVision Transforms API Target Detection Examples Semantic Segmentation Video Classes

Content Introduction

The TorchVision Transforms API has been extended and now supports target detection, instance and semantic segmentation, and video-based tasks. The new API is still in beta and developers can try it out.

This article was first published on WeChat: PyTorch Developer Community

TorchVision has now been extended for the Transforms API. The details are as follows:

  • In addition to being used for image classification, it can now be used for tasks such as target detection, instance and semantic segmentation, and video classification;
  • Supports direct import of SoTA data enhancements such as MixUp, CutMix, Large Scale Jitter, and SimpleCopyPaste from TorchVision.
  • Support for converting video, Bounding boxes and Segmentation Masks using new functional transforms.

Current limitations of Transforms

The stable version of the TorchVision Transforms API, also known as Transforms V1.Only a single image is supported and, therefore, is only suitable for classification tasks:

from torchvision import transforms
trans = ([
   (contrast=0.5),
   (30),
   (480),
])
imgs = trans(imgs)

The above methods do not support target detection, segmentation or classification Transforms that require the use of Label. This prevents computer vision tasks other than categorization from performing the necessary extensions with the Transforms API. At the same time, theThis also increases the difficulty of training high-precision models with TorchVision primitives.

To overcome this limitation, theTorchVision provides a custom implementation in its reference script. Used to demonstrate how enhancements are performed in all tasks.

Although this approach allows developers to train highly accurate models for classification, target detection and segmentation, the approach is relatively crude, theTransforms still can't be imported in TorchVision binaries.

New Transforms API

The Transforms V2 API supports videos, bounding boxes, labels, and split masks. This means that it provides native support for many computer vision tasks. The new solution is a more straightforward alternative: the

from  import transforms
# Exactly the same interface as V1:
trans = ([
    (contrast=0.5),
    (30),
    (480),
])
imgs, bboxes, labels = trans(imgs, bboxes, labels)

The new Transform Class can take any number of inputs without enforcing a specific order or structure:

# Already supported:
trans(imgs)  # Image Classification
trans(videos)  # Video Tasks
trans(imgs_or_videos, labels)  # MixUp/CutMix-style Transforms
trans(imgs, bboxes, labels)  # Object Detection
trans(imgs, bboxes, masks, labels)  # Instance Segmentation
trans(imgs, masks)  # Semantic Segmentation
trans({"image": imgs, "box": bboxes, "tag": labels})  # Arbitrary Structure
# Future support:
trans(imgs, bboxes, labels, keypoints)  # Keypoint Detection
trans(stereo_images, disparities, masks)  # Depth Perception
trans(image1, image2, optical_flows, masks)  # Optical Flow

functional API Updated,Supports all inputs necessary for signal processing kernel,as if resizing, cropping, affine transforms, padding et al. (and other authors):

from  import functional as F
# High-level dispatcher, accepts any supported input type, fully BC
(inpt, resize=[224, 224])
# Image tensor kernel
F.resize_image_tensor(img_tensor, resize=[224, 224], antialias=True)
# PIL image kernel
F.resize_image_pil(img_pil, resize=[224, 224], interpolation=BILINEAR)
# Video kernel
F.resize_video(video, resize=[224, 224], antialias=True)
# Mask kernel
F.resize_mask(mask, resize=[224, 224])
# Bounding box kernel
F.resize_bounding_box(bbox, resize=[224, 224], spatial_size=[256, 256])

The API uses Tensor subclassing to wrap the input, attach useful metadata, and dispatch to the correct kernel. Once the work associated with Datasets V2 utilizing the TorchData Data Pipe is complete, there is no longer a need to manually wrap the input. Currently, users can manually wrap inputs in the following ways:

from  import features
imgs = (images, color_space=)
vids = (videos, color_space=)
masks = (target["masks"])
bboxes = (target["boxes"], format=, spatial_size=imgs.spatial_size)
labels = (target["labels"], categories=["dog", "cat"])

In addition to the new API, PyTorch has provided important implementations of some of the data enhancements used in SoTA research, such as MixUp, CutMix, Large Scale Jitter, SimpleCopyPaste, AutoAugmentation methods, and some new Geometric, Colour and Type Conversion transforms.

The API continues to support PIL and Tensor backends for single image or batched input image and retains JIT-scriptability on the functional API.This allows the image mapping to be deferred from uint8 to float. brings a further increase in performance.

It is currently available in the prototype area of TorchVision and supports importing from the nightly build version.The new API has been verified to be accurate with the previous implementation.

Current limitations

The functional API (kernel) remains JIT-scriptable and fully-BC, while the Transform Class provides the same interface but is not scriptable.

This is because the Transform Class uses Tensor Subclassing and takes an arbitrary number of inputs, which is not supported by JIT. This limitation will be optimized in subsequent releases.

An End-to-End Demonstration

Below is an example of the new API that works with both PIL images and tensors.

Test image:

Code Example:

import PIL
from torchvision import io, utils
from  import features, transforms as T
from  import functional as F
# Defining and wrapping input to appropriate Tensor Subclasses
path = "COCO_val2014_000000418825.jpg"
img = (io.read_image(path), color_space=)
# img = (path)
bboxes = (
    [[2, 0, 206, 253], [396, 92, 479, 241], [328, 253, 417, 332],
     [148, 68, 256, 182], [93, 158, 170, 260], [432, 0, 438, 26],
     [422, 0, 480, 25], [419, 39, 424, 52], [448, 37, 456, 62],
     [435, 43, 437, 50], [461, 36, 469, 63], [461, 75, 469, 94],
     [469, 36, 480, 64], [440, 37, 446, 56], [398, 233, 480, 304],
     [452, 39, 463, 63], [424, 38, 429, 50]],
    format=,
    spatial_size=F.get_spatial_size(img),
)
labels = ([59, 58, 50, 64, 76, 74, 74, 74, 74, 74, 74, 74, 74, 74, 50, 74, 74])
# Defining and applying Transforms V2
trans = (
    [
        (contrast=0.5),
        (30),
        (480),
    ]
)
img, bboxes, labels = trans(img, bboxes, labels)
# Visualizing results
viz = utils.draw_bounding_boxes(F.to_image_tensor(img), boxes=bboxes)
F.to_pil_image(viz).show()

Above is the details of TorchVision Transforms API target detection example semantic segmentation video class, more information about TorchVision Transforms API, please pay attention to my other related articles!