Getting started with JAX (MLPs, CNNs & RNNs)
Learn the building blocks of JAX and use them to build some standard Deep Learning architectures (MLP, CNN, RNN, etc.).
jax xla autograd tpu tutorial
Objectives & Highlights

• I will walk you through some exciting CS concepts which were new to me (I am not a computer engineer, so this will be an educational experience for you and me). • Along the process, we will go through the individual building blocks of JAX and use them to build some standard Deep Learning architectures (Multilayer Perceptron, Convolutional Neural Nets and a Gated Recurrent Unit RNN) from scratch. • In the process we will encounter the basic operators of JAX (jit, vmap, grad), dive deeper into stax - the sequential layer API of JAX - and use lax.scan to quickly compile the for-loop of an RNN.

Don't forget to add the tag @RobertTLange in your comments.

PhD Fellow@TU Berlin. Intelligence. Biological&Artificial { 🐠 ♥️ 🦈}
Share this project
Similar projects
Flax: Google’s Open Source Approach To Flexibility In ML
A gentle introduction to Flax: a neural network library for JAX that is designed for flexibility.
From PyTorch to JAX
Towards neural net frameworks that purify stateful code.
Using JAX to Improve Separable Image Filters
Optimizing the filters to improve the filtered images for computer vision tasks.
Finetuning Transformers with JAX + Haiku
Walking through a port of the RoBERTa pre-trained model to JAX + Haiku, then fine-tuning the model to solve a downstream task.