Supervised Learning with PyTorch: A Classification Demo#

Introduction#

What is this example is about ?#

In this example, you will use supervised learning to train a neural network. The objective is to keep the workflow simple enough to run quickly on a CPU while still using either a standard real dataset or a simple toy example. Ok, but what is a neural network ?

What is an artifical neural network ?#

import numpy as np
import plotly.graph_objects as go
import torch
from plotly.subplots import make_subplots
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from IPython.display import HTML, YouTubeVideo, clear_output, display


def show_plotly(fig):
    """Render Plotly figures as HTML so Jupyter Book keeps them in the page."""
    display(HTML(fig.to_html(full_html=False, include_plotlyjs="cdn")))


torch.manual_seed(10)
<torch._C.Generator at 0xffff1efc2850>
YouTubeVideo("aircAruvnKk")

Why do we need data to learn ?#

Learning, in machine learning, means building a model that can detect a relationship between inputs and outputs from examples rather than from an explicit hand-written rule. The model is adjusted progressively so that its predictions become more consistent with the examples it is shown. We need data because the model does not invent the rule by itself. The data provide the examples from which it estimates the structure of the problem. Without data, there is nothing to compare the predictions to, so there is no way to improve the model or to check whether it has actually learned anything useful.

Ok, we need data !

Ok how should we use data ?#

In machine learning, learning means adjusting the parameters of a model so that it can transform an input into a useful prediction. In supervised learning, each example contains both an input and the expected output. The model starts with arbitrary parameters, makes a prediction, compares that prediction with the expected answer, and then modifies its parameters to reduce the error. Repeating this process over many examples gradually produces a model that captures part of the relationship between the inputs and the outputs.

Data are necessary because the model does not know the rule in advance. The only way it can improve is by being exposed to examples. The data play two roles at once: they show what kinds of inputs exist, and they provide the correct answers that tell the model whether its current prediction is good or bad. In that sense, the data are both the source of information and the reference used for correction.

The reason we split the data into a training set and a test set is that good performance on already-seen examples is not enough. A model can become very good at reproducing the data it was trained on without having learned a rule that works more generally. The training set is used to modify the model. The test set is kept aside and not used during learning. After training, we evaluate the model on the test set to see whether it can also make correct predictions on new examples. If it performs well on both sets, we say that it generalizes reasonably well. If it performs well only on the training set, then it has likely overfit the data.

So the distinction is essential:

  • the training set is used to learn

  • the test set is used to evaluate whether learning is genuinely useful

Without that separation, we could not tell whether the model has actually understood a pattern or has simply memorized the examples it has already seen.

Ok, so now we need datasets !

Datasets#

This notebook can be used in two ways. You can work with a standard real dataset (iris) or with a toy dataset generated from an exact classification function (toy_easy or toy_hard).

In every case, we work with a binary classification problem in two input variables.

The real dataset is useful for showing that the method applies to real observations. The toy datasets are useful because they provide an exact target function that can be compared directly with the learned model.

The functions defined below all play the same role: they produce a set of input points in 2D and a binary label for each point. This common format allows the rest of the notebook to stay unchanged when you switch from one dataset to another.

Here is some code to declare the datasets, don’t modify it !

# Toy classification functions.
# Each function receives a set of points with coordinates (x, y)
# and returns a binary label in {0, 1}.
def func_easy(inp):
    """Simple toy classification function."""
    xc1, yc1 = 0.3, 0.6
    xc2, yc2 = 0.55, 0.15
    r1, r2 = 0.25, 0.15
    X, Y = np.asarray(inp).T
    R12 = (X - xc1) ** 2 + (Y - yc1) ** 2
    R22 = (X - xc2) ** 2 + (Y - yc2) ** 2
    return (((R22 <= r2**2) | (R12 <= r1**2)) * 1).astype(np.float32)


def func_hard(inp):
    """More oscillatory toy classification function."""
    x, y = np.asarray(inp).T - 0.5
    r = np.sqrt(x**2 + y**2)
    theta = np.arctan2(y, x)
    out = (np.cos(3 * theta + 10 * r) >= 0.0).astype(np.float32)
    out[r <= 0.1] = 1.0
    return out


def func_xor(inp):
    """XOR-like pattern based on the four quadrants."""
    x, y = np.asarray(inp).T
    out = np.logical_xor(x >= 0.5, y >= 0.5).astype(np.float32)
    return out


def func_cross(inp, width=0.12):
    """Cross pattern: a vertical or horizontal band belongs to class 1."""
    x, y = np.asarray(inp).T
    out = ((np.abs(x - 0.5) <= width) | (np.abs(y - 0.5) <= width)).astype(np.float32)
    return out


def func_ring(inp, r_inner=0.18, r_outer=0.32):
    """Ring pattern: points in an annulus belong to class 1."""
    x, y = np.asarray(inp).T - 0.5
    r = np.sqrt(x**2 + y**2)
    out = ((r >= r_inner) & (r <= r_outer)).astype(np.float32)
    return out


# Dataset loaders.
# They all return the same objects so that the rest of the notebook
# does not depend on the specific dataset being used.
def load_iris_dataset():
    iris = load_iris()
    feature_ids = [2, 3]
    feature_names = [iris.feature_names[i] for i in feature_ids]
    class_names = ["versicolor", "virginica"]
    mask = iris.target >= 1
    points = iris.data[mask][:, feature_ids].astype(np.float32)
    labels = (iris.target[mask] == 2).astype(np.float32)
    # Normalize the two retained features to the [0, 1] interval.
    mins = points.min(axis=0)
    maxs = points.max(axis=0)
    points = (points - mins) / (maxs - mins)
    axis_label_suffix = " (normalized)"
    exact_model = None
    dataset_title = "Iris dataset"
    return (
        points,
        labels,
        feature_names,
        class_names,
        axis_label_suffix,
        exact_model,
        dataset_title,
    )


def load_toy_dataset(exact_model, dataset_title, n_samples=4000, seed=2):
    # Sample points uniformly in the unit square and label them
    # with the chosen exact classification function.
    rng = np.random.default_rng(seed)
    points = rng.random((n_samples, 2), dtype=np.float32)
    labels = exact_model(points).astype(np.float32)
    feature_names = ["x", "y"]
    class_names = ["class 0", "class 1"]
    axis_label_suffix = ""
    return (
        points,
        labels,
        feature_names,
        class_names,
        axis_label_suffix,
        exact_model,
        dataset_title,
    )

Chose the dataset here:

# Choose the dataset here. The rest of the notebook will adapt automatically.
dataset_name = (
    "toy_easy"  # "toy_easy", "toy_hard", "toy_xor", "toy_cross", "toy_ring" or "iris"
)
if dataset_name == "iris":
    (
        pointsa,
        solla,
        feature_names,
        class_names,
        axis_label_suffix,
        exact_model,
        dataset_title,
    ) = load_iris_dataset()
elif dataset_name == "toy_easy":
    (
        pointsa,
        solla,
        feature_names,
        class_names,
        axis_label_suffix,
        exact_model,
        dataset_title,
    ) = load_toy_dataset(func_easy, "Toy dataset: easy")
elif dataset_name == "toy_hard":
    (
        pointsa,
        solla,
        feature_names,
        class_names,
        axis_label_suffix,
        exact_model,
        dataset_title,
    ) = load_toy_dataset(func_hard, "Toy dataset: hard")
elif dataset_name == "toy_xor":
    (
        pointsa,
        solla,
        feature_names,
        class_names,
        axis_label_suffix,
        exact_model,
        dataset_title,
    ) = load_toy_dataset(func_xor, "Toy dataset: XOR")
elif dataset_name == "toy_cross":
    (
        pointsa,
        solla,
        feature_names,
        class_names,
        axis_label_suffix,
        exact_model,
        dataset_title,
    ) = load_toy_dataset(func_cross, "Toy dataset: cross pattern")
elif dataset_name == "toy_ring":
    (
        pointsa,
        solla,
        feature_names,
        class_names,
        axis_label_suffix,
        exact_model,
        dataset_title,
    ) = load_toy_dataset(func_ring, "Toy dataset: ring pattern")
else:
    raise ValueError(f"Unknown dataset_name: {dataset_name}")

pointsa.shape, solla.shape, dataset_title
((4000, 2), (4000,), 'Toy dataset: easy')

We first plot the selected dataset and create a regular grid in the feature space. This grid will later be used to visualize the decision map learned by the neural network.

The grid is not used for learning. It is only used for visualization: after training, we ask the network to predict a class probability at every point of the grid in order to draw the learned decision regions.

# Build a regular 2D grid that will later be used only for visualization.
nxm, nym = 200, 200
xm = np.linspace(0.0, 1.0, nxm)
ym = np.linspace(0.0, 1.0, nym)
Xm, Ym = np.meshgrid(xm, ym)
pointsm = np.array([Xm.flatten(), Ym.flatten()]).T.astype(np.float32)
pointsm.shape
(40000, 2)
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=pointsa[solla == 1][:, 0],
        y=pointsa[solla == 1][:, 1],
        mode="markers",
        marker=dict(size=5, color="firebrick"),
        name=class_names[1],
    )
)
fig.add_trace(
    go.Scatter(
        x=pointsa[solla == 0][:, 0],
        y=pointsa[solla == 0][:, 1],
        mode="markers",
        marker=dict(size=5, color="royalblue"),
        name=class_names[0],
    )
)
fig.update_layout(
    template="plotly_white",
    xaxis_title=feature_names[0] + axis_label_suffix,
    yaxis_title=feature_names[1] + axis_label_suffix,
    legend_title_text="Class",
    width=700,
    height=500,
)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
show_plotly(fig)

Training and Test Sets#

In a real supervised learning problem, we do not train on all available observations. Instead, we split the dataset into a training subset and a test subset.

The training subset is used to optimize the neural network weights. The test subset is used only to evaluate whether the learned model generalizes to data that were not seen during training.

This distinction is essential. A model may perform very well on the training set simply because it has adapted to the examples it has seen. The test set provides a more honest estimate of how the model behaves on new data.

# Keep part of the data for learning and part for evaluation.
# The option stratify=solla preserves the class proportions in both subsets.
pointsl, pointst, soll, solt = train_test_split(
    pointsa,
    solla,
    test_size=0.25,
    random_state=2,
    stratify=solla,
)

pointsl.shape, pointst.shape
((3000, 2), (1000, 2))
fig = make_subplots(rows=1, cols=2, subplot_titles=("Training set", "Test set"))

for col, points, labels in [(1, pointsl, soll), (2, pointst, solt)]:
    fig.add_trace(
        go.Scatter(
            x=points[labels == 1][:, 0],
            y=points[labels == 1][:, 1],
            mode="markers",
            marker=dict(size=7, color="firebrick"),
            name=class_names[1],
            legendgroup=class_names[1],
            showlegend=(col == 1),
        ),
        row=1,
        col=col,
    )
    fig.add_trace(
        go.Scatter(
            x=points[labels == 0][:, 0],
            y=points[labels == 0][:, 1],
            mode="markers",
            marker=dict(size=7, color="royalblue"),
            name=class_names[0],
            legendgroup=class_names[0],
            showlegend=(col == 1),
        ),
        row=1,
        col=col,
    )

fig.update_layout(template="plotly_white", width=950, height=450)
fig.update_xaxes(title_text=feature_names[0] + axis_label_suffix, row=1, col=1)
fig.update_xaxes(title_text=feature_names[0] + axis_label_suffix, row=1, col=2)
fig.update_yaxes(
    title_text=feature_names[1] + axis_label_suffix,
    scaleanchor="x",
    scaleratio=1,
    row=1,
    col=1,
)
fig.update_yaxes(
    title_text=feature_names[1] + axis_label_suffix,
    scaleanchor="x2",
    scaleratio=1,
    row=1,
    col=2,
)
show_plotly(fig)

Neural Network Model#

A neural network is a sequence of alternating layers. Linear layers apply affine transformations to the data, and nonlinear layers apply activation functions. Training consists of optimizing the weights in the linear layers.

During inference, data move from left to right through the network. During training, backpropagation computes how each weight contributes to the error.

The network below outputs a single value called a logit. We convert this value into a probability with a sigmoid function only when we want to interpret the prediction.

The loss used during training is the binary cross-entropy. It does not have a direct physical meaning. It is a statistical quantity that becomes small when the network assigns a high probability to the correct class, and large when it is confidently wrong.

In practice, accuracy is easier to interpret, while the loss is more useful for training because it varies smoothly and provides a usable gradient for optimization.

Required work#

  1. Run the code as provided and compare the training accuracy with the test accuracy. Are they both satisfactory?

  2. Try to improve the result by changing the network structure, the number of hidden neurons, the learning rate, or the number of training epochs. Summarize your conclusions.

  3. Try different activation functions and compare their behavior.

  4. Switch between toy_easy, toy_hard, toy_xor, toy_cross, and toy_ring, iris. Which dataset is the easiest to learn? Which one is the most difficult?

  5. If you use iris, replace the selected features with another pair of Iris features. Which pairs are easy to separate? Which ones are more difficult?

device = "cpu"
errors = []
steps = []
train_accuracies = []
test_accuracies = []

# BUILD LAYERS WITH:
# EXAMPLE OF A LINEAR LAYER:
torch.nn.Linear(2, 8, bias=True)
# EXAMPLES OF ACTIVATION FUNCTIONS:
torch.nn.ELU, torch.nn.ReLU, torch.nn.Tanh
activation_func = torch.nn.ReLU

# NETWORK DEFINITION
# The input has size 2 because each sample is a point (x, y).
# The output has size 1 because this is a binary classification problem.
# Hidden layers give the network the flexibility needed to learn nonlinear boundaries.
layers = [
    torch.nn.Linear(2, 4, bias=True),
    activation_func(),
    torch.nn.Linear(4, 1, bias=True),
]
model = torch.nn.Sequential(*layers).to(device)
layers
[Linear(in_features=2, out_features=4, bias=True),
 ReLU(),
 Linear(in_features=4, out_features=1, bias=True)]
# TRAINING INPUTS / OUTPUTS
# YOU CAN RERUN THIS CELL TO CONTINUE TRAINING THE SAME MODEL
run_training = True  # SET TO True TO RUN TRAINING
if run_training:
    # Convert NumPy arrays into PyTorch tensors.
    # x and t are used for learning; x_test and t_test are used only for evaluation.
    x = torch.tensor(pointsl, dtype=torch.float32).to(device)
    t = torch.tensor(soll[:, None], dtype=torch.float32).to(device)
    x_test = torch.tensor(pointst, dtype=torch.float32).to(device)
    t_test = torch.tensor(solt[:, None], dtype=torch.float32).to(device)

    # Adam updates the model parameters using the gradients of the loss.
    # BCEWithLogitsLoss combines a sigmoid and a binary cross-entropy in a numerically stable way.
    learning_rate = 5e-3
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = torch.nn.BCEWithLogitsLoss().to(device)
    Ne = 200  # Number of training epochs
    error = np.zeros(Ne)
    train_accuracy = np.zeros(Ne)
    test_accuracy = np.zeros(Ne)
    step = np.arange(Ne)
    if len(steps) != 0:
        step += steps[-1].max() + 1

    for e in range(Ne):
        # Forward pass: compute the network output and the loss on the training set.
        logits = model(x)
        loss = loss_fn(logits, t)

        # Backward pass: compute gradients and update the weights.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Convert logits into probabilities, then into binary decisions with a 0.5 threshold.
        with torch.no_grad():
            train_proba = torch.sigmoid(model(x))
            test_proba = torch.sigmoid(model(x_test))
            train_pred = (train_proba >= 0.5).float()
            test_pred = (test_proba >= 0.5).float()

        error[e] = loss.item()
        train_accuracy[e] = (train_pred == t).float().mean().item()
        test_accuracy[e] = (test_pred == t_test).float().mean().item()
        status = (
            f"Epoch {e + 1:02d}/{Ne:02d} | loss = {error[e]:.4f} | "
            f"train acc = {100 * train_accuracy[e]:5.1f}% | "
            f"test acc = {100 * test_accuracy[e]:5.1f}%"
        )
        clear_output(wait=True)
        print(status)

    errors.append(error)
    steps.append(step)
    train_accuracies.append(train_accuracy)
    test_accuracies.append(test_accuracy)

# POST-PROCESSING
if run_training:
    fig = make_subplots(
        rows=1,
        cols=2,
        subplot_titles=("Training Loss", "Accuracy on Training and Test Sets"),
    )

    for i, err in enumerate(errors):
        fig.add_trace(
            go.Scatter(
                x=steps[i],
                y=err,
                mode="lines+markers",
                name=f"Loss run {i + 1}",
            ),
            row=1,
            col=1,
        )

    for i, acc in enumerate(train_accuracies):
        fig.add_trace(
            go.Scatter(
                x=steps[i],
                y=100 * acc,
                mode="lines+markers",
                name=f"Train run {i + 1}",
            ),
            row=1,
            col=2,
        )
        fig.add_trace(
            go.Scatter(
                x=steps[i],
                y=100 * test_accuracies[i],
                mode="lines+markers",
                line=dict(dash="dash"),
                name=f"Test run {i + 1}",
            ),
            row=1,
            col=2,
        )

    fig.update_layout(template="plotly_white", width=1000, height=450)
    fig.update_xaxes(title_text="Epoch", row=1, col=1)
    fig.update_xaxes(title_text="Epoch", row=1, col=2)
    fig.update_yaxes(title_text="Binary cross-entropy", row=1, col=1)
    fig.update_yaxes(title_text="Accuracy [%]", range=[0, 105], row=1, col=2)
    show_plotly(fig)
Epoch 200/200 | loss = 0.5669 | train acc =  73.3% | test acc =  73.3%
if run_training:
    with torch.no_grad():
        yt = (
            torch.sigmoid(model(torch.tensor(pointsm, dtype=torch.float32).to(device)))
            .cpu()
            .numpy()
            .ravel()
        )
        train_proba = (
            torch.sigmoid(model(torch.tensor(pointsl, dtype=torch.float32).to(device)))
            .cpu()
            .numpy()
            .ravel()
        )
        test_proba = (
            torch.sigmoid(model(torch.tensor(pointst, dtype=torch.float32).to(device)))
            .cpu()
            .numpy()
            .ravel()
        )

    train_pred = (train_proba >= 0.5).astype(int)
    test_pred = (test_proba >= 0.5).astype(int)
    train_ok = train_pred == soll.astype(int)
    test_ok = test_pred == solt.astype(int)
    Z = yt.reshape(nym, nxm)
    exact_map = (
        exact_model(pointsm).reshape(nym, nxm) if exact_model is not None else None
    )
    region_map = (Z >= 0.5).astype(float)
    region_colorscale = [
        [0.0, "rgba(65, 105, 225, 0.18)"],
        [1.0, "rgba(178, 34, 34, 0.18)"],
    ]
    class_specs = [
        (0, "royalblue", class_names[0]),
        (1, "firebrick", class_names[1]),
    ]

    # Figure 1: what the dataset looks like and which points are used for learning.
    fig1 = make_subplots(
        rows=1,
        cols=2,
        subplot_titles=(
            "Exact model" if exact_map is not None else "Full dataset",
            "Training set used for learning",
        ),
    )
    if exact_map is not None:
        fig1.add_trace(
            go.Heatmap(
                x=xm,
                y=ym,
                z=exact_map,
                colorscale=region_colorscale,
                showscale=False,
                opacity=0.30,
                hoverinfo="skip",
            ),
            row=1,
            col=1,
        )

    for class_value, class_color, class_label in class_specs:
        fig1.add_trace(
            go.Scatter(
                x=pointsa[solla == class_value][:, 0],
                y=pointsa[solla == class_value][:, 1],
                mode="markers",
                marker=dict(size=7, color=class_color),
                name=class_label,
                legendgroup=class_label,
                showlegend=True,
            ),
            row=1,
            col=1,
        )
        fig1.add_trace(
            go.Scatter(
                x=pointsl[soll == class_value][:, 0],
                y=pointsl[soll == class_value][:, 1],
                mode="markers",
                marker=dict(size=7, color=class_color),
                name=class_label,
                legendgroup=class_label,
                showlegend=False,
            ),
            row=1,
            col=2,
        )

    fig1.update_layout(
        template="plotly_white",
        title_text=f"Dataset overview: {dataset_title}",
        width=1100,
        height=500,
    )
    fig1.update_xaxes(title_text=feature_names[0] + axis_label_suffix, row=1, col=1)
    fig1.update_xaxes(title_text=feature_names[0] + axis_label_suffix, row=1, col=2)
    fig1.update_yaxes(
        title_text=feature_names[1] + axis_label_suffix,
        scaleanchor="x",
        scaleratio=1,
        row=1,
        col=1,
    )
    fig1.update_yaxes(
        title_text=feature_names[1] + axis_label_suffix,
        scaleanchor="x2",
        scaleratio=1,
        row=1,
        col=2,
    )
    show_plotly(fig1)

    # Figure 2: where the model succeeds and fails on the training and test sets.
    fig2 = make_subplots(
        rows=1,
        cols=2,
        subplot_titles=(
            f"Training set after learning ({100 * train_ok.mean():.1f}% correct)",
            f"Test set: unseen data ({100 * test_ok.mean():.1f}% correct)",
        ),
    )

    for col in [1, 2]:
        fig2.add_trace(
            go.Heatmap(
                x=xm,
                y=ym,
                z=region_map,
                colorscale=region_colorscale,
                showscale=False,
                opacity=0.25,
                hoverinfo="skip",
            ),
            row=1,
            col=col,
        )
        fig2.add_trace(
            go.Contour(
                x=xm,
                y=ym,
                z=Z,
                contours=dict(start=0.5, end=0.5, size=1, coloring="lines"),
                line=dict(color="black", width=2),
                showscale=False,
                hoverinfo="skip",
            ),
            row=1,
            col=col,
        )

    for col, points, labels, ok_mask in [
        (1, pointsl, soll.astype(int), train_ok),
        (2, pointst, solt.astype(int), test_ok),
    ]:
        for class_value, class_color, class_label in class_specs:
            mask = labels == class_value
            correct_mask = mask & ok_mask
            wrong_mask = mask & (~ok_mask)
            fig2.add_trace(
                go.Scatter(
                    x=points[correct_mask, 0],
                    y=points[correct_mask, 1],
                    mode="markers",
                    marker=dict(
                        symbol="circle",
                        size=8,
                        color=class_color,
                        line=dict(color="black", width=1),
                    ),
                    name=f"{class_label}: correct",
                    legendgroup=f"{class_label}-correct",
                    showlegend=(col == 1),
                ),
                row=1,
                col=col,
            )
            if wrong_mask.any():
                fig2.add_trace(
                    go.Scatter(
                        x=points[wrong_mask, 0],
                        y=points[wrong_mask, 1],
                        mode="markers",
                        marker=dict(symbol="x", size=10, color=class_color),
                        name=f"{class_label}: incorrect",
                        legendgroup=f"{class_label}-incorrect",
                        showlegend=(col == 1),
                    ),
                    row=1,
                    col=col,
                )

    fig2.update_layout(
        template="plotly_white",
        title_text="Model evaluation: success and failure",
        width=1100,
        height=500,
    )
    fig2.update_xaxes(title_text=feature_names[0] + axis_label_suffix, row=1, col=1)
    fig2.update_xaxes(title_text=feature_names[0] + axis_label_suffix, row=1, col=2)
    fig2.update_yaxes(
        title_text=feature_names[1] + axis_label_suffix,
        scaleanchor="x",
        scaleratio=1,
        row=1,
        col=1,
    )
    fig2.update_yaxes(
        title_text=feature_names[1] + axis_label_suffix,
        scaleanchor="x2",
        scaleratio=1,
        row=1,
        col=2,
    )
    show_plotly(fig2)