Skip to main content
This guide provides recommendations on how to integrate W&B into your Python training script or notebook.

Original training script

Suppose you have a Python script that trains a model (see below). Your goal is to find the hyperparameters that maxmimizes the validation accuracy(val_acc). In your Python script, you define two functions: train_one_epoch and evaluate_one_epoch. The train_one_epoch function simulates training for one epoch and returns the training accuracy and loss. The evaluate_one_epoch function simulates evaluating the model on the validation data set and returns the validation accuracy and loss. You define a configuration dictionary (config) that contains hyperparameter values such as the learning rate (lr), batch size (batch_size), and number of epochs (epochs). The values in the configuration dictionary control the training process. Next you define a function called main that mimics a typical training loop. For each epoch, the accuracy and loss is computed on the training and validation data sets.
This code is a mock training script. It does not train a model, but simulates the training process by generating random accuracy and loss values. The purpose of this code is to demonstrate how to integrate W&B into your training script.
import random
import numpy as np

def train_one_epoch(epoch, lr, batch_size):
    acc = 0.25 + ((epoch / 30) + (random.random() / 10))
    loss = 0.2 + (1 - ((epoch - 1) / 10 + random.random() / 5))
    return acc, loss

def evaluate_one_epoch(epoch):
    acc = 0.1 + ((epoch / 20) + (random.random() / 10))
    loss = 0.25 + (1 - ((epoch - 1) / 10 + random.random() / 6))
    return acc, loss

# config variable with hyperparameter values
config = {"lr": 0.0001, "batch_size": 16, "epochs": 5}

def main():
    lr = config["lr"]
    batch_size = config["batch_size"]
    epochs = config["epochs"]

    for epoch in np.arange(1, epochs):
        train_acc, train_loss = train_one_epoch(epoch, lr, batch_size)
        val_acc, val_loss = evaluate_one_epoch(epoch)

        print("epoch: ", epoch)
        print("training accuracy:", train_acc, "training loss:", train_loss)
        print("validation accuracy:", val_acc, "validation loss:", val_loss)

if __name__ == "__main__":
    main()
In the next section, you will add W&B to your Python script to track hyperparameters and metrics during training. You want to use W&B to find the best hyperparameters that maximize the validation accuracy (val_acc).

Training script with W&B Python SDK

How you integrate W&B to your Python script or notebook depends on how you manage sweeps. You can start a sweep job within a Python notebook or script or from the command line.
  • Python script or notebook
  • CLI
Add the following to your Python script:
  1. Create a dictionary object where the key-value pairs define a sweep configuration. The sweep configuration defines the hyperparameters you want W&B to explore on your behalf along with the metric you want to optimize. Continuing from the previous example, the batch size (batch_size), epochs (epochs), and the learning rate (lr) are the hyperparameters to vary during each sweep. You want to maximize the accuracy of the validation score so you set "goal": "maximize" and the name of the variable you want to optimize for, in this case val_acc ("name": "val_acc").
  2. Pass the sweep configuration dictionary to wandb.sweep(). This initializes the sweep and returns a sweep ID (sweep_id). For more information, see Initialize sweeps.
  3. At the top of your script, import the W&B Python SDK (wandb).
  4. Within your main function, use the wandb.init() API to generate a background process to sync and log data as a W&B Run. Pass the project name as a parameter to the wandb.init() method. If you do not pass a project name, W&B uses the default project name.
  5. Fetch the hyperparameter values from the wandb.Run.config object. This allows you to use the hyperparameter values defined in the sweep configuration dictionary instead of hard coded values.
  6. Log the metric you are optimizing for to W&B using wandb.Run.log(). You must log the metric defined in your configuration. For example, if you define the metric to optimize as val_acc, you must log val_acc. If you do not log the metric, W&B does not know what to optimize for. Within the configuration dictionary (sweep_configuration in this example), you define the sweep to maximize the val_acc value.
  7. Start the sweep with wandb.agent(). Provide the sweep ID and the name of the function the sweep will execute (function=main), and specify the maximum number of runs to try to four (count=4).
Putting this all together, your script might look similar to the following:
import wandb # Import the W&B Python SDK
import numpy as np
import random
import argparse

def train_one_epoch(epoch, lr, batch_size):
    acc = 0.25 + ((epoch / 30) + (random.random() / 10))
    loss = 0.2 + (1 - ((epoch - 1) / 10 + random.random() / 5))
    return acc, loss

def evaluate_one_epoch(epoch):
    acc = 0.1 + ((epoch / 20) + (random.random() / 10))
    loss = 0.25 + (1 - ((epoch - 1) / 10 + random.random() / 6))
    return acc, loss

def main(args=None):
    # When called by sweep agent, args will be None,
    # so we use the project from sweep config
    project = args.project if args else None
    
    with wandb.init(project=project) as run:
        # Fetches the hyperparameter values from `wandb.Run.config` object
        lr = run.config["lr"]
        batch_size = run.config["batch_size"]
        epochs = run.config["epochs"]

        # Execute the training loop and log the performance values to W&B
        for epoch in np.arange(1, epochs):
            train_acc, train_loss = train_one_epoch(epoch, lr, batch_size)
            val_acc, val_loss = evaluate_one_epoch(epoch)
            run.log(
                {
                    "epoch": epoch,
                    "train_acc": train_acc,
                    "train_loss": train_loss,
                    "val_acc": val_acc, # Metric optimized
                    "val_loss": val_loss,
                }
            )

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--project", type=str, default="sweep-example", help="W&B project name")
    args = parser.parse_args()

    # Define a sweep config dictionary
    sweep_configuration = {
        "method": "random",
        "name": "sweep",
        # Metric that you want to optimize
        # For example, if you want to maximize validation
        # accuracy set "goal": "maximize" and the name of the variable 
        # you want to optimize for, in this case "val_acc"
        "metric": {
            "goal": "maximize",
            "name": "val_acc"
            },
        "parameters": {
            "batch_size": {"values": [16, 32, 64]},
            "epochs": {"values": [5, 10, 15]},
            "lr": {"max": 0.1, "min": 0.0001},
        },
    }

    # Initialize the sweep by passing in the config dictionary
    sweep_id = wandb.sweep(sweep=sweep_configuration, project=args.project)

    # Start the sweep job
    wandb.agent(sweep_id, function=main, count=4)
Logging metrics to W&B in a sweepYou must log the metric you define and are optimizing for in both your sweep configuration and with wandb.Run.log(). For example, if you define the metric to optimize as val_acc within your sweep configuration, you must also log val_acc to W&B. If you do not log the metric, W&B does not know what to optimize for.
with wandb.init() as run:
    val_loss, val_acc = train()
    run.log(
        {
            "val_loss": val_loss,
            "val_acc": val_acc
            }
        )
The following is an incorrect example of logging the metric to W&B. The metric that is optimized for in the sweep configuration is val_acc, but the code logs val_acc within a nested dictionary under the key validation. You must log the metric directly, not within a nested dictionary.
with wandb.init() as run:
    val_loss, val_acc = train()
    run.log(
        {
            "validation": {
                "val_loss": val_loss, 
                "val_acc": val_acc
                }
            }
        )
I