from ray.air.integrations.wandb import setup_wandb
def train_fn(config):
# Initialize wandb
wandb = setup_wandb(config)
run = wandb.init(
project=config["wandb"]["project"],
api_key_file=config["wandb"]["api_key_file"],
)
for i in range(10):
loss = config["a"] + config["b"]
run.log({"loss": loss})
tune.report(loss=loss)
run.finish()
tuner = tune.Tuner(
train_fn,
param_space={
# define search space here
"a": tune.choice([1, 2, 3]),
"b": tune.choice([4, 5, 6]),
# wandb configuration
"wandb": {"project": "Optimization_Project", "api_key_file": "/path/to/file"},
},
)
results = tuner.fit()