Getting started with JAX (MLPs, CNNs & RNNs)
Learn the building blocks of JAX and use them to build some standard Deep Learning architectures (MLP, CNN, RNN, etc.).
jax xla autograd tpu tutorial
Resource links
Top collections
Details
Objectives & Highlights

• I will walk you through some exciting CS concepts which were new to me (I am not a computer engineer, so this will be an educational experience for you and me). • Along the process, we will go through the individual building blocks of JAX and use them to build some standard Deep Learning architectures (Multilayer Perceptron, Convolutional Neural Nets and a Gated Recurrent Unit RNN) from scratch. • In the process we will encounter the basic operators of JAX (jit, vmap, grad), dive deeper into stax - the sequential layer API of JAX - and use lax.scan to quickly compile the for-loop of an RNN.

Don't forget to tag @RobertTLange in your comment.

Authors
PhD Fellow@TU Berlin. Intelligence. Biological&Artificial { 🐠 ♥️ 🦈}
Share this project
Similar projects
SymJAX
A symbolic CPU/GPU/TPU programming
Convoluted Stuff
Optimising compilers, and how thousand-year-old math shaped deep learning.
Using JAX to Improve Separable Image Filters
Optimizing the filters to improve the filtered images for computer vision tasks.
Implementing Graph Neural Networks with JAX
I’ll talk about my experience on how to build and train Graph Neural Networks (GNNs) with JAX.