Google JAX ou Just Aaprès Execution est un framework développé par Google pour accélérer les tâches d'apprentissage automatique.

Vous pouvez le considérer comme une bibliothèque pour Python, qui aide à accélérer l'exécution des tâches, le calcul scientifique, les transformations de fonctions, l'apprentissage en profondeur, les réseaux de neurones et bien plus encore.

About Google JAX

Le package de calcul le plus fondamental en Python est le package NumPy qui possède toutes les fonctions telles que les agrégations, les opérations vectorielles, l'algèbre linéaire, les manipulations de tableaux et de matrices à n dimensions et de nombreuses autres fonctions avancées.

Et si nous pouvions encore accélérer les calculs effectués à l'aide de NumPy, en particulier pour les énormes ensembles de données ?

Avons-nous quelque chose qui pourrait fonctionner aussi bien sur différents types de processeurs comme un GPU ou un TPU, sans aucun changement de code ?

Et si le système pouvait effectuer des transformations de fonctions composables automatiquement et plus efficacement ?

Google JAX est une bibliothèque (ou un framework, comme le dit Wikipedia) qui fait exactement cela et peut-être bien plus encore. Il a été conçu pour optimiser les performances et effectuer efficacement des tâches d'apprentissage automatique (ML) et d'apprentissage en profondeur. Google JAX fournit les fonctionnalités de transformation suivantes qui le rendent unique par rapport aux autres bibliothèques ML et aident au calcul scientifique avancé pour l'apprentissage en profondeur et les réseaux de neurones :

  • Différenciation automatique
  • Vectorisation automatique
  • Parallélisation automatique
  • Compilation juste à temps (JIT)
Fonctionnalités uniques de Google JAX

Toutes les transformations utilisent XLA (Accelerated Linear Algebra) pour de meilleures performances et une optimisation de la mémoire. XLA est un moteur de compilateur d'optimisation spécifique à un domaine qui effectue de l'algèbre linéaire et accélère les modèles TensorFlow. L'utilisation de XLA en plus de votre code Python ne nécessite aucune modification significative du code !

Explorons en détail chacune de ces fonctionnalités.

Features of Google JAX

Google JAX est livré avec d'importantes fonctions de transformation composables pour améliorer les performances et effectuer des tâches d'apprentissage en profondeur plus efficacement. Par exemple, la différenciation automatique pour obtenir le gradient d'une fonction et trouver des dérivées de n'importe quel ordre. De même, la parallélisation automatique et le JIT permettent d'effectuer plusieurs tâches en parallèle. Ces transformations sont essentielles pour des applications telles que la robotique, les jeux et même la recherche.

A fonction de transformation composable est un pur fonction qui transforme un ensemble de données en une autre forme. Ils sont appelés composables car ils sont autonomes (c'est-à-dire que ces fonctions n'ont aucune dépendance avec le reste du programme) et sont sans état (c'est-à-dire que la même entrée entraînera toujours la même sortie).

Y(x) = T : (f(x))

Dans l'équation ci-dessus, f(x) est la fonction d'origine sur laquelle une transformation est appliquée. Y(x) est la fonction résultante après l'application de la transformation.

Par exemple, si vous avez une fonction nommée 'total_bill_amt' et que vous voulez que le résultat soit une transformation de fonction, vous pouvez simplement utiliser la transformation que vous souhaitez, disons gradient (grad):

grad_total_bill = grad(total_bill_amt)

En transformant des fonctions numériques à l'aide de fonctions telles que grad (), nous pouvons facilement obtenir leurs dérivés d'ordre supérieur, que nous pouvons largement utiliser dans des algorithmes d'optimisation d'apprentissage en profondeur tels que la descente de gradient, rendant ainsi les algorithmes plus rapides et plus efficaces. De même, en utilisant jit(), nous pouvons compiler des programmes Python juste-à-temps (paresseusement).

# 1. Différenciation automatique

Python utilise la fonction autograd pour différencier automatiquement le code NumPy du code Python natif. JAX utilise une version modifiée d'autograd (c'est-à-dire grad) et combine XLA (Accelerated Linear Algebra) pour effectuer une différenciation automatique et trouver des dérivés de n'importe quel ordre pour GPU (Graphic Processing Units) et TPU (Tensor Processing Units).]

Remarque rapide sur le TPU, le GPU et le CPU : CPU ou Central Processing Unit gère toutes les opérations sur l'ordinateur. Le GPU est un processeur supplémentaire qui améliore la puissance de calcul et exécute des opérations haut de gamme. Le TPU est une unité puissante spécialement développée pour les charges de travail complexes et lourdes telles que l'IA et les algorithmes d'apprentissage en profondeur.

Dans le même esprit que la fonction autograd, qui peut se différencier par des boucles, des récursions, des branches, etc., JAX utilise la fonction grad() pour les gradients en mode inverse (rétropropagation). De plus, nous pouvons différencier une fonction de n'importe quel ordre en utilisant grad :

grad(grad(grad(sin θ))) (1.0)

Différenciation automatique d'ordre supérieur

Comme nous l'avons mentionné précédemment, grad est très utile pour trouver les dérivées partielles d'une fonction. Nous pouvons utiliser une dérivée partielle pour calculer la descente de gradient d'une fonction de coût par rapport aux paramètres du réseau de neurones en apprentissage profond afin de minimiser les pertes.

Calcul de la dérivée partielle

Supposons qu'une fonction ait plusieurs variables, x, y et z. Trouver la dérivée d'une variable en gardant les autres variables constantes s'appelle une dérivée partielle. Supposons que nous ayons une fonction,

f(x,y,z) = x + 2y + z2

Exemple pour montrer la dérivée partielle

La dérivée partielle de x sera ∂f/∂x, ce qui nous indique comment une fonction change pour une variable lorsque les autres sont constantes. Si nous effectuons cela manuellement, nous devons écrire un programme pour différencier, l'appliquer pour chaque variable, puis calculer la descente du gradient. Cela deviendrait une affaire complexe et chronophage pour plusieurs variables.

La différenciation automatique décompose la fonction en un ensemble d'opérations élémentaires, comme +, -, *, / ou sin, cos, tan, exp, etc., puis applique la règle de la chaîne pour calculer la dérivée. Nous pouvons le faire en mode avant et arrière.

C'est pas ce! Tous ces calculs se produisent si vite (eh bien, pensez à un million de calculs similaires à ceux ci-dessus et au temps que cela peut prendre !). XLA prend soin de la vitesse et des performances.

# 2. Algèbre linéaire accélérée

Reprenons l'équation précédente. Sans XLA, le calcul prendra trois (ou plus) noyaux, où chaque noyau effectuera une tâche plus petite. Par exemple,

Noyau k1 –> x * 2y (multiplication)

k2 –> x * 2y + z (addition)

k3 –> Réduction

Si la même tâche est effectuée par le XLA, un seul noyau prend en charge toutes les opérations intermédiaires en les fusionnant. Les résultats intermédiaires des opérations élémentaires sont diffusés en continu au lieu de les stocker en mémoire, économisant ainsi de la mémoire et améliorant la vitesse.

# 3. Compilation juste à temps

JAX utilise en interne le compilateur XLA pour augmenter la vitesse d'exécution. XLA peut augmenter la vitesse du CPU, du GPU et du TPU. Tout cela est possible en utilisant l'exécution du code JIT. Pour l'utiliser, nous pouvons utiliser jit via import :

from jax import jit
def my_function(x):
	…………some lines of code
my_function_jit = jit(my_function)

Une autre méthode consiste à décorer jit sur la définition de la fonction :

@jit
def my_function(x):
	…………some lines of code

Ce code est beaucoup plus rapide car la transformation renverra la version compilée du code à l'appelant plutôt que d'utiliser l'interpréteur Python. Ceci est particulièrement utile pour les entrées vectorielles, comme les tableaux et les matrices.

Il en va de même pour toutes les fonctions python existantes. Par exemple, les fonctions du package NumPy. Dans ce cas, nous devrions importer jax.numpy en tant que jnp plutôt que NumPy :

import jax
import jax.numpy as jnp

x = jnp.array([[1,2,3,4], [5,6,7,8]])

Une fois que vous avez fait cela, l'objet de tableau JAX principal appelé DeviceArray remplace le tableau NumPy standard. DeviceArray est paresseux - les valeurs sont conservées dans l'accélérateur jusqu'à ce qu'elles soient nécessaires. Cela signifie également que le programme JAX n'attend pas que les résultats reviennent au programme appelant (Python), suivant ainsi une répartition asynchrone.

# 4. Vectorisation automatique (vmap)

Dans un monde typique d'apprentissage automatique, nous avons des ensembles de données avec un million de points de données ou plus. Très probablement, nous effectuerions des calculs ou des manipulations sur chacun ou la plupart de ces points de données - ce qui est une tâche très consommatrice de temps et de mémoire ! Par exemple, si vous voulez trouver le carré de chacun des points de données dans l'ensemble de données, la première chose à laquelle vous pensez est de créer une boucle et de prendre le carré un par un - argh !

Si nous créons ces points sous forme de vecteurs, nous pourrions faire tous les carrés en une seule fois en effectuant des manipulations vectorielles ou matricielles sur les points de données avec notre NumPy préféré. Et si votre programme pouvait le faire automatiquement, que demander de plus ? C'est exactement ce que fait JAX ! Il peut vectoriser automatiquement tous vos points de données afin que vous puissiez facilement effectuer toutes les opérations dessus, ce qui rend vos algorithmes beaucoup plus rapides et plus efficaces.

JAX utilise la fonction vmap pour la vectorisation automatique. Considérez le tableau suivant :

x = jnp.array([1,2,3,4,5,6,7,8,9,10])
y = jnp.square(x)

En faisant juste ce qui précède, la méthode square s'exécutera pour chaque point du tableau. Mais si vous procédez comme suit :

vmap(jnp.square(x))

La méthode square ne s'exécutera qu'une seule fois car les points de données sont désormais vectorisés automatiquement à l'aide de la méthode vmap avant d'exécuter la fonction, et la boucle est poussée au niveau élémentaire de fonctionnement - ce qui entraîne une multiplication matricielle plutôt qu'une multiplication scalaire, offrant ainsi de meilleures performances. .

# 5. Programmation SPMD (pmap)

SPMD – ou SIngle Programme Mmultiple DLa programmation ata est essentielle dans les contextes d'apprentissage en profondeur - vous appliquez souvent les mêmes fonctions sur différents ensembles de données résidant sur plusieurs GPU ou TPU. JAX a une fonction nommée pompe, qui permet une programmation parallèle sur plusieurs GPU ou n'importe quel accélérateur. Comme JIT, les programmes utilisant pmap seront compilés par XLA et exécutés simultanément sur tous les systèmes. Cette parallélisation automatique fonctionne à la fois pour les calculs directs et inverses.

Comment fonctionne pmap

Nous pouvons également appliquer plusieurs transformations en une seule fois dans n'importe quel ordre sur n'importe quelle fonction comme :

pmap(vmap(jit(grad (f(x)))))

Multiples transformations composables

Limitations of Google JAX

Les développeurs de Google JAX ont bien pensé à accélérer l'apprentissage en profondeur algorithmes tout en introduisant toutes ces transformations impressionnantes. Les fonctions et packages de calcul scientifique sont sur le modèle de NumPy, vous n'avez donc pas à vous soucier de la courbe d'apprentissage. Cependant, JAX présente les limitations suivantes :

  • Google JAX en est encore aux premiers stades de développement, et bien que son objectif principal soit l'optimisation des performances, il n'offre pas beaucoup d'avantages pour le calcul CPU. NumPy semble mieux fonctionner et l'utilisation de JAX ne peut qu'ajouter à la surcharge.
  • JAX en est encore à ses débuts ou à ses débuts et a besoin de plus de précision pour atteindre les normes d'infrastructure de frameworks comme TensorFlow, qui sont plus établis et ont plus de modèles prédéfinis, de projets open source et de matériel d'apprentissage.
  • Pour l'instant, JAX ne prend pas en charge le système d'exploitation Windows - vous auriez besoin d'une machine virtuelle pour le faire fonctionner.
  • JAX ne fonctionne que sur les fonctions pures - celles qui n'ont pas d'effets secondaires. Pour les fonctions avec des effets secondaires, JAX peut ne pas être une bonne option.

How to install JAX in your Python environment

Si vous avez une configuration python sur votre système et que vous souhaitez exécuter JAX sur votre machine locale (CPU), utilisez les commandes suivantes :

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

Si vous souhaitez exécuter Google JAX sur un GPU ou un TPU, suivez les instructions données sur GitHubJAX page. Pour configurer Python, visitez le téléchargements officiels de python Venez regardez des photos heureuses et inspirantes.

Conclusion

Google JAX est idéal pour écrire des algorithmes d'apprentissage en profondeur efficaces, de la robotique et de la recherche. Malgré les limitations, il est largement utilisé avec d'autres frameworks comme Haiku, Flax et bien d'autres. Vous pourrez apprécier ce que JAX fait lorsque vous exécutez des programmes et voir les différences de temps dans l'exécution du code avec et sans JAX. Vous pouvez commencer par lire le documentation officielle de Google JAX, ce qui est assez complet.