We’re thrilled to host Boris Dayma an MLE at W&B
He will be teaching us about his work on implementing Dall-e in JAX
How to Join?
Register here to join live
Recording: JAX Course - 4. JAX Implementation of DALL·E - YouTube
We’re thrilled to host Boris Dayma an MLE at W&B
He will be teaching us about his work on implementing Dall-e in JAX
How to Join?
Register here to join live
Recording: JAX Course - 4. JAX Implementation of DALL·E - YouTube
For the inference pipeline: i missed it but why do you generate these many sample candidate prompts?
Training Pipeline seems to be VQGAN Encoder, BART then Cross entropy loss. What part of this is Trained?
Was VQGAN Encoder - Trained?
Inference Pipeline:
VQGAN Decoder seems to be only during inference. Is the Decoder or CLIP Trained?
Training is clear. But how is it connected to Inference? How do you generate multiple candidate images from one BART Decoder Output?
Is the ranking / CLIP part learned?
Thanks for the detailed explanation on Generation output. What are possible values for the Outputs - 0-255 in 3 channels?
How do you rank the images? Compare the embeddings of input text to generated image?
Nice idea to concat a bart encoding of the text, with the fixed vqgan encoding of the image. Does this mean we can extend this to maybe audio like replace vqgan encoding with an audio encoding? or even a graph encoding if that is possible
One general question - Are there things that JAX / Flax enables you to do, that is harder to do in PyTorch?
Similarly are there things that are easier in PyTorch that are harder to do in JAX / Flax?
One more related question - Do you use JAX/Flax for any recent experiments or do you prefer to use PyTorch primarily?
Training Question - In CNN / RNN, we might be looking at Gradients or look for Gradient Flow. In Transformers, we primarily only look at Loss value and there are no debugging or issues in Training? What problems you faced in Training DALL-E
So, encoded_images generates some N number of image encodings. What is gen_top_k here? Are the encodings rated?
So, to init_weights, can and how do you change the weights to your own pretrained weights for the Bart or the VQGAN so we don’t have to train from scratch
How do you avoid mode collapse and Noise in images? Is this taken care of VQGAN Decoder that it generates reasonable output given a Input vector that you generate from sampled probabilities of BART Decoder?
Sorry, I thought Boris is running on 1 gpu machine, if so how will pmap would help ? Please excuse if I this is a nood question
From what I understand, pmap will not help on single gpu setting. But it also doesn’t hurt. So if you run the same code in multi-gpu setup, it will benefit without changing any part of the code.