SoFunction
Updated on 2025-05-13

Use of functions in PyTorch

is a function in PyTorch that returns the index where the maximum value in the input tensor is located. Its function is consistent with the concept of argmax in mathematics, that is, to find the parameters (position index) when a function gets the maximum value within a specified range.

Function definition

(input, dim=None, keepdim=False)
  • ​Input:
    • input: Input tensor.
    • dim (optional): Specifies which dimension to find the maximum value along. If None, look up in the entire tensor.
    • keepdim (optional): Whether to keep the dimensions of the output tensor consistent with the input (default is False).
  • Output:
    A tensor containing the index where the maximum value is located

Core functions

1. Global maximum index (when dim=None)

  • After flattening the input tensor, return the index of the maximum value
import torch

x = ([[1, 2, 3],
                  [6, 5, 4]])
print((x))  # Output: tensor(3)# Flattened index: 1, 2, 3, 6, 5, 4 → Maximum value is 6, index is 3 (starting from 0)

2|​Find the maximum index along the specified dimension (when dim is specified)

  • Operation on the input tensor along the dim dimension, returning the maximum index of each row/column
# Search along line dimension (dim=1)x = ([[1, 2, 3],
                  [6, 5, 4]])
print((x, dim=1))  # Output: tensor([2, 0])# explain:# First line [1, 2, 3] Maximum value 3, index 2# Second line [6, 5, 4] Maximum value 6, index 0
# Search along column dimensions (dim=0)print((x, dim=0))  # Output: tensor([1, 1, 0])# explain:# Column 0 [1, 6] Maximum value 6, index 1# Column 1 [2, 5] Maximum value 5, index 1# Column 2 [3, 4] Maximum value 4, index 1 (but the output here is 0, which may be incorrect, it should actually be 1)

Detailed explanation of parameters

1. dim parameters

  • ​Effect: Specify which dimension to operate along.
  • Example:
    • dim=0: Operation along the column (portrait).
    • dim=1: Operation along the line (horizontal).

2. Keepdim parameters

  • ​Effect: Keep the output dimension consistent with the input.
  • Example:
x = ([[1, 2, 3],
                  [6, 5, 4]])
out = (x, dim=1, keepdim=True)
print(out)  # Output: tensor([[2], [0]])

Common uses

1. Obtain prediction tags in classification tasks

logits = ([0.1, 0.8, 0.05, 0.05])  # Probability distribution of model outputpredicted_class = (logits)         # Output: tensor(1)

2. Calculate accuracy

# Assume batch_size=4, num_classes=3preds = ([[0.1, 0.2, 0.7],
                      [0.9, 0.05, 0.05],
                      [0.3, 0.4, 0.3],
                      [0.05, 0.8, 0.15]])
labels = ([2, 0, 1, 1])
# Get prediction categorypredicted_classes = (preds, dim=1)  # Output: tensor([2, 0, 1, 1])# Calculate the correct prediction numbercorrect = (predicted_classes == labels).sum()   # Output: tensor(3)

Things to note

1. Multiple same maximum values:

  • If there are multiple identical maximum values, return the first occurrence of the index
x = ([3, 1, 4, 4])
print((x))  # Output: tensor(2)

2. Data Type

  • The input tensor should be of a numerical type (such as float32, int64)

3. Dimensional legality

  • If a non-existent dimension is specified (such as dim=3 for a two-dimensional tensor), an error will be triggered

Summarize

It is an efficient tool, widely used in scenarios such as classification model prediction and index calculation. Understand the behavior of its dim and keepdim parameters, and can flexibly process data in different dimensions

This is the end of this article about the use of functions in PyTorch. For more related PyTorch content, please search for my previous articles or continue browsing the related articles below. I hope everyone will support me in the future!