JAX


Composable transformations of Python + NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more/

Overview

JAX: Accelerated Machine Learning Research
This talk will introduce JAX and its core function transformations with a live demo.
jax video scipy-2020 tutorial

Tutorials

Getting started with JAX (MLPs, CNNs & RNNs)
Learn the building blocks of JAX and use them to build some standard Deep Learning architectures (MLP, CNN, RNN, etc.).
jax xla autograd tpu
From PyTorch to JAX
Towards neural net frameworks that purify stateful code.
jax haiku tutorial article
Implementing Graph Neural Networks with JAX
I’ll talk about my experience on how to build and train Graph Neural Networks (GNNs) with JAX.
graph-neural-networks jax graphs tutorial

Libraries

Flax: Google’s Open Source Approach To Flexibility In ML
A gentle introduction to Flax: a neural network library for JAX that is designed for flexibility.
flax jax deep-learning library
SymJAX
A symbolic CPU/GPU/TPU programming
jax xla autograd symjax
Foolbox Native
A Python toolbox to create adversarial examples that fool neural networks in PyTorch, TensorFlow, and JAX
adversarial-learning adversarial-attacks pytorch tensorflow
Table of Contents
Share a resource
Share a resource you found useful for this specific topic.
Topic experts
Share