In this tutorial, you will learn how to use the **NumPy argmax() function** to find the index of the maximum element in arrays.

NumPy is a powerful library for scientific computing in Python; it provides N-dimensional arrays that are more performant than Python lists. One of the common operations you’ll perform when working with NumPy arrays is to find the maximum value in the array. However, you may sometimes want to find the **index** at which the maximum value occurs.

The `argmax()`

function helps you find the index of the maximum in both one-dimensional and multidimensional arrays. Let’s proceed to learn how it works.

## How to Find the Index of Maximum Element in a NumPy Array

To follow along with this tutorial, you need to have Python and NumPy installed. You can code along by starting a Python REPL or launching a Jupyter notebook.

First, let’s import NumPy under the usual alias `np`

.

`import numpy as np`

You can use the NumPy `max()`

function to get the maximum value in an array (optionally along a specific axis).

```
array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))
# Output
10
```

In this case, `np.max(array_1)`

returns 10, which is correct.

Suppose you’d like to find the index at which the maximum value occurs in the array. You can take the following two-step approach:

- Find the maximum element.
- Find the index of the maximum element.

In `array_1`

, the maximum value of 10 occurs at index 4, following zero indexing. The first element is at index 0; the second element is at index 1, and so on.

To find the index at which the maximum occurs, you can use the NumPy where() function. `np.where(condition)`

returns an array of all indices where the `condition`

is `True`

.

You’ll have to tap into the array and access the item at the first index. To find where the maximum value occurs, we set the `condition`

to `array_1==10`

; recall that 10 is the maximum value in `array_1`

.

```
print(int(np.where(array_1==10)[0]))
# Output
4
```

We have used `np.where()`

with *only* the condition, but this is *not* the recommended method to use this function.

📑

Note: NumPy where() Function:`np.where(condition,x,y)`

returns:– Elements from

`x`

when the condition is`True`

, and

– Elements from`y`

when the condition is`False`

.

Therefore, chaining the `np.max()`

and `np.where()`

functions, we can find the maximum element, followed by the index at which it occurs.

Instead of the above two-step process, you can use the NumPy argmax() function to get the index of the maximum element in the array.

## Syntax of the NumPy argmax() Function

The general syntax to use the NumPy argmax() function is as follows:

```
np.argmax(array,axis,out)
# we've imported numpy under the alias np
```

In the above syntax:

**array**is any valid NumPy array.**axis**is an optional parameter. When working with multidimensional arrays, you can use the axis parameter to find the index of maximum along a specific axis.**out**is another optional parameter. You can set the`out`

parameter to a NumPy array to store the output of the`argmax()`

function.

Note: From NumPy version 1.22.0, there’s an additional`keepdims`

parameter. When we specify the`axis`

parameter in the`argmax()`

function call, the array is reduced along that axis. But setting the`keepdims`

parameter to`True`

ensures that the returned output is of the same shape as the input array.

## Using NumPy argmax() to Find the Index of the Maximum Element

**#1**. Let us use the NumPy argmax() function to find the index of the maximum element in `array_1`

.

```
array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.argmax(array_1))
# Output
4
```

The `argmax()`

function returns 4, which is correct! ✅

**#2**. If we redefine `array_1`

such that 10 occurs twice, the `argmax()`

function returns *only* the index of the first occurrence.

```
array_1 = np.array([1,5,7,2,10,10,8,4])
print(np.argmax(array_1))
# Output
4
```

For the rest of the examples, we’ll use the elements of `array_1`

we defined in example #1.

### Using NumPy argmax() to Find the Index of the Maximum Element in a 2D Array

Let’s reshape the NumPy array `array_1`

into a two-dimensional array with two rows and four columns.

```
array_2 = array_1.reshape(2,4)
print(array_2)
# Output
[[ 1 5 7 2]
[10 9 8 4]]
```

For a two-dimensional array, axis 0 denotes the rows and axis 1 denotes the columns. NumPy arrays follow **zero-indexing**. So the indices of the rows and columns for the NumPy array `array_2`

are as follows:

Now, let’s call the `argmax()`

function on the two-dimensional array, `array_2`

.

```
print(np.argmax(array_2))
# Output
4
```

Even though we called `argmax()`

on the two-dimensional array, it still returns 4. This is identical to the output for the one-dimensional array, `array_1`

from the previous section.

**Why does this happen?** 🤔

This is because we have not specified any value for the axis parameter. When this axis parameter is not set, by default, the` argmax() `

function returns the index of the maximum element along the flattened array.

What is a flattened array?If there is an N-dimensional array of shaped1 x d2 x … x dN, where d1, d2, up to dN are the sizes of the array along the N dimensions, then theflattened arrayis a long one-dimensional array of size d1 * d2 * … * dN.

To check how the flattened array looks like for `array_2`

, you can call the `flatten()`

method, as shown below:

```
array_2.flatten()
# Output
array([ 1, 5, 7, 2, 10, 9, 8, 4])
```

### Index of the Maximum Element Along the Rows (axis = 0)

Let’s proceed to find the index of the maximum element along the rows (axis = 0).

```
np.argmax(array_2,axis=0)
# Output
array([1, 1, 1, 1])
```

This output can be a bit difficult to comprehend, but we’ll understand how it works.

We’ve set the `axis`

parameter to zero (`axis = 0`

), as we’d like to find the index of the maximum element along the rows. Therefore, the `argmax()`

function returns the row number in which the maximum element occurs—for each of the three columns.

Let’s visualize this for better understanding.

From the above diagram and the `argmax()`

output, we have the following:

- For the first column at index 0, the maximum value
**10**occurs in the second row, at index = 1. - For the second column at index 1, the maximum value
**9**occurs in the second row, at index = 1. - For the third and fourth columns at index 2 and 3, the maximum values
**8**and**4**both occur in the second row, at index = 1.

This is precisely why we have the output `array([1, 1, 1, 1])`

because the maximum element along the rows occurs in the second row (for all columns).

### Index of the Maximum Element Along the Columns (axis = 1)

Next, let’s use the `argmax()`

function to find the index of the maximum element along the columns.

Run the following code snippet and observe the output.

`np.argmax(array_2,axis=1)`

`array([2, 0])`

**Can you parse the output?**

We have set `axis = 1`

to compute the index of the maximum element along the columns.

The `argmax()`

function returns, for each row, the column number in which the maximum value occurs.

Here’s a visual explanation:

From the above diagram and the `argmax()`

output, we have the following:

- For the first row at index 0, the maximum value
**7**occurs in the third column, at index = 2. - For the second row at index 1, the maximum value
**10**occurs in the first column, at index = 0.

I hope you now understand what the output, `array([2, 0])`

means.

## Using the Optional out Parameter in NumPy argmax()

You can use the optional `out`

the parameter in the NumPy argmax() function to store the output in a NumPy array.

Let’s initialize an array of zeros to store the output of the previous `argmax()`

function call – to find the index of the maximum along the columns (`axis= 1`

).

```
out_arr = np.zeros((2,))
print(out_arr)
[0. 0.]
```

Now, let’s revisit the example of finding the index of the maximum element along the columns (`axis = 1`

) and set the `out`

to `out_arr`

we’ve defined above.

`np.argmax(array_2,axis=1,out=out_arr)`

We see that the Python interpreter throws a `TypeError`

, as the `out_arr`

was initialized to an array of floats by default.

```
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
56 try:
---> 57 return bound(*args, **kwds)
58 except TypeError:
TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
```

Therefore, when setting the `out`

parameter to the output array, it’s important to ensure that the output array is of the correct shape and data type. As array indices are always integers, we should set the `dtype`

parameter to `int`

when defining the output array.

```
out_arr = np.zeros((2,),dtype=int)
print(out_arr)
# Output
[0 0]
```

We can now go ahead and call the `argmax()`

function with both the `axis`

and `out`

parameters, and this time, it runs without error.

`np.argmax(array_2,axis=1,out=out_arr)`

The output of the `argmax()`

function can now be accessed in the array `out_arr`

.

```
print(out_arr)
# Output
[2 0]
```

### Conclusion

I hope this tutorial helped you understand how to use the NumPy argmax() function. You can run the code examples in a Jupyter notebook.

Let’s review what we’ve learned.

- The NumPy argmax() function returns the index of the maximum element in an array. If the maximum element occurs more than once in an array
**a**, then**np.argmax(a)**returns the index of the first occurrence of the element. - When working with multidimensional arrays, you can use the optional
**axis**parameter to get the index of the maximum element along a particular axis. For example, in a two-dimensional array: by setting**axis = 0**and**axis = 1**, you can get the index of the maximum element along the rows and columns, respectively. - If you’d like to store the returned value in another array, you can set the optional
**out**parameter to the output array. However, the output array should be of compatible shape and data type.

Next, check out the in-depth guide on Python sets.