How to join?
Webinar Registration - Zoom
What to expect
In this meetup, Sanyam will walk us through Neural Networks in JAX and we’ll understand a paper implementation done in JAX
Find recording here
How to join?
Webinar Registration - Zoom
What to expect
In this meetup, Sanyam will walk us through Neural Networks in JAX and we’ll understand a paper implementation done in JAX
Find recording here
What is FLAX? What is the Linen API?
bfloat16 is supported on NVIDIA A100
source: https://moocaholic.medium.com/fp64-fp32-fp16-bfloat16-tf32-and-other-members-of-the-zoo-a1ca7897d407
Can you provide an example where pytrees can be useful for jax? Why just not use python native structures?
Towards the end of the session, I would like to have a mental map of where each pytorch ( fast-ai) and jax based nn modelling fit it. This isn’t a specific technical question on JAX, but would like to understand the ecosystem from your perspective.
What was the @ flag?
what is the features?
It is like a list of no of neurons in every layer?
Thanks for the comparison. Will take a detailed look at your references.
Here are some of the relevant resources from today’s Lecture:
Here is a link to the AutoDiff Notebook from Akash: TF_JAX_Tutorials - Part 9 (Autodiff in JAX) | Kaggle
Here’s the PyTree Notebook from Akash: TF_JAX_tutorials - Part 10 (Pytrees in JAX) | Kaggle
Here’s the link to “Flax Basics” from the Flax docs: Google Colab
Here’s the link to “Annotated MNIST” from the Flax docs: Google Colab
Here’s a link to the UNET lecture from last month: PyTorch Book Reading - 8. U-net, Image Segmentation and Image Augmentations in PyTorch - YouTube
Here’s the Scenic implementation of U-net: scenic/unet.py at main · google-research/scenic · GitHub