# Login to wandb
wandb.login()
#create wandb sweep id
with open("pathto_config", 'r') as stream:
sweep_config = yaml.safe_load(stream)
sweep_id = wandb.sweep(sweep=sweep_config, entity="kishimita", project="Simple-Unet-Training", prior_runs=["run-1"])
print("Sweep config: ", sweep_config)
def get_optimizer(optimizer_name, model, learning_rate):
#optimizer_name = optimizer_name.strip() # Remove any leading/trailing white spaces
if optimizer_name == "Adam":
return torch.optim.Adam(model.parameters(), lr=learning_rate)
elif optimizer_name == "SGD":
return torch.optim.SGD(model.parameters(), lr=learning_rate)
elif optimizer_name == "AdamW":
return torch.optim.AdamW(model.parameters(), lr=learning_rate)
elif optimizer_name == "Adamax":
return torch.optim.Adamax(model.parameters(), lr=learning_rate)
elif optimizer_name == "RMSprop":
return torch.optim.RMSprop(model.parameters(), lr=learning_rate)
elif optimizer_name == "Adagrad":
return torch.optim.Adagrad(model.parameters(), lr=learning_rate)
else:
raise ValueError(f"Unknown optimizer: {optimizer_name}")
def train():
global device
config = sweep_config["parameters"]
model.to(device)
count = 0
optimizer = get_optimizer(config["optim"]['values'][count], model, config["learning_rate"]['values'][count])
lr = config["learning_rate"]['values'][count]
epochs = config["epochs"]['values'][count]
print(len("------------------------------------------------------------------------------------------------------------"))
run = wandb.init(project="Simple-Unet-Training",
config={
"learning_rate": lr,
"architecture": "Simple Unet",
"dataset": "military planes",
"epochs": epochs,
"optimizer": optimizer,
"loss": "L1",
"metric": "L1",
"framework": "PyTorch",
"device": DEVICE,
"torch_seed" : seed
},
name="genesis-run" + "-" +str(count+1),
save_code=False,)
run.config.update(config)
print("*~+~*"*22)
print("\t\t\tThis is the start of training in mins: ", datetime.datetime.now())
print("*~+~*"*22)
memory_count = 0
for epoch in range(epochs):
epoch_start = datetime.datetime.now()
print("--------------------------------------------------------------------------------------------------------------")
print(f"\t\t\t\tThis is epoch : {epoch}'s start time: {epoch_start}")
print("--------------------------------------------------------------------------------------------------------------\n")
total_loss = 0
total_accuracy = 0
#print(f"Epoch :{epoch}")
for step, batch in tqdm(enumerate(train_loader), desc= "Step Loop", ncols=100):
optimizer.zero_grad()
t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
# Move the input data to the GPU
batch_gpu = batch[0].to(device)
loss = get_loss(model, batch_gpu, t)
loss.backward()
optimizer.step()
# Calculate accuracy
accuracy = accuracy_l1(model, batch_gpu, t)
total_accuracy += accuracy.item()
total_loss += loss.item()
print(f"This is memory usage after inner loop ends time {memory_count}")
memory_count += 1
print_memory_usage()
# Select the first image from the batch
input_image = batch_gpu[0]
output_image = model(input_image.unsqueeze(0), t)[0]
#log input and output image in the same log
wandb.log({"Input Image": wandb.Image(input_image.detach().cpu(), caption="Input Image-" + str(count))
,"Output Image": wandb.Image(output_image.detach().cpu(), caption="Output Image-" + str(count))})
del batch_gpu
if epoch % 5 == 0 and step == 0:
print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
#sample_plot_image()
wandb.log({"Lr": lr})
wandb.log({"epoch": epoch})
wandb.log({"Loss": total_loss/len(train_loader)})
wandb.log({"Accuracy": total_accuracy/len(train_loader)})
print(f"Total Epochs : {epochs}")
print(f"Current Epoch : {epoch}")
print(f"Optimizer : {optimizer}")
print(f"Lr : {lr}")
print(f"Loss : {loss.item()}")
del loss
print(f"Accuracy : {accuracy.item()}")
del accuracy
print(f"Total Loss : {total_loss/len(train_loader)}")
del total_loss
print(f"Total Accuracy : {total_accuracy/len(train_loader)}")
del total_accuracy
epoch_end = datetime.datetime.now()
print("--------------------------------------------------------------------------------------------------------------")
print(f"\t\t\tThis is epoch :{epoch}'s end time: {epoch_end}")
print("--------------------------------------------------------------------------------------------------------------\n")
count += 1
run.finish()
print("*~+~*"*12)
print(f"\t\t\t\tThis is the end of training in mins: {datetime.datetime.now()}")
print("*~+~*"*12)
config.finish()
wandb.agent(sweep_id="shtf1crd", function=train, project="Simple-Unet-Training", entity="kishimita")
here is the yaml config file
program: train.py
name: 'sweep 1'
method: bayes
metric:
goal: minimize
name: L1
parameters:
batch_size:
values: [128]
learning_rate:
values: [0.01, 0.015, 0.02, 0.025, 0.03, 0.035, 0.04, 0.045, 0.05, 0.055, 0.06, 0.065, 0.07, 0.075, 0.08, 0.085, 0.09, 0.095, 0.1]
optim:
values: ["Adam", "Adamax", "AdamW", "SGD", "RMSprop", "Adagrad"]
epochs:
values: [100, 150, 200, 250, 300, 350]
loss:
values: ['L1', 'MSE', 'BCE', 'CrossEntropy']
accuracy:
values: ['1-L1_loss', '1-MSE_loss', '1-BCE_loss', '1-CrossEntropy_loss']
activation:
values : ['ReLU', 'Sigmoid', 'Tanh', 'LeakyReLU']
early_terminate:
type: hyperband
min_iter: 3
command:
- ${env}
- path to python executable
- CUDA_VISIBLE_DEVICES = 1
- train.py
- ${args}
Thank you for your help