In this tutorial, you will learn how to use the NumPy argmax() function to find the index of the maximum element in arrays.
NumPy is a powerful library for scientific computing in Python; it provides N-dimensional arrays that are more performant than Python lists. One of the common operations you’ll perform when working with NumPy arrays is to find the maximum value in the array. However, you may sometimes want to find the index at which the maximum value occurs.
The argmax()
function helps you find the index of the maximum in both one-dimensional and multidimensional arrays. Let’s proceed to learn how it works.
How to Find the Index of Maximum Element in a NumPy Array
To follow along with this tutorial, you need to have Python and NumPy installed. You can code along by starting a Python REPL or launching a Jupyter notebook.
First, let’s import NumPy under the usual alias np
.
import numpy as np
You can use the NumPy max()
function to get the maximum value in an array (optionally along a specific axis).
array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))
# Output
10
In this case, np.max(array_1)
returns 10, which is correct.
Suppose you’d like to find the index at which the maximum value occurs in the array. You can take the following two-step approach:
- Find the maximum element.
- Find the index of the maximum element.
In array_1
, the maximum value of 10 occurs at index 4, following zero indexing. The first element is at index 0; the second element is at index 1, and so on.
To find the index at which the maximum occurs, you can use the NumPy where() function. np.where(condition)
returns an array of all indices where the condition
is True
.
You’ll have to tap into the array and access the item at the first index. To find where the maximum value occurs, we set the condition
to array_1==10
; recall that 10 is the maximum value in array_1
.
print(int(np.where(array_1==10)[0]))
# Output
4
We have used np.where()
with only the condition, but this is not the recommended method to use this function.
📑 Note: NumPy where() Function:
np.where(condition,x,y)
returns:– Elements from
x
when the condition isTrue
, and
– Elements fromy
when the condition isFalse
.
Therefore, chaining the np.max()
and np.where()
functions, we can find the maximum element, followed by the index at which it occurs.
Instead of the above two-step process, you can use the NumPy argmax() function to get the index of the maximum element in the array.
Syntax of the NumPy argmax() Function
The general syntax to use the NumPy argmax() function is as follows:
np.argmax(array,axis,out)
# we've imported numpy under the alias np
In the above syntax:
- array is any valid NumPy array.
- axis is an optional parameter. When working with multidimensional arrays, you can use the axis parameter to find the index of maximum along a specific axis.
- out is another optional parameter. You can set the
out
parameter to a NumPy array to store the output of theargmax()
function.
Note: From NumPy version 1.22.0, there’s an additional
keepdims
parameter. When we specify theaxis
parameter in theargmax()
function call, the array is reduced along that axis. But setting thekeepdims
parameter toTrue
ensures that the returned output is of the same shape as the input array.
Using NumPy argmax() to Find the Index of the Maximum Element
#1. Let us use the NumPy argmax() function to find the index of the maximum element in array_1
.
array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.argmax(array_1))
# Output
4
The argmax()
function returns 4, which is correct! ✅
#2. If we redefine array_1
such that 10 occurs twice, the argmax()
function returns only the index of the first occurrence.
array_1 = np.array([1,5,7,2,10,10,8,4])
print(np.argmax(array_1))
# Output
4
For the rest of the examples, we’ll use the elements of array_1
we defined in example #1.
Using NumPy argmax() to Find the Index of the Maximum Element in a 2D Array
Let’s reshape the NumPy array array_1
into a two-dimensional array with two rows and four columns.
array_2 = array_1.reshape(2,4)
print(array_2)
# Output
[[ 1 5 7 2]
[10 9 8 4]]
For a two-dimensional array, axis 0 denotes the rows and axis 1 denotes the columns. NumPy arrays follow zero-indexing. So the indices of the rows and columns for the NumPy array array_2
are as follows:
Now, let’s call the argmax()
function on the two-dimensional array, array_2
.
print(np.argmax(array_2))
# Output
4
Even though we called argmax()
on the two-dimensional array, it still returns 4. This is identical to the output for the one-dimensional array, array_1
from the previous section.
Why does this happen? 🤔
This is because we have not specified any value for the axis parameter. When this axis parameter is not set, by default, the argmax()
function returns the index of the maximum element along the flattened array.
What is a flattened array? If there is an N-dimensional array of shape d1 x d2 x … x dN, where d1, d2, up to dN are the sizes of the array along the N dimensions, then the flattened array is a long one-dimensional array of size d1 * d2 * … * dN.
To check how the flattened array looks like for array_2
, you can call the flatten()
method, as shown below:
array_2.flatten()
# Output
array([ 1, 5, 7, 2, 10, 9, 8, 4])
Index of the Maximum Element Along the Rows (axis = 0)
Let’s proceed to find the index of the maximum element along the rows (axis = 0).
np.argmax(array_2,axis=0)
# Output
array([1, 1, 1, 1])
This output can be a bit difficult to comprehend, but we’ll understand how it works.
We’ve set the axis
parameter to zero (axis = 0
), as we’d like to find the index of the maximum element along the rows. Therefore, the argmax()
function returns the row number in which the maximum element occurs—for each of the three columns.
Let’s visualize this for better understanding.
From the above diagram and the argmax()
output, we have the following:
- For the first column at index 0, the maximum value 10 occurs in the second row, at index = 1.
- For the second column at index 1, the maximum value 9 occurs in the second row, at index = 1.
- For the third and fourth columns at index 2 and 3, the maximum values 8 and 4 both occur in the second row, at index = 1.
This is precisely why we have the output array([1, 1, 1, 1])
because the maximum element along the rows occurs in the second row (for all columns).
Index of the Maximum Element Along the Columns (axis = 1)
Next, let’s use the argmax()
function to find the index of the maximum element along the columns.
Run the following code snippet and observe the output.
np.argmax(array_2,axis=1)
array([2, 0])
Can you parse the output?
We have set axis = 1
to compute the index of the maximum element along the columns.
The argmax()
function returns, for each row, the column number in which the maximum value occurs.
Here’s a visual explanation:
From the above diagram and the argmax()
output, we have the following:
- For the first row at index 0, the maximum value 7 occurs in the third column, at index = 2.
- For the second row at index 1, the maximum value 10 occurs in the first column, at index = 0.
I hope you now understand what the output, array([2, 0])
means.
Using the Optional out Parameter in NumPy argmax()
You can use the optional out
the parameter in the NumPy argmax() function to store the output in a NumPy array.
Let’s initialize an array of zeros to store the output of the previous argmax()
function call – to find the index of the maximum along the columns (axis= 1
).
out_arr = np.zeros((2,))
print(out_arr)
[0. 0.]
Now, let’s revisit the example of finding the index of the maximum element along the columns (axis = 1
) and set the out
to out_arr
we’ve defined above.
np.argmax(array_2,axis=1,out=out_arr)
We see that the Python interpreter throws a TypeError
, as the out_arr
was initialized to an array of floats by default.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
56 try:
---> 57 return bound(*args, **kwds)
58 except TypeError:
TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
Therefore, when setting the out
parameter to the output array, it’s important to ensure that the output array is of the correct shape and data type. As array indices are always integers, we should set the dtype
parameter to int
when defining the output array.
out_arr = np.zeros((2,),dtype=int)
print(out_arr)
# Output
[0 0]
We can now go ahead and call the argmax()
function with both the axis
and out
parameters, and this time, it runs without error.
np.argmax(array_2,axis=1,out=out_arr)
The output of the argmax()
function can now be accessed in the array out_arr
.
print(out_arr)
# Output
[2 0]
Conclusion
I hope this tutorial helped you understand how to use the NumPy argmax() function. You can run the code examples in a Jupyter notebook.
Let’s review what we’ve learned.
- The NumPy argmax() function returns the index of the maximum element in an array. If the maximum element occurs more than once in an array a, then np.argmax(a) returns the index of the first occurrence of the element.
- When working with multidimensional arrays, you can use the optional axis parameter to get the index of the maximum element along a particular axis. For example, in a two-dimensional array: by setting axis = 0 and axis = 1, you can get the index of the maximum element along the rows and columns, respectively.
- If you’d like to store the returned value in another array, you can set the optional out parameter to the output array. However, the output array should be of compatible shape and data type.
Next, check out the in-depth guide on Python sets. Also learn how to use the Python Sleep Function to add delays to your code.