Finetuning Transformers with JAX + Haiku
Walking through a port of the RoBERTa pre-trained model to JAX + Haiku, then fine-tuning the model to solve a downstream task.
jax haiku roberta transformers fine-tuning natural-language-processing pretraining tutorial article code notebook
Links
Details

• This post will be code-oriented and will usually show code examples first before providing commentary. * We're going to be working in a top-down fashion, so we'll lay out our Transformer model in broad strokes and then fill in the detail. * I'll introducing Haiku's features as they're needed for our Transformer finetuning project.

Top collections

Don't forget to tag @madisonmay in your comment.

Authors
Machine Learning Architect at @IndicoDataSolutions
Share this project
Similar projects
From PyTorch to JAX
Towards neural net frameworks that purify stateful code.
Flax: Google’s Open Source Approach To Flexibility In ML
A gentle introduction to Flax: a neural network library for JAX that is designed for flexibility.
Using JAX to Improve Separable Image Filters
Optimizing the filters to improve the filtered images for computer vision tasks.
Lagrangian Neural Networks
Trying to learn a simulation? Try Lagrangian Neural Networks, which explicitly conserve energy and may generalize better!