- I will walk you through some exciting CS concepts which were new to me (I am not a computer engineer, so this will be an educational experience for you and me).
- Along the process, we will go through the individual building blocks of JAX and use them to build some standard Deep Learning architectures (Multilayer Perceptron, Convolutional Neural Nets and a Gated Recurrent Unit RNN) from scratch.
- In the process we will encounter the basic operators of JAX (jit, vmap, grad), dive deeper into stax - the sequential layer API of JAX - and use lax.scan to quickly compile the for-loop of an RNN.

Don't forget to tag @RobertTLange in your comment, otherwise they may not be notified.

PhD Fellow@TU Berlin. Intelligence. Biological&Artificial { 🐠 ♥️ 🦈}

SymJAX

A symbolic CPU/GPU/TPU programming

Solving Optimization Problems with JAX

2020-05-25 ·
JAX can be used to solve a range of simple to complex optimization problems with matrix methods.

Convoluted Stuff

2020-05-17 ·
Optimising compilers, and how thousand-year-old math shaped deep learning.

PyGLN: Gated Linear Network implementations (NumPy, PT/TF, JAX)

Gated Linear Network implementations for NumPy, PyTorch, TensorFlow and JAX.