In diesem Tutorial erfahren Sie, wie Sie die verwenden NumPy argmax()-Funktion um den Index des maximalen Elements in Arrays zu finden.
NumPy ist eine leistungsstarke Bibliothek für wissenschaftliches Rechnen in Python; Es bietet N-dimensionale Arrays, die leistungsfähiger sind als Python-Listen. Eine der häufigsten Operationen, die Sie beim Arbeiten mit NumPy-Arrays ausführen, besteht darin, den Maximalwert im Array zu finden. Möglicherweise möchten Sie jedoch manchmal die finden Index bei dem der Maximalwert auftritt.
Die argmax()
Funktion hilft Ihnen, den Index des Maximums sowohl in eindimensionalen als auch in mehrdimensionalen Arrays zu finden. Lassen Sie uns fortfahren, um zu lernen, wie es funktioniert.
How to Find the Index of Maximum Element in a NumPy Array
Um diesem Tutorial folgen zu können, benötigen Sie Python und NumPy Eingerichtet. Sie können mitcodieren, indem Sie eine Python-REPL starten oder ein Jupyter-Notebook starten.
Importieren wir zunächst NumPy unter dem üblichen Alias np
.
import numpy as np
Sie können NumPy verwenden max()
Funktion, um den Maximalwert in einem Array (optional entlang einer bestimmten Achse) zu erhalten.
array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))
# Output
10
In diesem Fall np.max(array_1)
gibt 10 zurück, was richtig ist.
Angenommen, Sie möchten den Index finden, an dem der maximale Wert im Array auftritt. Sie können den folgenden zweistufigen Ansatz wählen:
- Finden Sie das maximale Element.
- Finden Sie den Index des maximalen Elements.
In array_1
, tritt der Maximalwert von 10 bei Index 4 nach der Nullindizierung auf. Das erste Element hat den Index 0; das zweite Element befindet sich auf Index 1 und so weiter.

Um den Index zu finden, bei dem das Maximum auftritt, können Sie die verwenden NumPy wo() Funktion. np.where(condition)
gibt ein Array aller Indizes zurück, in denen die condition
is True
.
Sie müssen auf das Array tippen und auf das Element am ersten Index zugreifen. Um herauszufinden, wo der Maximalwert auftritt, setzen wir die condition
zu array_1==10
; Denken Sie daran, dass 10 der maximale Wert in ist array_1
.
print(int(np.where(array_1==10)[0]))
# Output
4
Wir haben benutzt np.where()
mit einzige die Bedingung, aber das ist nicht die empfohlene Methode zur Verwendung dieser Funktion.
📑 Hinweis: NumPy where() Funktion:
np.where(condition,x,y)
kehrt zurück:– Elemente aus
x
wenn die Bedingung istTrue
und
– Elemente ausy
wenn die Bedingung istFalse
.
Daher Verkettung der np.max()
und np.where()
Funktionen können wir das maximale Element finden, gefolgt von dem Index, an dem es auftritt.
Anstelle des obigen zweistufigen Prozesses können Sie die NumPy-Funktion argmax() verwenden, um den Index des größten Elements im Array abzurufen.
Syntax of the NumPy argmax() Function
Die allgemeine Syntax zur Verwendung der Funktion NumPy argmax() lautet wie folgt:
np.argmax(array,axis,out)
# we've imported numpy under the alias np
In der obigen Syntax:
- Array ist ein beliebiges gültiges NumPy-Array.
- Achse ist ein optionaler Parameter. Wenn Sie mit mehrdimensionalen Arrays arbeiten, können Sie den Achsenparameter verwenden, um den Index des Maximums entlang einer bestimmten Achse zu finden.
- ist ein weiterer optionaler Parameter. Sie können die einstellen
out
-Parameter in ein NumPy-Array, um die Ausgabe von zu speichernargmax()
Funktion.
Hinweis: Ab NumPy-Version 1.22.0 gibt es eine zusätzliche
keepdims
Parameter. Wenn wir die angebenaxis
Parameter in derargmax()
Funktionsaufruf wird das Array entlang dieser Achse reduziert. Aber die Einstellungkeepdims
ParameterTrue
stellt sicher, dass die zurückgegebene Ausgabe die gleiche Form wie das Eingabearray hat.
Using NumPy argmax() to Find the Index of the Maximum Element
#1. Lassen Sie uns die Funktion NumPy argmax() verwenden, um den Index des größten Elements in zu finden array_1
.
array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.argmax(array_1))
# Output
4
Die argmax()
Die Funktion gibt 4 zurück, was richtig ist! ✅
#2. Wenn wir neu definieren array_1
so dass 10 zweimal vorkommt, die argmax()
Funktion kehrt zurück einzige der Index des ersten Vorkommens.
array_1 = np.array([1,5,7,2,10,10,8,4])
print(np.argmax(array_1))
# Output
4
Für die restlichen Beispiele verwenden wir die Elemente von array_1
wir haben in Beispiel #1 definiert.
Verwenden von NumPy argmax() zum Finden des Index des maximalen Elements in einem 2D-Array
Lasst uns Forme das NumPy-Array um array_1
in ein zweidimensionales Array mit zwei Zeilen und vier Spalten.
array_2 = array_1.reshape(2,4)
print(array_2)
# Output
[[ 1 5 7 2]
[10 9 8 4]]
Bei einem zweidimensionalen Array bezeichnet die Achse 0 die Zeilen und die Achse 1 die Spalten. NumPy-Arrays folgen Null-Indizierung. Also die Indizes der Zeilen und Spalten für das NumPy-Array array_2
sind wie folgt:

Rufen wir jetzt die an argmax()
Funktion auf dem zweidimensionalen Array, array_2
.
print(np.argmax(array_2))
# Output
4
Obwohl wir angerufen haben argmax()
auf dem zweidimensionalen Array gibt es immer noch 4 zurück. Dies ist identisch mit der Ausgabe für das eindimensionale Array, array_1
aus dem vorherigen Abschnitt.
Warum geschieht das? 🤔
Dies liegt daran, dass wir keinen Wert für den Achsenparameter angegeben haben. Wenn dieser Achsenparameter nicht gesetzt ist, wird standardmäßig die argmax()
Die Funktion gibt den Index des größten Elements entlang des abgeflachten Arrays zurück.
Was ist ein abgeflachtes Array? Wenn es ein N-dimensionales Array von Formen gibt d1 x d2 x … x dN, wobei d1, d2 bis dN die Größen des Arrays entlang der N Dimensionen sind, dann die abgeflachtes Array ist ein langes eindimensionales Array der Größe d1 * d2 * … * dN.
Um zu überprüfen, wie das abgeflachte Array für aussieht array_2
, können Sie die anrufen flatten()
Methode, wie unten gezeigt:
array_2.flatten()
# Output
array([ 1, 5, 7, 2, 10, 9, 8, 4])
Index des maximalen Elements entlang der Zeilen (Achse = 0)
Lassen Sie uns fortfahren, um den Index des maximalen Elements entlang der Zeilen zu finden (Achse = 0).
np.argmax(array_2,axis=0)
# Output
array([1, 1, 1, 1])
Diese Ausgabe kann etwas schwierig zu verstehen sein, aber wir werden verstehen, wie sie funktioniert.
Wir haben das eingestellt axis
Parameter auf Null (axis = 0
), da wir den Index des maximalen Elements entlang der Zeilen finden möchten. deshalb, die argmax()
Die Funktion gibt die Zeilennummer zurück, in der das maximale Element vorkommt – für jede der drei Spalten.
Lassen Sie uns dies zum besseren Verständnis visualisieren.

Aus dem obigen Diagramm und der argmax()
Ausgabe haben wir folgendes:
- Für die erste Spalte bei Index 0 der Maximalwert 10 tritt in der zweiten Zeile bei Index = 1 auf.
- Für die zweite Spalte bei Index 1 der Maximalwert 9 tritt in der zweiten Zeile bei Index = 1 auf.
- Für die dritte und vierte Spalte bei Index 2 und 3 die Maximalwerte 8 und 4 beide treten in der zweiten Reihe bei Index = 1 auf.
Genau dafür haben wir die Ausgabe array([1, 1, 1, 1])
weil das maximale Element entlang der Zeilen in der zweiten Zeile auftritt (für alle Spalten).
Index des maximalen Elements entlang der Spalten (Achse = 1)
Als nächstes verwenden wir die argmax()
Funktion, um den Index des maximalen Elements entlang der Spalten zu finden.
Führen Sie das folgende Code-Snippet aus und beobachten Sie die Ausgabe.
np.argmax(array_2,axis=1)
array([2, 0])
Kannst du die Ausgabe parsen?
Wir haben eingestellt axis = 1
um den Index des maximalen Elements entlang der Spalten zu berechnen.
Die argmax()
Die Funktion gibt für jede Zeile die Spaltennummer zurück, in der der Maximalwert auftritt.
Hier ist eine visuelle Erklärung:

Aus dem obigen Diagramm und der argmax()
Ausgabe haben wir folgendes:
- Für die erste Zeile bei Index 0 der Maximalwert 7 tritt in der dritten Spalte bei Index = 2 auf.
- Für die zweite Zeile bei Index 1 der Maximalwert 10 tritt in der ersten Spalte bei Index = 0 auf.
Ich hoffe, Sie verstehen jetzt, was die Ausgabe ist, array([2, 0])
bedeutet.
Using the Optional out Parameter in NumPy argmax()
Sie können optional verwenden out
der Parameter in der Funktion NumPy argmax(), um die Ausgabe in einem NumPy-Array zu speichern.
Lassen Sie uns ein Array von Nullen initialisieren, um die Ausgabe des vorherigen zu speichern argmax()
Funktionsaufruf – um den Index des Maximums entlang der Spalten zu finden (axis= 1
).
out_arr = np.zeros((2,))
print(out_arr)
[0. 0.]
Betrachten wir nun das Beispiel zum Suchen des Index des maximalen Elements entlang der Spalten (axis = 1
) und stellen Sie die out
zu out_arr
wir haben oben definiert.
np.argmax(array_2,axis=1,out=out_arr)
Wir sehen, dass der Python-Interpreter a wirft TypeError
, Da die out_arr
wurde standardmäßig mit einem Array von Gleitkommazahlen initialisiert.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
56 try:
---> 57 return bound(*args, **kwds)
58 except TypeError:
TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
Daher beim Einstellen der out
-Parameter an das Ausgabearray, ist es wichtig sicherzustellen, dass das Ausgabearray die richtige Form und den richtigen Datentyp hat. Da Array-Indizes immer Ganzzahlen sind, sollten wir die setzen dtype
Parameter int
beim Definieren des Ausgabearrays.
out_arr = np.zeros((2,),dtype=int)
print(out_arr)
# Output
[0 0]
Wir können jetzt weitermachen und anrufen argmax()
Funktion sowohl mit der axis
und out
Parameter, und dieses Mal läuft es ohne Fehler.
np.argmax(array_2,axis=1,out=out_arr)
Die Ausgabe der argmax()
Auf die Funktion kann jetzt im Array zugegriffen werden out_arr
.
print(out_arr)
# Output
[2 0]
Fazit
Ich hoffe, dieses Tutorial hat Ihnen geholfen, die NumPy-Funktion argmax() zu verstehen. Sie können die Codebeispiele in a ausführen Jupyter Notizbuch.
Lassen Sie uns überprüfen, was wir gelernt haben.
- Die Funktion NumPy argmax() gibt den Index des größten Elements in einem Array zurück. Wenn das maximale Element mehr als einmal in einem Array vorkommt a und dann np.argmax(a) gibt den Index des ersten Vorkommens des Elements zurück.
- Wenn Sie mit mehrdimensionalen Arrays arbeiten, können Sie die optionale verwenden Achse -Parameter, um den Index des größten Elements entlang einer bestimmten Achse zu erhalten. Zum Beispiel in einem zweidimensionalen Array: by setting Achse = 0 und Achse = 1, können Sie den Index des maximalen Elements entlang der Zeilen bzw. Spalten abrufen.
- Wenn Sie den zurückgegebenen Wert in einem anderen Array speichern möchten, können Sie optional festlegen Parameter in das Ausgabearray. Das Ausgabearray sollte jedoch eine kompatible Form und einen kompatiblen Datentyp haben.
Sehen Sie sich als Nächstes die ausführliche Anleitung an Python-Sets.