Getting started with JAX (MLPs, CNNs & RNNs)
Learn the building blocks of JAX and use them to build some standard Deep Learning architectures (MLP, CNN, RNN, etc.).
jax xla autograd tpu tutorial article code notebook

  • I will walk you through some exciting CS concepts which were new to me (I am not a computer engineer, so this will be an educational experience for you and me).
  • Along the process, we will go through the individual building blocks of JAX and use them to build some standard Deep Learning architectures (Multilayer Perceptron, Convolutional Neural Nets and a Gated Recurrent Unit RNN) from scratch.
  • In the process we will encounter the basic operators of JAX (jit, vmap, grad), dive deeper into stax - the sequential layer API of JAX - and use lax.scan to quickly compile the for-loop of an RNN.

Don't forget to tag @RobertTLange in your comment, otherwise they may not be notified.

Authors
PhD Fellow@TU Berlin. Intelligence. Biological&Artificial { 🐠 ♥️ 🦈}
Share this project
Similar projects
Solving Optimization Problems with JAX
JAX can be used to solve a range of simple to complex optimization problems with matrix methods.
SymJAX
A symbolic CPU/GPU/TPU programming
Convoluted Stuff
Optimising compilers, and how thousand-year-old math shaped deep learning.
Using JAX to Improve Separable Image Filters
Optimizing the filters to improve the filtered images for computer vision tasks.
Top collections