Skip to main content
Custom charts in W&B are programmable through a group of functions in the wandb.plot namespace. These functions create interactive visualizations in W&B project dashboards, and support common ML visualizations such as confusion matrices, ROC curves, and distribution plots.

Available Chart Functions

FunctionDescription
confusion_matrix()Generate confusion matrices for classification performance visualization.
roc_curve()Create Receiver Operating Characteristic curves for binary and multi-class classifiers.
pr_curve()Build Precision-Recall curves for classifier evaluation.
line()Construct line charts from tabular data.
scatter()Create scatter plots for variable relationships.
bar()Generate bar charts for categorical data.
histogram()Build histograms for data distribution analysis.
line_series()Plot multiple line series on a single chart.
plot_table()Create custom charts using Vega-Lite specifications.

Common Use Cases

Model Evaluation

  • Classification: confusion_matrix(), roc_curve(), and pr_curve() for classifier evaluation
  • Regression: scatter() for prediction vs. actual plots and histogram() for residual analysis
  • Vega-Lite Charts: plot_table() for domain-specific visualizations

Training Monitoring

  • Learning Curves: line() or line_series() for tracking metrics over epochs
  • Hyperparameter Comparison: bar() charts for comparing configurations

Data Analysis

  • Distribution Analysis: histogram() for feature distributions
  • Correlation Analysis: scatter() plots for variable relationships

Getting Started

Log a confusion matrix

import wandb

y_true = [0, 1, 2, 0, 1, 2]
y_pred = [0, 2, 2, 0, 1, 1]
class_names = ["class_0", "class_1", "class_2"]

# Initialize a run
with wandb.init(project="custom-charts-demo") as run:
    run.log({
        "conf_mat": wandb.plot.confusion_matrix(
            y_true=y_true, 
            preds=y_pred,
            class_names=class_names
        )
    })

Build a scatter plot for feature analysis

import numpy as np

# Generate synthetic data
data_table = wandb.Table(columns=["feature_1", "feature_2", "label"])

with wandb.init(project="custom-charts-demo") as run:

    for _ in range(100):
        data_table.add_data(
            np.random.randn(), 
            np.random.randn(), 
            np.random.choice(["A", "B"])
        )

    run.log({
        "feature_scatter": wandb.plot.scatter(
            data_table, x="feature_1", y="feature_2",
            title="Feature Distribution"
        )
    })

I