When working with Weights and Biases (W&B/wandb) for hyperparameter (hp) optimization, you can use sweeps to systematically explore different combinations of hyperparameters to find the best performing set.
What are Sweeps?
Sweeps in W&B allow you to define a set of hyperparameters to search over. When you create a sweep in the cli or in python, this creates a set (e.g., grid) of hyperparameters to search over. This (usually afaik) create a sweep with all the possible hyperparameters in your wandb account/server and later when you run an agent it fetches one and tries that hp and logs it to wandb.
Understanding count and Agents
- Agents: These are the workers that run the trials/hp attempts/runs – basically try each hp. Each agent pulls a set of hyperparameters (hps) from the W&B/wandb server, runs the (usually) training script (but in our example a search for optimal chinchilla scaling laws) with these hps, logs the results, and repeats.
- Count: This is the number of trials/runs the agent will run. If you set count to 100, the agent will run 100 trials, each with a different combination of hyperparameters.
Afaik, if you use grid search and count is higher than the total number of combinations, the sweep stops after running all combinations. If count is lower, it runs only the specified number of trials.
Afaik as long as an agent is running and the sweep has hps to try, it will keep fetching them from your wandb sweep server from your wandb account. You Can see the agents running (and kill them, pause them) etc for a run in your wandb’s account.
I think the crux is that an agent continually fetches hps until the sweep is finished from your wandb’s sweep run on the wandb website (or you kill it). You can run multiple agents until the sweep’s hps are exhausted on your wanbd’s website.
I will provide an example without multiprocessing then make it multiprocessing (mp):
1 Example without Multiprocessing
import numpy as np
import scipy.optimize as opt
import scipy.special
import wandb
# Define the synthetic scaling law function
def scaling_law(c, e, a, b, alpha, beta):
return np.exp(e) + np.exp(a) * c[:, 0] ** (-alpha) + np.exp(b) * c[:, 1] ** (-beta)
# Generate synthetic data
np.random.seed(0)
C = np.array([[7e9, 2e12], [13e9, 2e12], [34e9, 2e12], [70e9, 2e12]]) # [m, Din] = [m, 2]
e_true, a_true, b_true = np.log(1.8172), np.log(482.01), np.log(2085.43)
alpha_true, beta_true = 0.3478, 0.3658
L_target = scaling_law(C, e_true, a_true, b_true, alpha_true, beta_true).reshape(-1, 1) # [m, K]
L_target = np.repeat(L_target, 1, axis=1) # [m , K]
# Define the cost function using the Huber loss
def aggregate_huber_loss(theta_sl, c, l_target, delta=1e-3):
e, a, b, alpha, beta = theta_sl
E, A, B = np.exp(e), np.exp(a), np.exp(b)
l_pred = E + A * c[:, 0] ** (-alpha) + B * c[:, 1] ** (-beta)
log_l_target = np.log(l_target)
x1 = a - alpha * np.log(c[:, 0]).reshape(-1, 1)
x2 = b - beta * np.log(c[:, 1]).reshape(-1, 1)
x3 = e * np.ones((c.shape[0], 1))
lse = scipy.special.logsumexp([x1, x2, x3], axis=0)
h = scipy.special.huber(delta, lse - log_l_target)
return h.sum()
# Training function to run each trial
def train():
wandb.init()
config = wandb.config
initial_params = [config.e, config.a, config.b, config.alpha, config.beta]
# Perform the optimization
result = opt.minimize(aggregate_huber_loss, initial_params, args=(C, L_target), method='BFGS')
optimized_params = result.x
e_opt, a_opt, b_opt, alpha_opt, beta_opt = optimized_params
loss = aggregate_huber_loss(optimized_params, C, L_target)
wandb.log({
"e": e_opt,
"a": a_opt,
"b": b_opt,
"alpha": alpha_opt,
"beta": beta_opt,
"loss": loss
})
# Sweep configuration for grid search
sweep_config = {
"method": "grid",
"metric": {
"name": "loss",
"goal": "minimize"
},
"parameters": {
"e": {
"values": [-1, 0, 1]
},
"a": {
"values": [0, 5, 10]
},
"b": {
"values": [0, 5, 10]
},
"alpha": {
"values": [0, 1, 2]
},
"beta": {
"values": [0, 1, 2]
}
}
}
# Initialize the sweep
sweep_id = wandb.sweep(sweep_config, project="scaling-law-optimization")
# Print the sweep URL and ID
print(f"Sweep URL: https://wandb.ai/{wandb.run.entity}/{wandb.run.project}/sweeps/{sweep_id}")
print(f"Sweep ID: {sweep_id}")
# Run the sweep
# wandb.agent(sweep_id, function=train, count=10) # only tries 10 out og G^5 sweeps
wandb.agent(sweep_id, function=train) # tries all G^5 hps! Sweeps them all!
my understanding is as long as the agents are running it keeps fetching hps fromt he server sweep until the sever (your wandb site/account) for this sweep is exhausted. Some sweeps like random and bayesian afaik can run forever! So count is important here (or manually killing it).
2 Example with Multiprocessing
The main idea I think is realizing that when you create a sweep (in python or in the cli), the process fetching hps/trials to try is the agent. So my suggestions to parallelize over the agent e.g.,
import numpy as np
import scipy.optimize as opt
import scipy.special
import wandb
from multiprocessing import Process, cpu_count
# Define the synthetic scaling law function
def scaling_law(c, e, a, b, alpha, beta):
return np.exp(e) + np.exp(a) * c[:, 0] ** (-alpha) + np.exp(b) * c[:, 1] ** (-beta)
# Generate synthetic data
np.random.seed(0)
C = np.array([[7e9, 2e12], [13e9, 2e12], [34e9, 2e12], [70e9, 2e12]]) # [m, Din] = [m, 2]
e_true, a_true, b_true = np.log(1.8172), np.log(482.01), np.log(2085.43)
alpha_true, beta_true = 0.3478, 0.3658
L_target = scaling_law(C, e_true, a_true, b_true, alpha_true, beta_true).reshape(-1, 1) # [m, K]
L_target = np.repeat(L_target, 1, axis=1) # [m , K]
# Define the cost function using the Huber loss
def aggregate_huber_loss(theta_sl, c, l_target, delta=1e-3):
e, a, b, alpha, beta = theta_sl
E, A, B = np.exp(e), np.exp(a), np.exp(b)
l_pred = E + A * c[:, 0] ** (-alpha) + B * c[:, 1] ** (-beta)
log_l_target = np.log(l_target)
x1 = a - alpha * np.log(c[:, 0]).reshape(-1, 1)
x2 = b - beta * np.log(c[:, 1]).reshape(-1, 1)
x3 = e * np.ones((c.shape[0], 1))
lse = scipy.special.logsumexp([x1, x2, x3], axis=0)
h = scipy.special.huber(delta, lse - log_l_target)
return h.sum()
# Training function to run each trial
def train():
wandb.init()
config = wandb.config
initial_params = [config.e, config.a, config.b, config.alpha, config.beta]
# Perform the optimization
result = opt.minimize(aggregate_huber_loss, initial_params, args=(C, L_target), method='BFGS')
optimized_params = result.x
e_opt, a_opt, b_opt, alpha_opt, beta_opt = optimized_params
loss = aggregate_huber_loss(optimized_params, C, L_target)
wandb.log({
"e": e_opt,
"a": a_opt,
"b": b_opt,
"alpha": alpha_opt,
"beta": beta_opt,
"loss": loss
})
# Sweep configuration for grid search
sweep_config = {
"method": "grid",
"metric": {
"name": "loss",
"goal": "minimize"
},
"parameters": {
"e": {
"values": [-1, 0, 1]
},
"a": {
"values": [0, 5, 10]
},
"b": {
"values": [0, 5, 10]
},
"alpha": {
"values": [0, 1, 2]
},
"beta": {
"values": [0, 1, 2]
}
}
}
# Initialize the sweep
sweep_id = wandb.sweep(sweep_config, project="scaling-law-optimization")
# Print the sweep URL and ID
print(f"Sweep URL: https://wandb.ai/{wandb.run.entity}/{wandb.run.project}/sweeps/{sweep_id}")
print(f"Sweep ID: {sweep_id}")
# Function to run an agent
def run_agent():
# wandb.agent(sweep_id, function=train, count=10) # runs subset 10 <= G^5 sweeps
wandb.agent(sweep_id, function=train) # keeps fetching hps until all hps in sweep are done. All G^5
# Number of agents to run in parallel
num_agents = min(cpu_count(), 72) # Adjust this number based on your system
if __name__ == "__main__":
processes = []
for _ in range(num_agents):
p = Process(target=run_agent)
p.start()
processes.append(p)
for p in processes:
p.join()
print('Done!\a')
In CLI Bash
1 Without multiprocressing
Creteate an sweep from the config .yaml file and then pass that sweep id to the multiple agents you create (e.g., with a slurm array or tmux sessions).
In detail:
1 create yaml file for sweeps
program: ~/github_repo_proj_folder/scaling_laws.py
method: grid
metric:
name: loss
goal: minimize
parameters:
e:
values: [-1, 0, 1]
a:
values: [0, 5, 10]
b:
values: [0, 5, 10]
alpha:
values: [0, 1, 2]
beta:
values: [0, 1, 2]
2: Initialize the Sweep in cli terminal
wandb sweep sweep_config.yaml
This command will output a sweep ID in the format entity/project/sweep_ID. Note this sweep ID for the next steps.
3: Run a Single Agent
Running an agent will continually fetch hps from your wandb server sweep run until it’s done:
wandb agent <sweep_id>
replacing <sweep_id> with your actual sweep ID, e.g., from the output of the previous command.
2 With multiprocessing parallelization in the cli
One way is to run each agent once the wandb sweep is initialize in lots of tmux sessions with the &
commands or with slurm arrays or even with nohop:
#!/bin/bash
# start sweep
## wandb sweep sweep_config.yaml
# Number of agents to run
NUM_AGENTS=4
SWEEP_ID=<sweep_id> # Replace with your actual sweep ID
# Run agents in parallel
for i in $(seq 1 $NUM_AGENTS); do
nohup wandb agent $SWEEP_ID > agent_$i.log 2>&1 &
done
# Wait for all agents to finish (optional)
wait
ref: multithreading - How to implement multiprocessing with Weights and Biases wandb sweeps for maximum parallelization, especially how the count var work in this setting? - Stack Overflow