#4 JAX Course - JAX Implementation of DALL·E. - Generate images from a prompt

:raised_hands::fire:We’re thrilled to host Boris Dayma an MLE at W&B

He will be teaching us about his work on implementing Dall-e in JAX :man_teacher:

How to Join?
:link:Register here to join live

:movie_camera:Recording: JAX Course - 4. JAX Implementation of DALL·E - YouTube

3 Likes

For the inference pipeline: i missed it but why do you generate these many sample candidate prompts?

2 Likes

Training Pipeline seems to be VQGAN Encoder, BART then Cross entropy loss. What part of this is Trained?
Was VQGAN Encoder - Trained?
Inference Pipeline:
VQGAN Decoder seems to be only during inference. Is the Decoder or CLIP Trained?

1 Like

Training is clear. But how is it connected to Inference? How do you generate multiple candidate images from one BART Decoder Output?
Is the ranking / CLIP part learned?

1 Like

Thanks for the detailed explanation on Generation output. What are possible values for the Outputs - 0-255 in 3 channels?

How do you rank the images? Compare the embeddings of input text to generated image?

1 Like

Nice idea to concat a bart encoding of the text, with the fixed vqgan encoding of the image. Does this mean we can extend this to maybe audio like replace vqgan encoding with an audio encoding? or even a graph encoding if that is possible

1 Like

One general question - Are there things that JAX / Flax enables you to do, that is harder to do in PyTorch?
Similarly are there things that are easier in PyTorch that are harder to do in JAX / Flax?

1 Like

One more related question - Do you use JAX/Flax for any recent experiments or do you prefer to use PyTorch primarily?

1 Like

Training Question - In CNN / RNN, we might be looking at Gradients or look for Gradient Flow. In Transformers, we primarily only look at Loss value and there are no debugging or issues in Training? What problems you faced in Training DALL-E

1 Like

So, encoded_images generates some N number of image encodings. What is gen_top_k here? Are the encodings rated?

1 Like

So, to init_weights, can and how do you change the weights to your own pretrained weights for the Bart or the VQGAN so we don’t have to train from scratch

1 Like

How do you avoid mode collapse and Noise in images? Is this taken care of VQGAN Decoder that it generates reasonable output given a Input vector that you generate from sampled probabilities of BART Decoder?

1 Like

Sorry, I thought Boris is running on 1 gpu machine, if so how will pmap would help ? Please excuse if I this is a nood question :slight_smile:

1 Like

From what I understand, pmap will not help on single gpu setting. But it also doesn’t hurt. So if you run the same code in multi-gpu setup, it will benefit without changing any part of the code.

2 Likes