Sweeps + Accelerate (mulit GPU) + Trainer

I am trying to use huggingface Trainer Api together with Accelerate (with deepspeed stage 3) to perform a hyperparameter sweep. I have a single node, with two GPU available.
I want each run to use both GPUs.

When I am performing a run without sweep everything works fine. But with hypereparameter sweep I believe each GPU spawns one run each, which might ends up in deadlock as the model/optimizer/gradients are all loaded and unloaded at different times? At the end it just continues request wandb.ai without starting any training.

my train script:

def train(config=None):
    with wandb.init(project=project_name):
        print("SETTING MODEL")
        config = wandb.config

       model_peft = set_up_model(model,config["lora_rank"],config["lora_alpha"],config["lora_dropout"])
       print('LOADED PEFT MODEL')
        
       trainable_params, all_param, percentage = print_trainable_parameters(model_peft)
       wandb.log({'trainable_params': trainable_params, 'all_param': all_param, 'percentage_tain': percentage})

        print("SETTING TRAINER")
        trainer = transformers.Trainer(
            model = model_peft, 
            tokenizer=tokenizer,
            train_dataset=data['train'],
            args=transformers.TrainingArguments(
                per_device_train_batch_size=config["batch_size"], 
                gradient_accumulation_steps=2,
                evaluation_strategy='steps',
                num_train_epochs=config["epochs"],
                warmup_steps=100, 
                learning_rate= config["lr"],
                fp16=True,
                eval_steps=10,
                logging_steps=1, 
                output_dir=output_dir,
                gradient_checkpointing= True,
                report_to='wandb',
            ),
            data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
            eval_dataset=data['test'], 
        )
        
        model_peft.config.use_cache = False  # silence the warnings. Please re-enable for inference!
        print("STARTING TRAINING:...")
        result = trainer.train()

sweep_id = wandb.sweep(sweep=sweep_config, project=project_name) 
wandb.agent(sweep_id=sweep_id,function=train, count=num_sweeps)

which gives following output, where it eventually just keep checking connection with wandb.ai

024-06-03 00:38:48 DEBUG Starting new HTTPS connection (1): api.wandb.ai:443
2024-06-03 00:38:49 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 196
Create sweep with ID: mzu3ccc6
Sweep URL: https://wandb.ai/tnf/NOR_GLM_369_sweep_evaluate_3/sweeps/mzu3ccc6
2024-06-03 00:38:49 INFO Starting sweep agent: entity=None, project=None, count=9
2024-06-03 00:38:49 DEBUG Agent._setup()
2024-06-03 00:38:49 DEBUG Agent._register()
2024-06-03 00:38:49 DEBUG Starting new HTTPS connection (1): api.wandb.ai:443
2024-06-03 00:38:49 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 196
Create sweep with ID: holmp7kn
Sweep URL: https://wandb.ai/tnf/NOR_GLM_369_sweep_evaluate_3/sweeps/holmp7kn
2024-06-03 00:38:49 INFO Starting sweep agent: entity=None, project=None, count=9
2024-06-03 00:38:49 DEBUG Agent._setup()
2024-06-03 00:38:49 DEBUG Agent._register()
2024-06-03 00:38:49 DEBUG Starting new HTTPS connection (1): api.wandb.ai:443
2024-06-03 00:38:49 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 68
2024-06-03 00:38:49 DEBUG agent_id = QWdlbnQ6Z2U5N2Ezc24=
2024-06-03 00:38:49 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 68
2024-06-03 00:38:49 DEBUG agent_id = QWdlbnQ6ZnQ5OHkwdW8=
2024-06-03 00:38:49 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 281
2024-06-03 00:38:49 DEBUG Job received: Job(m4569mfa,{'batch_size': {'value': 4}, 'epochs': {'value': 1}, 'lora_alpha': {'value': 16}, 'lora_dropout': {'value': 0.05}, 'lora_rank': {'value': 16}, 'lr': {'value': 9e-06}})
2024-06-03 00:38:49 DEBUG Spawning new thread for run m4569mfa.
2024-06-03 00:38:49 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 284
2024-06-03 00:38:49 DEBUG Job received: Job(ke4vgbe6,{'batch_size': {'value': 8}, 'epochs': {'value': 1}, 'lora_alpha': {'value': 16}, 'lora_dropout': {'value': 0.05}, 'lora_rank': {'value': 64}, 'lr': {'value': 9e-06}})
2024-06-03 00:38:49 DEBUG Spawning new thread for run ke4vgbe6.
wandb: Agent Starting Run: m4569mfa with config:
wandb:  batch_size: 4
wandb:  epochs: 1
wandb:  lora_alpha: 16
wandb:  lora_dropout: 0.05
wandb:  lora_rank: 16
wandb:  lr: 9e-06
2024-06-03 00:38:50 DEBUG git repository is invalid
wandb: Agent Starting Run: ke4vgbe6 with config:
wandb:  batch_size: 8
wandb:  epochs: 1
wandb:  lora_alpha: 16
wandb:  lora_dropout: 0.05
wandb:  lora_rank: 64
wandb:  lr: 9e-06
2024-06-03 00:38:50 DEBUG git repository is invalid
2024-06-03 00:38:51 DEBUG Starting new HTTPS connection (1): api.wandb.ai:443
2024-06-03 00:38:51 DEBUG Starting new HTTPS connection (1): api.wandb.ai:443
2024-06-03 00:38:51 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 1879
2024-06-03 00:38:51 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 1879
2024-06-03 00:38:51 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 353
wandb: Currently logged in as: tnf. Use `wandb login --relogin` to force relogin
wandb: WARNING Ignored wandb.init() arg project when running a sweep.
2024-06-03 00:38:51 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 353
wandb: Currently logged in as: tnf. Use `wandb login --relogin` to force relogin
wandb: WARNING Ignored wandb.init() arg project when running a sweep.
2024-06-03 00:38:55 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:38:55 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:00 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:00 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
wandb: Tracking run with wandb version 0.17.0
wandb: Run data is saved locally in /cluster/home/terjenf/norwAI_All/wandb/run-20240603_003851-ke4vgbe6
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run happy-sweep-1
wandb: ⭐️ View project at https://wandb.ai/tnf/NOR_GLM_369_sweep_evaluate_3
wandb: 🧹 View sweep at https://wandb.ai/tnf/NOR_GLM_369_sweep_evaluate_3/sweeps/holmp7kn
wandb: 🚀 View run at https://wandb.ai/tnf/NOR_GLM_369_sweep_evaluate_3/runs/ke4vgbe6
wandb: Tracking run with wandb version 0.17.0
wandb: Run data is saved locally in /cluster/home/terjenf/norwAI_All/wandb/run-20240603_003851-m4569mfa
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run expert-sweep-1
wandb: ⭐️ View project at https://wandb.ai/tnf/NOR_GLM_369_sweep_evaluate_3
wandb: 🧹 View sweep at https://wandb.ai/tnf/NOR_GLM_369_sweep_evaluate_3/sweeps/mzu3ccc6
wandb: 🚀 View run at https://wandb.ai/tnf/NOR_GLM_369_sweep_evaluate_3/runs/m4569mfa
2024-06-03 00:39:05 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:05 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
SETTING MODELSETTING MODEL

/cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/peft/tuners/lora/layer.py:1119: UserWarning: fan_in_fan_out is set to False but the target module is `Conv1D`. Setting fan_in_fan_out to True.
  warnings.warn(
/cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/peft/tuners/lora/layer.py:1119: UserWarning: fan_in_fan_out is set to False but the target module is `Conv1D`. Setting fan_in_fan_out to True.
  warnings.warn(
LOADED PEFT MODEL
trainable params: 1572864 || all params: 371517440 || trainable%: 0.4234
SETTING TRAINER
/cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/transformers/training_args.py:1474: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
LOADED PEFT MODEL
trainable params: 6291456 || all params: 376236032 || trainable%: 1.6722
SETTING TRAINER
/cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/transformers/training_args.py:1474: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
[2024-06-03 00:39:07,418] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-06-03 00:39:07,436] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-devel package with yum
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  async_io: please install the libaio-devel package with yum
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.3
 [WARNING]  using untested triton version (2.3.0), only 1.0.0 is known to be compatible
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.3
 [WARNING]  using untested triton version (2.3.0), only 1.0.0 is known to be compatible
[2024-06-03 00:39:07,765] [INFO] [comm.py:637:init_distributed] cdb=None
[2024-06-03 00:39:07,765] [INFO] [comm.py:637:init_distributed] cdb=None
[2024-06-03 00:39:07,766] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
STARTING TRAINING:...
STARTING TRAINING:...
Installed CUDA version 12.3 does not match the version torch was compiled with 12.1 but since the APIs are compatible, accepting this combination
Using /cluster/home/terjenf/.cache/torch_extensions/py311_cu121 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /cluster/home/terjenf/.cache/torch_extensions/py311_cu121/cpu_adam/build.ninja...
/cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/torch/utils/cpp_extension.py:1967: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Installed CUDA version 12.3 does not match the version torch was compiled with 12.1 but since the APIs are compatible, accepting this combination
Using /cluster/home/terjenf/.cache/torch_extensions/py311_cu121 as PyTorch extensions root...
2024-06-03 00:39:10 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:10 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:15 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:15 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
[1/4] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output custom_cuda_kernel.cuda.o.d -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/deepspeed/ops/csrc/includes -I/usr/local/cuda/include -isystem /cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/torch/include -isystem /cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/torch/include/TH -isystem /cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /cluster/home/terjenf/.conda/envs/vgdebatt/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -O3 --use_fast_math -std=c++17 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ --threads=8 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_80,code=compute_80 -DBF16_AVAILABLE -U__CUDA_NO_BFLOAT16_OPERATORS__ -U__CUDA_NO_BFLOAT162_OPERATORS__ -c /cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/deepspeed/ops/csrc/common/custom_cuda_kernel.cu -o custom_cuda_kernel.cuda.o 
2024-06-03 00:39:20 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:20 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:26 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:26 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:31 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:31 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:36 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:36 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:41 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:41 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:46 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:46 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:51 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:51 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:57 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:39:57 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:40:02 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:40:02 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
[2/4] c++ -MMD -MF cpu_adam.o.d -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/deepspeed/ops/csrc/includes -I/usr/local/cuda/include -isystem /cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/torch/include -isystem /cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/torch/include/TH -isystem /cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /cluster/home/terjenf/.conda/envs/vgdebatt/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -O3 -std=c++17 -g -Wno-reorder -L/usr/local/cuda/lib64 -lcudart -lcublas -g -march=native -fopenmp -D__AVX512__ -D__ENABLE_CUDA__ -DBF16_AVAILABLE -c /cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/deepspeed/ops/csrc/adam/cpu_adam.cpp -o cpu_adam.o 
[3/4] c++ -MMD -MF cpu_adam_impl.o.d -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/deepspeed/ops/csrc/includes -I/usr/local/cuda/include -isystem /cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/torch/include -isystem /cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/torch/include/TH -isystem /cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /cluster/home/terjenf/.conda/envs/vgdebatt/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -O3 -std=c++17 -g -Wno-reorder -L/usr/local/cuda/lib64 -lcudart -lcublas -g -march=native -fopenmp -D__AVX512__ -D__ENABLE_CUDA__ -DBF16_AVAILABLE -c /cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/deepspeed/ops/csrc/adam/cpu_adam_impl.cpp -o cpu_adam_impl.o 
[4/4] c++ cpu_adam.o cpu_adam_impl.o custom_cuda_kernel.cuda.o -shared -lcurand -L/cluster/home/terjenf/.conda/envs/vgdebatt/lib/python3.11/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda/lib64 -lcudart -o cpu_adam.so
Loading extension module cpu_adam...
Time to load cpu_adam op: 55.60796356201172 seconds
Loading extension module cpu_adam...
Time to load cpu_adam op: 55.2439079284668 seconds
2024-06-03 00:40:07 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:40:07 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
Parameter Offload: Total persistent parameters: 1894400 in 218 params
2024-06-03 00:40:12 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:40:12 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:40:17 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:40:18 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:40:22 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:40:23 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:40:28 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:40:28 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:40:33 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
2024-06-03 00:40:33 DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 57
....

Hello, This a reply from our support bot designed to assist you with your Weights & Biases related queries. To reach a human please reply to this message.

‘context’

To reach a human please reply to this message.

-WandBot :robot:

yeah, still need help

Hi Terje!

To address the issue where each GPU spawns one run during hyperparameter sweeps, leading to deadlocks, you need to ensure that only one instance of the sweep agent is running and utilizing both GPUs for a single run. Here’s how you can adjust your script and the configuration to achieve this:

  1. Ensure that only one sweep agent is running: You should run the sweep agent on a single process. This can be done by setting the number of processes for the sweep agent to 1.

  2. Utilize both GPUs for training:

Configure the transformers.TrainingArguments to use both GPUs by setting n_gpu parameter appropriately.

Here’s an updated version of your script:

python
import wandb
import transformers
from transformers import Trainer, TrainingArguments
from transformers import set_seed

def train(config=None):
    with wandb.init():
        config = wandb.config
        set_seed(42)

        # Setting up the model with the provided configuration
        model_peft = set_up_model(model, config["lora_rank"], config["lora_alpha"], config["lora_dropout"])
        trainable_params, all_param, percentage = print_trainable_parameters(model_peft)
        wandb.log({'trainable_params': trainable_params, 'all_param': all_param, 'percentage_tain': percentage})

        # Setting up the Trainer
        trainer = Trainer(
            model=model_peft,
            tokenizer=tokenizer,
            train_dataset=data['train'],
            eval_dataset=data['test'],
            args=TrainingArguments(
                per_device_train_batch_size=config["batch_size"],
                gradient_accumulation_steps=2,
                evaluation_strategy='steps',
                num_train_epochs=config["epochs"],
                warmup_steps=100,
                learning_rate=config["lr"],
                fp16=True,
                eval_steps=10,
                logging_steps=1,
                output_dir=output_dir,
                gradient_checkpointing=True,
                report_to='wandb',
                dataloader_num_workers=4,
                save_steps=10,
                n_gpu=2  # Ensure this matches the number of GPUs available
            ),
            data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
        )

        model_peft.config.use_cache = False  # silence the warnings. Please re-enable for inference!
        result = trainer.train()

# Define the sweep configuration
sweep_config = {
    'method': 'bayes',
    'metric': {'name': 'eval_loss', 'goal': 'minimize'},
    'parameters': {
        'batch_size': {'values': [4, 8]},
        'epochs': {'values': [1, 2]},
        'lora_alpha': {'values': [16, 32]},
        'lora_dropout': {'values': [0.05, 0.1]},
        'lora_rank': {'values': [16, 32]},
        'lr': {'max': 0.0001, 'min': 1e-6}
    }
}

# Initialize the sweep
sweep_id = wandb.sweep(sweep_config, project=project_name)

# Start the sweep agent
wandb.agent(sweep_id, function=train, count=num_sweeps)

Key Adjustments:

  1. Single Agent with n_gpu: Ensuring only one sweep agent runs and uses both GPUs by setting n_gpu=2 in the TrainingArguments.
  2. Sweep Configuration: Simplified and set to only one sweep agent to avoid multiple runs on multiple GPUs.### Important Considerations:
  • Ensure your environment is set up to recognize both GPUs.
  • Verify that your DeepSpeed configuration is set to handle multi-GPU setups appropriately.
  • Check if your set_up_model function handles the model distribution across GPUs correctly.

By making these adjustments, you should be able to run hyperparameter sweeps with the Huggingface Trainer API and DeepSpeed, utilizing both GPUs without deadlocks.