Getting started with JAX (MLPs, CNNs & RNNs) 2020-03-16 · Learn the building blocks of JAX and use them to build some standard Deep Learning architectures (MLP, CNN, RNN, etc.). jaxxlaautogradtpututorialarticlecodenotebook
I will walk you through some exciting CS concepts which were new to me (I am not a computer engineer, so this will be an educational experience for you and me).
Along the process, we will go through the individual building blocks of JAX and use them to build some standard Deep Learning architectures (Multilayer Perceptron, Convolutional Neural Nets and a Gated Recurrent Unit RNN) from scratch.
In the process we will encounter the basic operators of JAX (jit, vmap, grad), dive deeper into stax - the sequential layer API of JAX - and use lax.scan to quickly compile the for-loop of an RNN.
Don't forget to tag
your comment, otherwise they may not be notified.