• Quickly recap a stateful LSTM-LM implementation in a tape-based gradient framework, specifically PyTorch. • See how PyTorch-style coding relies on mutating state, learn about mutation-free pure functions and build (pure) zappy one-liners in JAX. • Step-by-step go from individual parameters to medium-size modules by registering them as pytree nodes. • Combat growing pains by building fancy scaffolding, and controlling context to extract initialized parameters purify functions and realize that we could get that easily in a framework like DeepMind's haiku using its transform mechanism.
Don't forget to add the tag @sjmielke in your comments.