SoFunction
Updated on 2024-11-14

Usage of tensordot in numpy

prologue or interlude in Yuan dynasty drama

There is a tensordot method in numpy that will be useful especially when doing machine learning. I guess when someone sees this name, they will think of tensorflow, yes tensorflow also has tensordot function inside. What this function does is that it allows two arrays of different dimensions to be multiplied. Let's take an example:

import numpy as np

a = (0, 9, (3, 4))
b = (0, 9, (4, 5))
try:
    print(a * b)
except Exception as e:
    print(e)  # operands could not be broadcast together with shapes (3,4) (4,5)

# It's clear that the two arrays a and b don't have the same dimensions, there's no way to multiply them #
# But #
print((a, b, 1))
"""
[[32 32 28 28 52]
 [10 25 40 38 78]
 [56  7 28  0 42]]
"""
# We see it's okay to use tensordot #

Let's see how this function is used

function prototype

@array_function_dispatch(_tensordot_dispatcher)
def tensordot(a, b, axes=2):

We see that this function receives three parameters, the first two are numpy in the array, the last parameter is used to specify the axis of contraction. It can receive an integer, a list, a list of nested lists, the specific meaning of what we represent the following examples.

Understanding axes

axes is an integer

If axes receives an integer: m, then it means that the last n axes of the specified array a and the first n axes of the array b are each subjected to an inner product, which is the multiplication of the corresponding positional elements, and then the overall sum.

import numpy as np

a = (60).reshape((3, 4, 5))
b = (160).reshape((4, 5, 8))

# Obviously these two arrays cannot be multiplied directly, but a and the last two axes and the first two axes of b can be multiplied directly
# Since they're both (4, 5), the resulting shape will be (3, 8) #
print((a, b, 2).shape)  # (3, 8)

And this axes defaults to 2, so it's generally for arrays of three or more dimensions

But in order to understand specifically, later we will use one-dimensional, two-dimensional data specific examples. Now first look at axes take different values, what results will be obtained, first understand the meaning of axes.

import numpy as np

a = (60).reshape((3, 4, 5))
b = (160).reshape((4, 5, 8))

try:
    print((a, b, 1).shape)
except Exception as e:
    print(e)  # shape-mismatch for sum
# The result is an error, it's easy to understand, it's just a shape mismatch #
# axes is specified as 1, which means that the latter axis of a is inner-producted with the former axis of b
# But one is 5 and one is 4, the elements don't match, so I get an error about shape-mismatch.

# Here we change the shape of array b so that the last axis of a and the first axis of b match, both being 5 #
a = (60).reshape((3, 4, 5))
b = (160).reshape((5, 4, 8))
print((a, b, 1).shape)  # (3, 4, 4, 8)
"""
That'll do the math.,We say specify the axis of contraction,Performing the inner product operation yields a value of
So here's(3, 4, 5)cap (a poem)(5, 4, 8)It's turned into(3, 4, 4, 8)

And the previous example was(3, 4, 5)cap (a poem)(4, 5, 8),after thataxes=2
on account ofa的后两个轴cap (a poem)b的前两个轴进行内积It's turned into一个具体的值,So the final dimension is(3, 8)
"""

What happens if axes is 0

import numpy as np

a = (60).reshape((3, 4, 5))
b = (160).reshape((4, 5, 8))

print((a, b, 0).shape)  # (3, 4, 5, 4, 5, 8)
print((b, a, 0).shape)  # (4, 5, 8, 3, 4, 5)
"""
(a, b, 0)equivalence (math.)aEach element in thebmultiply (math.)
Then the originalaby replacing the corresponding element in the
"""

The above operation can also be realized using the Einstein summation

axes=0

import numpy as np

a = (60).reshape((3, 4, 5))
b = (160).reshape((4, 5, 8))

c1 = (a, b, 0)
c2 = ("ijk,xyz->ijkxyz", a, b)
print(, )  # (3, 4, 5, 4, 5, 8) (3, 4, 5, 4, 5, 8)
print((c1 == c2))  # True
"""
The generated c1 and c2 are the same
"""

c3 = (b, a, 0)
c4 = ("ijk,xyz->xyzijk", a, b)
print(, )  # (4, 5, 8, 3, 4, 5) (4, 5, 8, 3, 4, 5)
print((c3 == c4))  # True
"""
The generated c3 and c4 are the same
"""

So which is better or worse between them in terms of efficiency? Let's test it on jupyter

>>> %timeit c1 = (a, b, 0)
50.5 µs ± 206 ns per loop
>>> %timeit c2 = ("ijk,xyz->ijkxyz", a, b)
7.29 µs ± 242 ns per loop

You can see that Einstein sums up a lot faster.

axes=1

import numpy as np

a = (60).reshape((3, 4, 5))
b = (160).reshape((5, 4, 8))

c1 = (a, b, 1)
c2 = ("ijk,kyz->ijyz", a, b)
print(, )  # (3, 4, 4, 8) (3, 4, 4, 8)
print((c1 == c2))  # True

axes=2

import numpy as np

a = (60).reshape((3, 4, 5))
b = (160).reshape((4, 5, 8))

c1 = (a, b, 2)
c2 = ("ijk,jkz->iz", a, b)
print(, )  # (3, 8) (3, 8)
print((c1 == c2))  # True

axes is a list

If axes receives a list:[m, n], then it means that let the m+1st of a(indexed by m)axis and the n+1st of b(indexed by n)The inner product is performed for the individual axes. The biggest advantage of using the list approach is that you can specify the axes at any position.

import numpy as np

a = (60).reshape((3, 4, 5))
b = (160).reshape((4, 5, 8))
# We see that the second dimension (or axis) of a and the first dimension of b are both 4, so they are inner-productable
c1 = (a, b, [1, 0])
# Since the result of the inner product is a scalar, the shape of (3, 4, 5) and (4, 5, 8) after tensordot is (3, 5, 5, 8)
# The equivalent of throwing away their respective 4's (because they become scalars) and combining them together
print()  # (3, 5, 5, 8)

# Similarly the last dimension of a and the second dimension of b are inner productable #
# The last dimension can also use -1, which is the same as taking the corresponding dimension at the index of the list
c2 = (a, b, [-1, 1])
print()  # (3, 4, 4, 8)

The above operation can also be realized using the Einstein summation

import numpy as np

a = (60).reshape((3, 4, 5))
b = (160).reshape((4, 5, 8))
c1 = (a, b, [1, 0])
c2 = ("ijk,jyz->ikyz", a, b)
print(, )  # (3, 5, 5, 8) (3, 5, 5, 8)
print((c1 == c2))  # True

c3 = (a, b, [-1, 1])
c4 = ("ijk,akz->ijaz", a, b)
print(, )  # (3, 4, 4, 8) (3, 4, 4, 8)
print((c3 == c4))  # True

axes is a list of nested lists

If axes receives a list of nested lists: [[m], [n]], it is the same as saying that more than one axis can be selected

import numpy as np

a = (60).reshape((3, 4, 5))
b = (160).reshape((4, 5, 8))
# We want the inner product of the last two axes of a and the first two axes of b #
c1 = (a, b, axes=2)
c2 = (a, b, [[1,2], [0,1]])
print(, )  # (3, 8) (3, 8)
print((c1 == c2))  # True

But using lists for filtering has the added benefit of ignoring order

import numpy as np

a = (60).reshape((4, 3, 5))
b = (160).reshape((4, 5, 8))
# It's not possible to pass integers to axes at this point #
c3 = (a, b, [[0, 2], [0, 1]])
print()  # (3, 8)

In addition, another powerful feature of using list filtering is the ability to take values backwards

import numpy as np

a = (60).reshape((4, 5, 3))
b = (160).reshape((5, 4, 8))

# This time we choose the first two axes, but one is (4, 5) and one is (5, 4), so we can't multiply them
# So when selecting you need to filter backwards: #
# [[0, 1], [1, 0]] -> (4, 5) and (4, 5) or [[1, 0], [0, 1]] -> (5, 4) and (5, 4)
c3 = (a, b, [[0, 1], [1, 0]])
print()  # (3, 8)

Finally the same look at how Einstein summation can be implemented to

import numpy as np

a = (60).reshape((4, 5, 3))
b = (160).reshape((4, 5, 8))

c1 = (a, b, [[0, 1], [0, 1]])
c2 = ("ijk,ijz->kz", a, b)
print(, )  # (3, 8) (3, 8)
print((c1 == c2))  # True


a = (60).reshape((4, 5, 3))
b = (160).reshape((5, 4, 8))

c1 = (a, b, [[0, 1], [1, 0]])
c2 = ("ijk,jiz->kz", a, b)
print(, )  # (3, 8) (3, 8)
print((c1 == c2))  # True


a = (60).reshape((4, 3, 5))
b = (160).reshape((5, 4, 8))

c1 = (a, b, [[0, 2], [1, 0]])
c2 = ("ijk,kiz->jz", a, b)
print(, )  # (3, 8) (3, 8)
print((c1 == c2))  # True

Take two one-dimensional arrays as an example

Let's take a look at tensordot by printing the concrete array

import numpy as np

a = ([1, 2, 3])
b = ([2, 3, 4])

print((a, b, axes=0))
"""
[[ 2  3  4]
 [ 4  6  8]
 [ 6  9 12]]
"""
print(("i,j->ij", a, b))
"""
[[ 2  3  4]
 [ 4  6  8]
 [ 6  9 12]]
"""

# We axes=0, which is equivalent to multiplying each element of a by the sum and then replacing the original a counterpart
# So it's 1 2 3 in a multiplied by b to get [2 3 4] [4 6 8] [6 9 12], and then 1 2 3 is replaced.
# So the result is [[2 3 4] [4 6 8] [6 9 12]]

What if axes=1?

import numpy as np

a = ([1, 2, 3])
b = ([2, 3, 4])

print((a, b, axes=1))  # 20
"""
Pick the first axis of a and the second axis of b for the inner product
And a and b have only one axis, so the result is a scalar
"""
print(("i,i->", a, b))  # 20

What if axes = 2? First we say that axes is equal to an integer, indicating that the last n axes of a and the first n axes of b are selected, whereas for one-dimensional arrays they have only one axis

import numpy as np

a = ([1, 2, 3])
b = ([2, 3, 4])

try:
    print((a, b, axes=2))  # 20
except Exception as e:
    print(e)  # tuple index out of range

Obviously indexing crosses the line.

Take a one-dimensional array and a two-dimensional array for example

Let's tensordot through a one-dimensional array and a two-dimensional array to get a feel for it!

axes=0

import numpy as np

a = ([1, 2, 3])
b = ([[2, 3, 4]])

print((a, b, 0))
"""
[[[ 2  3  4]]

 [[ 4  6  8]]
 
 [[ 6  9 12]]]
"""
print(("i,jk->ijk", a, b))
"""
[[[ 2  3  4]]

 [[ 4  6  8]]
 
 [[ 6  9 12]]]
"""
# It's easy to understand, it's 1 2 3 multiplied by [[2, 3, 4]] and replaced by 1 2 3 respectively.
print((a, b, 0).shape)  # (3, 1, 3)


##########################
print((b, a, 0))
"""
[[[ 2  4  6]
  [ 3  6  9]
  [ 4  8 12]]]
"""
print(("i,jk->jki", a, b))
"""
[[[ 2  4  6]
  [ 3  6  9]
  [ 4  8 12]]]
"""
# It's easy to understand, it's just 2 3 4 multiplied by [1 2 3] and replacing 2 3 4 respectively
print((b, a, 0).shape)  # (1, 3, 3)

What if axes=1?

import numpy as np

a = ([1, 2, 3])
b = ([[2, 3, 4], [4, 5, 6]])
try:
    print((a, b, 1))
except Exception as e:
    print(e)  # shape-mismatch for sum
# We notice that an error was reported because axes=1, which means take the last axis of a and the first 1 axis of b
# a's shape is (3, 0), so the lengths of the arrays corresponding to its latter and former axes are both 3
# But the previous axis of b corresponds to an array length of 2, which doesn't match, so it's an error.

print((b, a, 1))  # [20 32]
# We see that this one is OK, because it means that the latter axis of b, with an array length of 3, is a match #
# Let [2 3 4] and [4 5 6] of the latter axis be inner products with [1 2 3], respectively, to end up with two scalars

try:
    print(("i,ij->ij", a, b))
except Exception as e:
    print(e)
    # operands could not be broadcast together with remapped shapes [original->remapped]: (3,)->(3,newaxis) (2,3)->(2,3)

# The same can't be said for Einstein sums, we need to change the order #
print(("i,ji->j", a, b))  # [20 32]
# Or
print(("j,ij->i", a, b))  # [20 32]

What if axes=2?

import numpy as np

a = ([1, 2, 3])
b = ([[2, 3, 4], [4, 5, 6]])
try:
    print((a, b, 2))
except Exception as e:
    print(e)  # tuple index out of range
# We notice that the error is reported because axes=2, which means take the last two axes of a and the first two axes of b
# And a has only 1 axis in total, so it's an error #

try:
    print((b, a, 2))
except Exception as e:
    print(e)  # shape-mismatch for sum
# We see that although an error is also reported, it's not an index out of bounds.
# Because the above indicates that the first two axes of a are taken, even though there is only one a, it will not index out of bounds at this point, it will just take one. If it were to take the last two it would be out of bounds
# At this point b is (2, 3) and a is (3,) Mismatch, one might think a broadcast would happen, but not here

Take two two-dimensional arrays as an example

Let's tensordot through two more two-dimensional arrays to get a feel for it

axes=0

import numpy as np

a = ([[1, 2, 3]])
b = ([[2, 3, 4], [4, 5, 6]])

# a_shape: (1, 3) b_shape(3, 3)
print((a, b, 0))
"""
[[[[ 2  3  4]
   [ 4  5  6]]

  [[ 4  6  8]
   [ 8 10 12]]

  [[ 6  9 12]
   [12 15 18]]]]
"""
print(("ij,xy->ijxy", a, b))
"""
[[[[ 2  3  4]
   [ 4  5  6]]

  [[ 4  6  8]
   [ 8 10 12]]

  [[ 6  9 12]
   [12 15 18]]]]
"""
print((a, b, 0).shape)  # (1, 3, 2, 3)

#############
print((b, a, 0))
"""
[[[[ 2  4  6]]

  [[ 3  6  9]]

  [[ 4  8 12]]]


 [[[ 4  8 12]]

  [[ 5 10 15]]

  [[ 6 12 18]]]]
"""
print(("ij,xy->xyij", a, b))
"""
[[[[ 2  4  6]]

  [[ 3  6  9]]

  [[ 4  8 12]]]


 [[[ 4  8 12]]

  [[ 5 10 15]]

  [[ 6 12 18]]]]
"""
print((b, a, 0).shape)  # (2, 3, 1, 3)

axes=1

import numpy as np

a = ([[1, 2], [3, 4]])
b = ([[2, 3, 4], [4, 5, 6]])

# a_shape: (2, 2) b_shape(2, 3)
print((a, b, 1))
"""
[[10 13 16]
 [22 29 36]]
"""
print(("ij,jk->ik", a, b))
"""
[[10 13 16]
 [22 29 36]]
"""
# You must have noticed, if you're careful, that this is equivalent to the dot product of the matrix #
print(a @ b)
"""
[[10 13 16]
 [22 29 36]]
"""

axes=2

import numpy as np

a = ([[1, 2], [3, 4]])
b = ([[2, 3, 4], [4, 5, 6]])

# a_shape: (2, 2) b_shape(2, 3)

# Taking the last two axes obviously doesn't work because (2, 2) and (2, 3) don't match
try:
    print((a, b, 2))
except Exception as e:
    print(e)  # shape-mismatch for sum
    
a = ([[1, 2, 3], [2, 2, 2]])
b = ([[2, 3, 4], [4, 5, 6]])
print((a, b, 2))  # 50
print(("ij,ij->", a, b))  # 50    

Finally, look at the example of Einstein's summation and feel the difference between it and the protagonist tensordot, but of course you don't have to look at it if you're not familiar with Einstein's summation.

import numpy as np

a = (1, 9, (5, 3, 2, 3))
b = (1, 9, (3, 3, 2))

c1 = a @ b  # Multi-dimensional arrays, by default dot-multiplying the last two digits
c2 = ("ijkm,jmn->ijkn", a, b)
print((c1 == c2))  # True
print()  # (5, 3, 2, 2)
print(("...km,...mn->...kn", a, b).shape)  # (5, 3, 2, 2)

# But if it's
c3 = ("ijkm,amn->ijkn", a, b)
print()  # (5, 3, 2, 2)
# The two arrays are not the same even though the shapes are the same because the symbols are not the same
print((c3 == c1))  # False


a = (1, 9, (5, 3, 3, 2))
b = (1, 9, (1, 3, 2))

print(("ijmk,jmn->ijkn", a, b).shape)  # (5, 3, 2, 2)
print(("ijkm,jnm->ijkn", a, b).shape)  # (5, 3, 3, 3)

to this article on the use of numpy tensordot article is introduced to this, more related numpy tensordot content please search for my previous articles or continue to browse the following related articles I hope you will support me more in the future!