#2 JAX course - Working with Neural Networks in JAX

How to join?
:link: Webinar Registration - Zoom

:books: What to expect
In this meetup, Sanyam will walk us through Neural Networks in JAX and we’ll understand a paper implementation done in JAX

:movie_camera:Find recording here

3 Likes

What is FLAX? What is the Linen API?

bfloat16 is supported on NVIDIA A100
source: https://moocaholic.medium.com/fp64-fp32-fp16-bfloat16-tf32-and-other-members-of-the-zoo-a1ca7897d407

1 Like

Can you provide an example where pytrees can be useful for jax? Why just not use python native structures?

1 Like

Towards the end of the session, I would like to have a mental map of where each pytorch ( fast-ai) and jax based nn modelling fit it. This isn’t a specific technical question on JAX, but would like to understand the ecosystem from your perspective.

1 Like

What was the @ flag?

1 Like

what is the features?

1 Like

It is like a list of no of neurons in every layer?

1 Like

Thanks for the comparison. Will take a detailed look at your references.

1 Like

Here are some of the relevant resources from today’s Lecture:

3 Likes