Hi,
I am struggling to get sweeps to work with Hugging Face’s Accelerate library. Specifically, the first run of the sweep works fine, but every run thereafter fails due to re-initialising the Accelerator for every run. In every run from the 2nd, I get the error: AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass mixed_precision='bf16' to Accelerate().
Below is a minimal example of a script which I’m launching using accelerate launch
. I’d appreciate any suggestions. Thanks!
import os
from typing import Any, List, Tuple
from accelerate import Accelerator
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from transformers import (
Adafactor,
PreTrainedTokenizerFast,
T5ForConditionalGeneration,
T5TokenizerFast,
)
import wandb
class TestDataset(Dataset[Any]):
def __init__(self, tokenizer: PreTrainedTokenizerFast) -> None:
super().__init__()
self._str_prompt = "This is a "
self._str_target = "test."
self._tokenizer = tokenizer
def __len__(self) -> int:
return 1
def __getitem__(self, idx: int) -> Tuple[str, str]:
return self._str_prompt, self._str_target
def collate(self, batch: List[Tuple[str, str]]) -> Tuple[Tensor, Tensor]:
prompts = [b[0] for b in batch]
targets = [b[1] for b in batch]
prompts_tokenized = self._tokenizer(prompts, return_tensors="pt")
targets_tokenized = self._tokenizer(targets, return_tensors="pt")
return prompts_tokenized["input_ids"], targets_tokenized["input_ids"]
def main() -> None:
accelerator = Accelerator(log_with="wandb", mixed_precision="bf16")
if accelerator.is_main_process:
accelerator.init_trackers(os.environ.get("WANDB_PROJECT"))
accelerator.wait_for_everyone()
wandb_tracker = accelerator.get_tracker("wandb")
multiplier = wandb_tracker.config["multiplier"]
model = T5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5TokenizerFast.from_pretrained("t5-small")
opt = Adafactor(params=model.parameters())
dataset = TestDataset(tokenizer=tokenizer)
data_loader = DataLoader(dataset=dataset, collate_fn=dataset.collate)
model, opt, data_loader = accelerator.prepare(model, opt, data_loader)
input_ids, labels = next(iter(data_loader))
loss = model(input_ids=input_ids, labels=labels).loss
loss_gathered = accelerator.gather_for_metrics(loss).mean()
accelerator.log({"loss": loss_gathered.item() * multiplier})
accelerator.end_training()
if __name__ == "__main__":
sweep_configuration = {
"method": "random",
"metric": {"goal": "maximize", "name": "loss"},
"parameters": {"multiplier": {"values": list(range(100))}},
}
sweep_id = wandb.sweep(
sweep=sweep_configuration,
project=os.environ.get("WANDB_PROJECT"),
)
wandb.agent(sweep_id, function=main, count=3)