Intially I wanted to run a hugging face run such that if the user wanted to run a sweep they could (and merge them with the command line arguments given) or just execute the run with the arguments from command line. The merging is so that the train script uses a single args object (e.g. tuple[DataClass, …]) to execute it’s run. This would lead to merging the arguments from sweep or command line. But then I realized that if the user wanted to do wandb.init in a custom way through the arguments then one couldn’t do the standard run = wand.init()
with no arguments that is common for sweeps. Since the wandb config usually specifies this fully. So I’d need two wandb.init()
. Then the code got ugly and confusing and I realized that perhaps only running from the cmd arguments or from the sweep seperately is the best. And then it made me wonder, ok so how do people actuall yuse wandb sweeps officially with hugging face.
So what is an example demo of how to run wandb sweeps with hugging face transformers? At some point the wandb_config and the run arguments have to merge so to execute the hf run correct. And I assume if report_to='wandb'
is needed for the trainer to call the wandb.init() properly (or the need to call it manually).
Pseudo Python
def exec_train(args: tuple):
"""
note:
- decided against named obj to simplify code i.e. didn't know model_args, data_args, training_args, general_args
how to have the code write the variables on it's own. Would Namespace(**tup) work? Dont want to do d['x'] = x manually.
I don't think automatic nameing obj is possible in python: https://chat.openai.com/share/b1d58369-ce27-4ee3-a588-daf28137f774
better reference maybe some day.
- seperates logic of wandb setup from the actual training code a little bit for cleaner (to reason) code.
- passes run var just in case it's needed.
"""
model_args, data_args, training_args = args
print(training_args.report_to)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
special_tokens_dict = get_special_tokens_dict()
smart_tokenizer_and_embedding_resize(
special_tokens_dict=special_tokens_dict,
tokenizer=tokenizer,
model=model,
)
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
trainer.train()
def train(args: tuple):
"""
Runs train but seperates the wandb setup from the actual training code.
"""
# - init wanbd run
run = wandb.init()
print(f'{wandb.get_sweep_url()}=')
# - exec run
# args[3].run = run # just in case the GeneralArguments has a pointer to run. Decided against this to avoid multiple pointers to the same object.
exec_train(args)
# - finish wandb
run.finish()
def exec_run_from_sweep():
""" Run standard sweep.
In uutils since this is standard code. (You can write in your private repo optional expansions.)
"""
# -- 1. Define the sweep configuration in a YAML file and load it in Python as a dict.
path2sweep_config = '~/ultimate-utils/tutorials_for_myself/my_wandb_uu/my_wandb_sweeps_uu/sweep_in_python_yaml_config/sweep_config.yaml'
config_path = Path(path2sweep_config).expanduser()
with open(config_path, 'r') as file:
sweep_config = yaml.safe_load(file)
# -- 2. Initialize the sweep in Python which create it on your project/eneity in wandb platform and get the sweep_id.
sweep_id = wandb.sweep(sweep_config, entity=sweep_config['entity'], project=sweep_config['project'])
# -- 3. Finally, once the sweep_id is acquired, execute the sweep using the desired number of agents in python.
wandb.agent(sweep_id, function=train, count=5)
# print(f"Sweep URL: https://wandb.ai/{sweep_config['entity']}/{sweep_config['project']}/sweeps/{sweep_id}")
wandb.get_sweep_url()
def get_args_for_run_from_cmd_args_or_sweep():
"""
Simply execs a run either from a wand sweep file or from the command line arguments. Ignore the wandb sweep details
if it confuses you.
"""
# 1. parse all the arguments from the command line
parser = HfArgumentParser((ModelArguments, DataArguments, CustomTrainingArguments, GeneralArguments))
_, _, _, general_args = parser.parse_args_into_dataclasses() # default args is to parse sys.argv
# 2. if the wandb_config option is on, then overwrite run cmd line configuration in favor of the sweep_config.
if general_args.path2sweep_config: # None => False => not getting wandb_config
# overwrite run configuration with the wandb_config configuration (get config and create new args)
config_path = Path(general_args.path2sweep_config).expanduser()
with open(config_path, 'r') as file:
sweep_config = dict(yaml.safe_load(file))
sweep_args: list[str] = [item for pair in [[f'--{k}', str(v)] for k, v in sweep_config.items()] for item in pair]
model_args, data_args, training_args, general_args = parser.parse_args_into_dataclasses(args=sweep_args)
args: tuple = (model_args, data_args, training_args, general_args) # decided against named obj to simplify code
# 3. execute run from sweep
# Initialize the sweep in Python which create it on your project/eneity in wandb platform and get the sweep_id.
sweep_id = wandb.sweep(sweep_config, entity=sweep_config['entity'], project=sweep_config['project'])
# # Finally, once the sweep_id is acquired, execute the sweep using the desired number of agents in python.
train = lambda : train(args) # pkg train with args i.e., when you call train() it will all train(args).
wandb.agent(sweep_id, function=train, count=general_args.count)
# # print(f"Sweep URL: https://wandb.ai/{sweep_config['entity']}/{sweep_config['project']}/sweeps/{sweep_id}")
# wandb.get_sweep_url()
else:
# use the args from the command line
parser = HfArgumentParser((ModelArguments, DataArguments, CustomTrainingArguments, GeneralArguments))
model_args, data_args, training_args, general_args = parser.parse_args_into_dataclasses()
# 3. execute run
args: tuple = (model_args, data_args, training_args, general_args) # decided against named obj to simplify code
# train(args)
return args
if __name__ == '__main__':
import time
start_time = time.time()
exec_run_from_cmd_args_or_sweep()
print(f"The main function executed in {time.time() - start_time} seconds.\a")
Some Notes
Wand sweeps current thoughts:
Major Assumption: wandb.config comes from a .yaml that has a specific structure that doesn’t change (since the website needs this structure to set up the ui correctly)
- soln1: have a ScriptArguments dataclass that is same structure as wandb.config and merge it. The merging still needs to respect the wandb structure and custom HF args structure.
- this is under the assumption that wandb.config have specific structure that doesn’t change
- soln2: loop throught he wandb.config (dict) and create a string that looks like a sys.argv argument
-- {name}
and have HF argparse parse it and join it with the previous
structure (mdl, data, train) we specified for the args in the code.
run = wandb.init()
wandb.get_sweep_url()
sweep_config = run.config
# might need to change a little bit to respect the wandb_config structure
args: list[str] = [item for pair in [[f'--{k}', str(v)] for k, v in sweep_config.items()] for item in pair]
parser = HfArgumentParser((ModelArguments, DataArguments, CustomTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses(args=args)
# make sure the 3 or X args have the fields from the wandb_config
- this is under the assumption that wandb.config have specific structure that doesn’t change
- I’m also assuming that parse.parse_args_into_dataclasses(args) the will do the recursive matching of names I want
- soln3: recursively loop through the args generated from the HF parser and replace the values with the ones from wandb.config
- this is under the assumption that wandb.config have specific structure that doesn’t change
Decision is to keep it simple. Ideally we give a flag that says to either
- use the given arguments to the python cmd or
- use the wandb_config
I guess the easiest thing would be to do this:
→ Key Decision: if arg says config, then overwrite the args using config else don’t use the config.
refs: