Google JAX o Just AfterExecutiones un marco desarrollado por Google para acelerar las tareas de aprendizaje automático.
Se puede considerar una biblioteca para Python, que ayuda en la ejecución más rápida de tareas, computación científica, transformaciones de funciones, aprendizaje profundo, redes neuronales y mucho más.
Acerca de Google JAX
El paquete de computación más fundamental en Python es el paquete NumPy, que tiene todas las funciones como agregaciones, operaciones vectoriales, álgebra lineal, manipulaciones de matrices y arrays n-dimensionales, y muchas otras funciones avanzadas.
¿Y si pudiéramos acelerar aún más los cálculos realizados con NumPy, sobre todo para conjuntos de datos enormes?
¿Tendríamos algo que pudiera funcionar igual de bien en distintos tipos de procesadores, como una GPU o una TPU, sin necesidad de modificar el código?
¿Y si el sistema pudiera realizar transformaciones de funciones componibles de forma automática y más eficiente?
Google JAX es una biblioteca (o marco de trabajo, como dice Wikipedia) que hace precisamente eso y quizá mucho más. Se construyó para optimizar el rendimiento y realizar de forma eficiente tareas de aprendizaje automático (ML) y aprendizaje profundo. Google JAX proporciona las siguientes características de transformación que la hacen única con respecto a otras bibliotecas de ML y ayudan en la computación científica avanzada para el aprendizaje profundo y las redes neuronales:
- Auto diferenciación
- Autovectorización
- Auto paralelización
- Compilación justo a tiempo (JIT)
Todas las transformaciones utilizan XLA (álgebra lineal acelerada) para un mayor rendimiento y optimización de la memoria. XLA es un motor compilador optimizador específico del dominio que realiza álgebra lineal y acelera los modelos TensorFlow. El uso de XLA sobre su código Python no requiere cambios significativos en el código
Exploremos en detalle cada una de estas características.
Características de Google JAX
Google JAX viene con importantes funciones de transformación componibles para mejorar el rendimiento y realizar tareas de aprendizaje profundo de forma más eficiente. Por ejemplo, la autodiferenciación para obtener el gradiente de una función y encontrar derivadas de cualquier orden. Del mismo modo, autoparalelización y JIT para realizar múltiples tareas de forma paralela. Estas transformaciones son clave para aplicaciones como la robótica, los juegos e incluso la investigación.
Una función de transformación componible es una función pura que transforma un conjunto de datos en otra forma. Se denominan componibles porque son autocontenidas (es decir, estas funciones no tienen dependencias con el resto del programa) y no tienen estado (es decir, la misma entrada siempre dará como resultado la misma salida).
Y(x) = T: (f(x))
En la ecuación anterior, f(x) es la función original sobre la que se aplica una transformación. Y(x) es la función resultante después de aplicar la transformación.
Por ejemplo, si tiene una función llamada ‘total_factura_amt’, y quiere el resultado como una función transformada, puede utilizar simplemente la transformación que desee, digamos gradiente (grad):
grad_factura_total = grad(factura_total_amt)
Al transformar funciones numéricas utilizando funciones como grad(), podemos obtener fácilmente sus derivadas de orden superior, que podemos utilizar ampliamente en algoritmos de optimización de aprendizaje profundo como el descenso de gradiente, haciendo así que los algoritmos sean más rápidos y eficientes. Del mismo modo, utilizando jit(), podemos compilar programas Python justo a tiempo (lazily).
#1. Autodiferenciación
Python utiliza la función autograd para diferenciar automáticamente el código NumPy y el nativo de Python. JAX utiliza una versión modificada de autograd (es decir, grad) y combina XLA (álgebra lineal acelerada) para realizar la diferenciación automática y encontrar derivadas de cualquier orden para GPU (unidades de procesamiento gráfico) y TPU (unidades de procesamiento tensorial)]
Nota rápida sobre TPU, GPU y CPU: La CPU o Unidad Central de Procesamiento gestiona todas las operaciones del ordenador. La GPU es un procesador adicional que aumenta la potencia de cálculo y ejecuta operaciones de gama alta. La TPU es una potente unidad desarrollada específicamente para cargas de trabajo complejas y pesadas como la IA y los algoritmos de aprendizaje profundo.
En la misma línea que la función autograd, que puede diferenciar a través de bucles, recursiones, ramas, etc., JAX utiliza la función grad() para gradientes de modo inverso (retropropagación). Además, podemos diferenciar una función a cualquier orden utilizando grad:
grad(grad(grad(sen θ))) (1.0)
Autodiferenciación de orden superior
Como hemos mencionado antes, grad es bastante útil para encontrar las derivadas parciales de una función. Podemos utilizar una derivada parcial para calcular el descenso de gradiente de una función de coste con respecto a los parámetros de la red neuronal en el aprendizaje profundo para minimizar las pérdidas.
Cálculo de la derivada parcial
Supongamos que una función tiene múltiples variables, x, y, y z. Encontrar la derivada de una variable manteniendo las otras variables constantes se llama derivada parcial. Supongamos que tenemos una función
f(x,y,z) = x 2y z2
Ejemplo para mostrar la derivada parcial
La derivada parcial de x será ∂f/∂x, que nos dice cómo cambia una función para una variable cuando las demás son constantes. Si realizamos esto manualmente, debemos escribir un programa para diferenciar, aplicarlo para cada variable y luego calcular el descenso del gradiente. Esto se convertiría en un asunto complejo y laborioso para múltiples variables.
La diferenciación automática descompone la función en un conjunto de operaciones elementales, como , -, *, / o sen, cos, tan, exp, etc., y luego aplica la regla de la cadena para calcular la derivada. Podemos hacerlo tanto en modo directo como inverso.
Esto no ¡es! Todos estos cálculos se realizan muy rápido (bueno, ¡piense en un millón de cálculos similares a los anteriores y el tiempo que puede llevar!) XLA se encarga de la velocidad y el rendimiento.
#2. Álgebra lineal acelerada
Tomemos la ecuación anterior. Sin XLA, el cálculo llevará tres (o más) núcleos, donde cada núcleo realizará una tarea más pequeña. Por ejemplo
Núcleo k1 –> x * 2y (multiplicación)
k2 –> x * 2y z (suma)
k3 –> Reducción
Si la misma tarea es realizada por el XLA, un único núcleo se encarga de todas las operaciones intermedias fusionándolas. Los resultados intermedios de las operaciones elementales se transmiten en flujo en lugar de almacenarlos en memoria, con lo que se ahorra memoria y se aumenta la velocidad.
#3. Compilación justo a tiempo
JAX utiliza internamente el compilador XLA para aumentar la velocidad de ejecución. XLA puede potenciar la velocidad de la CPU, la GPU y la TPU. Todo esto es posible utilizando la ejecución de código JIT. Para ello, podemos utilizar jit mediante import:
from jax import jit
def mi_funcion(x):
............algunas líneas de código
mi_funcion_jit = jit(mi_funcion)
Otra forma es decorando jit sobre la definición de la función:
@jit
def mi_funcion(x):
............algunas líneas de código
Este código es mucho más rápido porque la transformación devolverá la versión compilada del código a la persona que llama en lugar de utilizar el intérprete de Python. Esto es especialmente útil para entradas vectoriales, como matrices y arrays.
Lo mismo ocurre también con todas las funciones de Python existentes. Por ejemplo, las funciones del paquete NumPy. En este caso, deberíamos importar jax.numpy como jnp en lugar de NumPy:
import jax
importar jax.numpy como jnp
x = jnp.array([[1,2,3,4], [5,6,7,8]])
Una vez hecho esto, el objeto array central de JAX llamado DeviceArray sustituye al array estándar de NumPy. DeviceArray es perezoso – los valores se mantienen en el acelerador hasta que se necesiten. Esto también significa que el programa JAX no espera a que los resultados vuelvan al programa llamante (Python), siguiendo así un envío asíncrono.
#4. Vectorización automática (vmap)
En un mundo típico de aprendizaje automático, tenemos conjuntos de datos con un millón o más de puntos de datos. Lo más probable es que realicemos algunos cálculos o manipulaciones en cada uno o en la mayoría de estos puntos de datos, ¡lo cual es una tarea que consume mucho tiempo y memoria! Por ejemplo, si queremos hallar el cuadrado de cada uno de los puntos de datos del conjunto de datos, lo primero que se nos ocurriría es crear un bucle y sacar el cuadrado uno a uno… ¡argh!
Si creamos estos puntos como vectores, podríamos hacer todos los cuadrados de una sola vez realizando manipulaciones vectoriales o matriciales en los puntos de datos con nuestro NumPy favorito. Y si su programa pudiera hacer esto automáticamente – ¿se puede pedir algo más? ¡Eso es exactamente lo que hace JAX! Puede vectorizar automáticamente todos sus puntos de datos para que pueda realizar fácilmente cualquier operación sobre ellos – haciendo que sus algoritmos sean mucho más rápidos y eficientes.
JAX utiliza la función vmap para la autovectorización. Considere la siguiente matriz:
x = jnp.array([1,2,3,4,5,6,7,8,9,10])
y = jnp.cuadrado(x)
Haciendo sólo lo anterior, el método cuadrado se ejecutará para cada punto del array. Pero si hace lo siguiente
vmap(jnp.cuadrado(x))
El método cuadrado se ejecutará sólo una vez porque ahora los puntos de datos se vectorizan automáticamente utilizando el método vmap antes de ejecutar la función, y el bucle se empuja hacia abajo en el nivel elemental de operación – lo que resulta en una multiplicación matricial en lugar de una multiplicación escalar, dando así un mejor rendimiento.
#5. Programación SPMD (pmap)
SPMD – o Single Program Multiple Dataprogramming es esencial en contextos de aprendizaje profundo – a menudo se aplican las mismas funciones en diferentes conjuntos de datos que residen en múltiples GPUs o TPUs. JAX dispone de una función denominada bomba, que permite la programación paralela en múltiples GPU o cualquier acelerador. Al igual que JIT, los programas que utilicen pmap serán compilados por el XLA y ejecutados simultáneamente en todos los sistemas. Esta paralelización automática funciona tanto para los cálculos directos como para los inversos.
También podemos aplicar múltiples transformaciones de una sola vez en cualquier orden sobre cualquier función como:
pmap(vmap(jit(grad (f(x)))))
Múltiples transformaciones componibles
Limitaciones de Google JAX
Los desarrolladores de Google JAX han pensado muy bien en acelerar los algoritmos de aprendizaje profundo al introducir todas estas impresionantes transformaciones. Las funciones y paquetes de cálculo científico están en la línea de NumPy, por lo que no tiene que preocuparse por la curva de aprendizaje. Sin embargo, JAX tiene las siguientes limitaciones:
- Google JAX se encuentra aún en las primeras fases de desarrollo y, aunque su principal objetivo es la optimización del rendimiento, no aporta grandes ventajas para la computación en la CPU. NumPy parece ofrecer un mejor rendimiento, y el uso de JAX sólo puede aumentar la sobrecarga.
- JAX se encuentra todavía en su fase de investigación o inicial y necesita más ajustes para alcanzar los estándares de infraestructura de marcos como TensorFlow, que están más establecidos y cuentan con más modelos predefinidos, proyectos de código abierto y material de aprendizaje.
- Por ahora, JAX no es compatible con el sistema operativo Windows – necesitaría una máquina virtual para hacerlo funcionar.
- JAX sólo funciona con funciones puras – las que no tienen efectos secundarios. Para funciones con efectos secundarios, JAX puede no ser una buena opción.
Cómo instalar JAX en su entorno Python
Si tiene instalado python en su sistema y desea ejecutar JAX en su máquina local (CPU), utilice los siguientes comandos:
pip install --upgrade pip
pip install --upgrade "jax[cpu]"
Si desea ejecutar Google JAX en una GPU o TPU, siga las instrucciones indicadas en la página JAX de GitHub. Para instalar Python, visite la página oficial de descargas de python.
Conclusión
Google JAX es excelente para escribir algoritmos eficientes de aprendizaje profundo, robótica e investigación. A pesar de sus limitaciones, se utiliza ampliamente con otros frameworks como Haiku, Flax, y muchos más. Podrá apreciar lo que hace JAX cuando ejecute programas y vea las diferencias de tiempo en la ejecución de código con y sin JAX. Puede empezar leyendo la documentación oficial de Google sobre JAX, que es bastante completa.