Learning curve

This tutorial demonstrates the convergence behavior of three common optimization algorithms with respect to training set size. Specifically, it considers standard least-squares regression, LASSO and automated relevance detection regression (ARDR).

After the generation of several reference structures, a fourth-order model for EMT Ni is constructed. Subsequently, the CrossValidationEstimator is used to obtain the variation of the cross validation (CV) score with training set size and the results are plotted.


  • The performance of the Lasso and ARDR can vary depending on the hyperparameters, especially in the underdetermined regime.

  • The CV score is an estimate of the predictive power of a model. Different training set choices will lead to slightly different results.


Please note that calling functions that rely on the generation of pseudo- random numbers repeatedly with the same seed (i.e., repeatedly falling back to the default value) is strongly discouraged as it will lead to correlation. To circumvent this problem one can for example seed a sequence of random numbers and then use these numbers in turn as seeds.


Source code

The structure container is built in

Create and save a StructureContainer to file
from ase.build import bulk
from ase.calculators.emt import EMT
from hiphive import ClusterSpace, StructureContainer
from hiphive.structure_generation import generate_mc_rattled_structures
from hiphive.utilities import prepare_structures

# parameters
cutoffs = [6.5, 5.0, 4.0]
cell_size = 5
number_of_structures = 3
rattle_std = 0.03
minimum_distance = 2.3

# setup
atoms_ideal = bulk('Ni').repeat(cell_size)
calc = EMT()

# generate structures
structures = generate_mc_rattled_structures(
    atoms_ideal, number_of_structures, rattle_std, minimum_distance)
prepare_structures(structures, atoms_ideal, calc)

# set up cluster space and structure container
cs = ClusterSpace(structures[0], cutoffs)
sc = StructureContainer(cs)
for structure in structures:

The learning curves are constructed and plotted in

Construct learning curves and plot them.

This scrip takes a few minutes to run, n_splits and train_sizes can be adjusted
in order to run quicker

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import defaultdict
from hiphive import StructureContainer
from hiphive.fitting import CrossValidationEstimator

# parameters
fit_methods = ['lasso', 'ardr', 'least-squares']
train_sizes = np.linspace(100, 600, 15).astype(int)
n_splits = 5

# fit kwargs
fit_kwargs = defaultdict(dict)
fit_kwargs['lasso'] = dict(max_iter=5000)

# read sc
sc = StructureContainer.read('structure_container.sc')
n_rows, n_cols = sc.data_shape

# run learning curves
columns = ['train_size', 'rmse_ave', 'rmse_std']
learning_curves = dict()
for fit_method in fit_methods:
    df = pd.DataFrame(columns=columns)
    for train_size in train_sizes:
        train_fraction = train_size / n_rows
        test_fraction = 1.0 - train_fraction - 1e-10
        cve = CrossValidationEstimator(sc.get_fit_data(),
        row = dict(train_size=train_size,
        print(fit_method, row)
        df = df.append(row, ignore_index=True)
    learning_curves[fit_method] = df

# plotting
lw = 2.0
ms = 6
fs = 14
alpha = 0.5
xlim = [min(train_sizes)-20, max(train_sizes)+20]
ylim = [0.014, 1.0]
colors = {fit_method: color for fit_method, color in zip(
    fit_methods, plt.rcParams['axes.prop_cycle'].by_key()['color'])}

fig = plt.figure(figsize=(8, 5))
ax = plt.gca()

for fit_method, df in learning_curves.items():
    col = colors[fit_method]
    ax.semilogy(df.train_size, df.rmse_ave, '-o', color=col, lw=lw, ms=ms,
    ax.fill_between(df.train_size, df.rmse_ave - df.rmse_std,
                    df.rmse_ave + df.rmse_std, color=col, alpha=alpha)

ax.set_xlabel('Training size (# force components)', fontsize=fs)
ax.set_ylabel('RMSE (eV/Å)', fontsize=fs)

ax.axvspan(0.0, n_cols, alpha=0.1, color='k')
ax.text(0.03, 0.93, 'under-determined', transform=ax.transAxes, fontsize=fs)
ax.text(0.35, 0.93, 'over-determined', transform=ax.transAxes, fontsize=fs)