From PyTorch to JAX
Towards neural net frameworks that purify stateful code.
jax haiku
Objectives & Highlights

• Quickly recap a stateful LSTM-LM implementation in a tape-based gradient framework, specifically PyTorch. • See how PyTorch-style coding relies on mutating state, learn about mutation-free pure functions and build (pure) zappy one-liners in JAX. • Step-by-step go from individual parameters to medium-size modules by registering them as pytree nodes. • Combat growing pains by building fancy scaffolding, and controlling context to extract initialized parameters purify functions and realize that we could get that easily in a framework like DeepMind's haiku using its transform mechanism.

Don't forget to add the tag @sjmielke in your comments.

This project's author does not have a MWML account yet. If you are @sjmielke, then sign up to gain ownership of this project and edit this page.
Share this project
Similar projects
Using JAX to Improve Separable Image Filters
Optimizing the filters to improve the filtered images for computer vision tasks.
Flax: Google’s Open Source Approach To Flexibility In ML
A gentle introduction to Flax: a neural network library for JAX that is designed for flexibility.
Getting started with JAX (MLPs, CNNs & RNNs)
Learn the building blocks of JAX and use them to build some standard Deep Learning architectures (MLP, CNN, RNN, etc.).
You don't know JAX
In this tutorial, we'll cover each of these transformations in turn by demonstrating their use on one of the core problems of AGI: learning the Exclusive ...