SoFunction
Updated on 2024-11-19

Example of a Fourier Convolution Implementation in PyTorch

convolution

Convolution is ubiquitous in data analysis. They have been used in signal and image processing for decades. More recently, they have become an important part of modern neural networks. If you deal with data, you may encounter intricate problems.

Mathematically, the convolution is denoted as:

Although discrete convolution is more common in computational applications, I'll be using the continuous form for most of this article because it's much easier to prove the convolution theorem (discussed below) using continuous variables. After that, we'll return to the discrete case and implement it in PyTorch using the Fourier transform. The discrete convolution can be thought of as an approximation of the continuous convolution, where the continuous function is discretized on a regular grid. Therefore, we will not re-prove the convolution theorem for this discrete case.

convolution theorem (math.)

Mathematically, the convolution theorem can be described like this:

where the continuous Fourier transform is (up to a normalization constant):

In other words, convolution in position space is equivalent to direct product in frequency space. The idea is rather unintuitive, but proving the convolution theorem is surprisingly easy for the continuous case. To do this, first write the left side of the equation.

Now switch the order of integration, replace the variable (x = y + z) , and separate the two product functions.

Why should we care about any of this?

This is because the algorithmic complexity of the Fast Fourier Transform is lower than the convolution. Direct convolution operations have complexity O(n^2) because in f we pass each element in g, so the fast Fourier transform can be computed in O(nlogn) time. They are much faster than convolution when the input array is large. In these cases, we can use the convolution theorem to compute the convolution in frequency space and then perform the inverse Fourier transform back to position space.

When the input is small (e.g. 3x3 convolution kernels), direct convolution is still faster. In machine learning applications, it is more common to use small kernels, so deep learning libraries like PyTorch and Tensorflow only provide implementations of direct convolution. But there are many real-world use cases with large kernels where Fourier convolution algorithms are more efficient.

PyTorch Implementation

Now, I will demonstrate how to implement a Fourier convolution function in PyTorch. It should mimic the functionality of and utilize fft without requiring any additional work from the user. As such, it should accept three Tensors (signal, kernel, and optionally bias) and padding applied to the input.Conceptually, the inner workings of this function are:

def fft_conv(
  signal: Tensor, kernel: Tensor, bias: Tensor = None, padding: int = 0,
) -> Tensor:
  # 1. Pad the input signal & kernel tensors
  # 2. Compute FFT for both signal & kernel
  # 3. Multiply the transformed Tensors together
  # 4. Compute inverse FFT
  # 5. Add bias and return

Let's build the FFT convolution step-by-step in the order of operations shown above. For this example, I'll construct a one-dimensional Fourier convolution, but extending it to two- and three-dimensional convolutions is easy.

1. Filling the input array

We need to make sure that the signal and kernel have the same size after padding. Apply the initial padding to the signal and then adjust the padding of the kernel to match.

# 1. Pad the input signal & kernel tensors
signal = (signal, [padding, padding])
kernel_padding = [0, (-1) - (-1)]
padded_kernel = (kernel, kernel_padding)

Note that I'm only populating the kernel on one side. we want the original kernel to be on the left side of the populated array so it can be aligned with the start of the signal array.

2. Computing the Fourier transform

This is very simple because n-dimensional fft is already implemented in PyTorch. We simply use the built-in function and compute the FFT along the last dimension of each tensor.

# 2. Perform fourier convolution
signal_fr = rfftn(signal, dim=-1)
kernel_fr = rfftn(padded_kernel, dim=-1)

3. Transformation tensor multiplication

Surprisingly, this is the most complex part of our function. There are two reasons for this. (1) PyTorch convolution runs on multidimensional tensors, so our signal and kernel tensors are actually three-dimensional. From this equation in the PyTorch documentation, we can see that matrix multiplication runs on the first two dimensions (excluding the bias term):

We will need to include this matrix multiplication, as well as a direct multiplication of the transformed dimensions.

PyTorch actually implements a cross-correlation/value method instead of a convolutional method. (The same is true of TensorFlow and other deep learning libraries.) Mutual correlation is closely related to convolution, but with one important sign change:

This effectively reverses the orientation of the kernel (g) compared to convolution. Rather than manually flipping the kernel, we correct for this by utilizing the conjugate complex of the kernel in Fourier space. Since we do not need to create an entirely new Tensor, this is significantly faster and more memory efficient. (A brief explanation of how this approach works is in the appendix at the end of this paper.)

# 3. Multiply the transformed matrices
 
def complex_matmul(a: Tensor, b: Tensor) -> Tensor:
  """Multiplies two complex-valued tensors."""
  # Scalar matrix multiplication of two tensors, over only the first two dimensions.
  # Dimensions 3 and higher will have the same shape after multiplication.
  scalar_matmul = partial(, "ab..., cb... -> ac...") 
 
  # Compute the real and imaginary parts independently, then manually insert them
  # into the output Tensor. This is fairly hacky but necessary for PyTorch 1.7.0,
  # because Autograd is not enabled for complex matrix operations yet. Not exactly
  # idiomatic PyTorch code, but it should work for all future versions (>= 1.7.0).
  real = scalar_matmul(, ) - scalar_matmul(, )
  imag = scalar_matmul(, ) + scalar_matmul(, )
  c = (, dtype=torch.complex64)
  ,  = real, imag
  return c 

# Conjugate the kernel for cross-correlation
kernel_fr.imag *= -1
output_fr = complex_matmul(signal_fr, kernel_fr)

PyTorch 1.7 improved support for complex numbers, but many operations on complex tensors are not yet supported in autograd. For now, we had to write our own complex matmul method as a patch. It's not ideal, but it does work and will not be a problem in future versions.

4. Calculation of the inverse transformation

Use You can directly compute the inverse transformation and then crop out the extra array padding.

# 4. Compute inverse FFT, and remove extra padded values
output = irfftn(output_fr, dim=-1)
output = output[:, :, :(-1) - (-1) + 1]

5. Add paranoia and return

Adding bias terms is also easy. Remember that for each channel in the output array, the bias term has an element and adjusts its shape accordingly.

# 5. Optionally, add a bias term before returning.
if bias is not None:
  output += (1, -1, 1)

Putting the above code together

For completeness, let's compile all these code snippets into one cohesive function.

def fft_conv_1d(
  signal: Tensor, kernel: Tensor, bias: Tensor = None, padding: int = 0,
) -> Tensor:
  """
  Args:
    signal: (Tensor) Input tensor to be convolved with the kernel.
    kernel: (Tensor) Convolution kernel.
    bias: (Optional, Tensor) Bias tensor to add to the output.
    padding: (int) Number of zero samples to pad the input on the last dimension.
  Returns:
    (Tensor) Convolved tensor
  """
  # 1. Pad the input signal & kernel tensors
  signal = (signal, [padding, padding])
  kernel_padding = [0, (-1) - (-1)]
  padded_kernel = (kernel, kernel_padding)
 
  # 2. Perform fourier convolution
  signal_fr = rfftn(signal, dim=-1)
  kernel_fr = rfftn(padded_kernel, dim=-1)
 
  # 3. Multiply the transformed matrices
  kernel_fr.imag *= -1
  output_fr = complex_matmul(signal_fr, kernel_fr)
 
  # 4. Compute inverse FFT, and remove extra padded values
  output = irfftn(output_fr, dim=-1)
  output = output[:, :, :(-1) - (-1) + 1]
 
  # 5. Optionally, add a bias term before returning.
  if bias is not None:
    output += (1, -1, 1)
 
 
  return output

Direct Convolution Test

Finally, we will use .conv1d to confirm that this is numerically equivalent to direct one-dimensional convolution. We construct random tensors for all inputs and measure the relative difference in output values.

import torch
import  as f 
 
torch.manual_seed(1234)
kernel = (2, 3, 1025)
signal = (3, 3, 4096)
bias = (2)
 
y0 = f.conv1d(signal, kernel, bias=bias, padding=512)
y1 = fft_conv_1d(signal, kernel, bias=bias, padding=512)
 
abs_error = (y0 - y1)
print(f'\nAbs Error Mean: {abs_error.mean():.3E}')
print(f'Abs Error Std Dev: {abs_error.std():.3E}')
 
# Abs Error Mean: 1.272E-05

Considering that we are using 32-bit precision, each element differs by about 1e-5ー quite accurately! Let's also perform a quick benchmark to measure the speed of each method:

from timeit import timeit
direct_time = timeit(
  "f.conv1d(signal, kernel, bias=bias, padding=512)", 
  globals=locals(), 
  number=100
) / 100
fourier_time = timeit(
  "fft_conv_1d(signal, kernel, bias=bias, padding=512)", 
  globals=locals(), 
  number=100
) / 100
print(f"Direct time: {direct_time:.3E} s")
print(f"Fourier time: {fourier_time:.3E} s")
 
# Direct time: 1.523E-02 s
# Fourier time: 1.149E-03 s

The baseline for the measurements will change significantly with the machine you are using. (I'm testing with a very old Macbook Pro.) For the 1025 kernel, Fourier convolution seems to be over 10 times faster.

summarize

I hope this has provided a thorough introduction to Fourier Convolution. I think it's a really cool trick and there are a lot of applications in the real world that can use it. I also love math, so it's interesting to see the combination of programming and pure math. All comments and constructive criticism are welcome and encouraged, so please applaud if you enjoyed this post!

Appendix:

Convolution vs.

Earlier in this paper, we did this by obtaining the reciprocal conjugate complex of the kernel in Fourier space. This actually reverses the direction of the kernel, and now I want to demonstrate why this is the case. First, remember the formulas for convolution and mutual correlation:

Then, let's look at the Fourier transform of g(x):

Note that g(x) is real-valued, so it is not affected by changes in the conjugate complex. Then, change the variable (y = -x) and simplify the expression.

To this point this article on PyTorch Fourier Convolution in the implementation of the example of the article is introduced to this, more related to PyTorch Fourier Convolution content, please search for my previous articles or continue to browse the following related articles I hope that you will support me in the future more!