Composable transformations of Python + NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more.

Overview

JAX: Accelerated Machine Learning Research
This talk will introduce JAX and its core function transformations with a live demo.
jax video scipy-2020 tutorial

Tutorials

Getting started with JAX (MLPs, CNNs & RNNs)
Learn the building blocks of JAX and use them to build some standard Deep Learning architectures (MLP, CNN, RNN, etc.).
jax xla autograd tpu
From PyTorch to JAX
Towards neural net frameworks that purify stateful code.
jax haiku tutorial article
Finetuning Transformers with JAX + Haiku
Walking through a port of the RoBERTa pre-trained model to JAX + Haiku, then fine-tuning the model to solve a downstream task.
jax haiku roberta transformers
Implementing Graph Neural Networks with JAX
I’ll talk about my experience on how to build and train Graph Neural Networks (GNNs) with JAX.
graph-neural-networks jax graphs tutorial

Libraries

General
Flax: Google’s Open Source Approach To Flexibility In ML
A gentle introduction to Flax: a neural network library for JAX that is designed for flexibility.
flax jax deep-learning library
Elegy
A Keras-like Deep Learning framework based on Jax + Haiku.
jax haiku keras article
SymJAX
A symbolic CPU/GPU/TPU programming
jax xla autograd symjax
Foolbox Native
A Python toolbox to create adversarial examples that fool neural networks in PyTorch, TensorFlow, and JAX
adversarial-learning adversarial-attacks pytorch tensorflow
Chex
Chex is a library of utilities for helping to write reliable JAX code.
jax chex unit-tests testing
Table of Contents
Share a project
Share something you or the community has made with ML.
Topic experts
Share