Wandb for Huggingface Trainer saves only first model

I am finetuning multiple models using for loop as follows.

for file in os.listdir(args.data_dir):
    finetune(args, file)

BUT wandb shows logs only for the first file in data_dir although it is training and saving models for other files. It feels very strange behavior.

wandb: Synced bertweet-base-finetuned-file1: https://wandb.ai/***/huggingface/runs/***

This is a small snippet of finetuning code with Huggingface:

def finetune(args, file):
    training_args = TrainingArguments(
        output_dir=f'{model_name}-finetuned-{file}',
        overwrite_output_dir=True,
        evaluation_strategy='no',
        num_train_epochs=args.epochs,
        learning_rate=args.lr,
        weight_decay=args.decay,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        fp16=True, # mixed-precision training to boost speed
        save_strategy='no',
        seed=args.seed,
        dataloader_num_workers=4,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset['train'],
        eval_dataset=None,
        data_collator=data_collator,
    )
    trainer.train()
    trainer.save_model()

@kgarg8 , you’ve set save_strategy to NO in your code to avoid saving anything. This would only save the final model once training is done with trainer.save_model() . You can update it to save_strategy="epoch" and it will save the model with every epoch.

Or, in order to log models, you could also set the env var WANDB_LOG_MODEL as specified in our docs here. Once you set this env var, any Trainer you initialize from now on will upload models to your W&B project. Note that your model will be saved to W&B Artifacts as run-{run_name} .

wandb.init(reinit=True) and run.finish() helped me to log the models separately on wandb website.

The working code looks like below:


for file in os.listdir(args.data_dir):
    finetune(args, file)

import wandb
def finetune(args, file):
    run = wandb.init(reinit=True)
    ...
    run.finish()

Reference: Launch Experiments with wandb.init - Documentation

1 Like