K-Fold Cross-Validation in Machine Learning

A model can appear highly accurate during training but lose accuracy when it is tested with data it has not seen before. This situation is called overfitting. It means the model has adapted too closely to the training examples. Underfitting is the opposite case: the model performs poorly even on the training data.

Cross-validation helps evaluate whether a machine learning model is likely to work reliably on new data. The method is usually carried out in several stages:

Prerequisites

  • Basic Knowledge of Machine Learning – Familiarity with model training, evaluation metrics, and overfitting.
  • Python Programming Skills – Comfort with Python and libraries such as scikit-learn, numpy, and pandas.
  • Dataset Preparation – A cleaned and preprocessed dataset prepared for model training.
  • Scikit-Learn Installed – Install it with pip install scikit-learn if it is not already installed.
  • Understanding of Model Performance Metrics – Knowledge of accuracy, precision, recall, RMSE, and related metrics depending on the task.

Common Cross-Validation Methods

  • K-Fold Cross-Validation: The dataset is divided into k equal sections, and the model is trained k times, with a different fold serving as the validation set in each round.
  • Stratified K-Fold: This variation keeps the same class proportions in every fold for classification tasks. It is especially useful when the target variable is imbalanced, meaning the categorical classes are not distributed evenly.
  • Leave-One-Out (LOO): A single observation is used for validation while all remaining observations are used for training, and this is repeated for every instance.
  • Time-Series Cross-Validation: Designed for sequential datasets so that the training data always comes before the validation data.

Cross-validation is valuable for choosing the best model and hyperparameters while reducing the risk of overfitting.

In this guide, we’ll explore:

  • What K-Fold Cross-Validation is
  • How it differs from a traditional train-test split
  • A step-by-step implementation with scikit-learn
  • Advanced versions such as Stratified K-Fold, Group K-Fold, and Nested K-Fold
  • Ways to work with imbalanced datasets

What is K-Fold Cross-Validation?

K-Fold Cross-Validation is a resampling method used to assess machine learning models by dividing the dataset into K equally sized folds. The model is trained on K-1 folds and validated on the remaining fold. This process repeats K times, and the final performance score is calculated as the average of all runs.

Why Use K-Fold Cross-Validation?

  • Instead of relying on a single train-test split, K-Fold uses multiple splits, which lowers the variance of performance estimates. As a result, the model becomes better at predicting unseen data.
  • Every data point is used for both training and validation across the different rounds, making better use of the available dataset and resulting in a stronger evaluation.
  • Because validation is repeated on different portions of the data, it becomes easier to identify and reduce overfitting. This helps ensure that the model learns general patterns rather than memorizing training examples.
  • By averaging results from several folds, K-Fold Cross-Validation offers a more dependable estimate of the model’s actual performance by lowering both bias and variance.
  • K-Fold Cross-Validation is often paired with grid search and randomized search to tune hyperparameters without overfitting to only one train-test split.

K-Fold vs. Train-Test Split

Aspect K-Fold Cross-Validation Train-Test Split
Data Utilization The dataset is separated into multiple folds, giving every data point an opportunity to appear in both training and validation sets over different rounds. The dataset is divided once into fixed training and testing portions.
Bias-Variance Tradeoff Variance is reduced because the model is trained multiple times on different unseen sections of the data, helping achieve a better bias-variance balance. There is a higher possibility of variance with a simple train-test split. This often happens when the model fits the training data too closely and fails to generalize to the test data.
Overfitting Risk The risk of overfitting is lower because the model is evaluated across several folds. The risk of overfitting is higher if the single split does not represent the full dataset well.
Performance Evaluation Provides a more stable and generalized estimate of performance. Performance depends on only one train-test split, which can introduce bias.

Implementing K-Fold Cross-Validation in Python

Let’s apply K-Fold Cross-Validation using scikit-learn.

Step 1: Import Dependencies

We begin by importing the required libraries.

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold, cross_val_score, train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn import linear_model, tree, ensemble

Step 2: Load and Explore the Titanic Dataset

For this example, we use the Titanic dataset, a well-known dataset that makes it easier to understand how k-fold cross-validation works in practice.

df = pd.read_csv("https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv")
print(df.head(3))
print(df.info())

 

PassengerId  Survived  Pclass  \
0            1         0       3   
1            2         1       1   
2            3         1       3   

                                                Name     Sex   Age  SibSp  \
0                            Braund, Mr. Owen Harris    male  22.0      1   
1  Cumings, Mrs. John Bradley (Florence Briggs Th...  female  38.0      1   
2                             Heikkinen, Miss. Laina  female  26.0      0   

   Parch            Ticket     Fare Cabin Embarked  
0      0         A/5 21171   7.2500   NaN        S  
1      0          PC 17599  71.2833   C85        C  
2      0  STON/O2. 3101282   7.9250   NaN        S  
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  891 non-null    int64  
 1   Survived     891 non-null    int64  
 2   Pclass       891 non-null    int64  
 3   Name         891 non-null    object 
 4   Sex          891 non-null    object 
 5   Age          714 non-null    float64
 6   SibSp        891 non-null    int64  
 7   Parch        891 non-null    int64  
 8   Ticket       891 non-null    object 
 9   Fare         891 non-null    float64
 10  Cabin        204 non-null    object 
 11  Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
None

Step 3: Data Preprocessing

Next, it is good practice to perform data preprocessing and feature engineering before training any model.

df = df[['Survived', 'Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare']]  # Select relevant features
df.dropna(inplace=True)  # Remove missing values

# Encode categorical variable
label_encoder = LabelEncoder()
df['Sex'] = label_encoder.fit_transform(df['Sex'])

# Split features and target
X = df.drop(columns=['Survived'])
y = df['Survived']
df.shape
(714, 7)

Step 4: Define the K-Fold Split

kf = KFold(n_splits=5, shuffle=True, random_state=42)

In this example, n_splits=5 means the dataset is separated into five folds. The parameter shuffle=True introduces randomness into the split.

Step 5: Train and Evaluate the Model


model = RandomForestClassifier(n_estimators=100, random_state=42)
scores = cross_val_score(model, X, y, cv=kf, scoring='accuracy')
print(f'Cross-validation accuracy scores: {scores}')
print(f'Average Accuracy: {np.mean(scores):.4f}')
Cross-validation accuracy scores: [0.77622378 0.8041958 0.79020979 0.88111888 0.80985915] Average Accuracy: 0.8123
score = cross_val_score(tree.DecisionTreeClassifier(random_state= 42), X, y, cv= kf, scoring="accuracy")  
print(f'Scores for each fold are: {score}')  
print(f'Average score: {"{:.2f}".format(score.mean())}')
Scores for each fold are: [0.72727273 0.79020979 0.76923077 0.81818182 0.8028169] Average score: 0.78

Advanced Cross-Validation Techniques

1. Stratified K-Fold (For Imbalanced Datasets)

When working with imbalanced classes, Stratified K-Fold ensures that each fold preserves the same class distribution as the full dataset. This balanced distribution makes it a strong option for classification problems with unequal class frequencies.

from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
scores = cross_val_score(model, X, y, cv=skf, scoring='accuracy')
print(f'Average Accuracy (Stratified K-Fold): {np.mean(scores):.4f}')

Average Accuracy (Stratified K-Fold): 0.8124

2. Repeated K-Fold Cross-Validation

Repeated K-Fold applies K-Fold several times with different splits to reduce variance even more. This is commonly used when the dataset is relatively simple and models such as logistic regression can fit the data effectively.

from sklearn.model_selection import RepeatedKFold
rkf = RepeatedKFold(n_splits=5, n_repeats=10, random_state=42)
scores = cross_val_score(model, X, y, cv=rkf, scoring='accuracy')
print(f'Average Accuracy (Repeated K-Fold): {np.mean(scores):.4f}')

Average Accuracy (Repeated K-Fold): 0.8011

3. Nested K-Fold Cross-Validation (For Hyperparameter Tuning)

Nested K-Fold performs hyperparameter tuning in the inner loop while model performance is assessed in the outer loop, helping reduce overfitting during tuning.

from sklearn.model_selection import GridSearchCV, cross_val_score
param_grid = {'n_estimators': [50, 100, 150], 'max_depth': [None, 10, 20]}
gs = GridSearchCV(model, param_grid, cv=5)
scores = cross_val_score(gs, X, y, cv=5)
print(f'Average Accuracy (Nested K-Fold): {np.mean(scores):.4f}')

4. Group K-Fold (For Non-Independent Samples)

If the dataset contains groups, such as multiple images from the same patient, Group K-Fold makes sure samples from the same group are not split between training and validation. This is especially helpful for hierarchical data structures.

from sklearn.model_selection import GroupKFold
gkf = GroupKFold(n_splits=5)
groups = np.random.randint(0, 5, size=len(y))
scores = cross_val_score(model, X, y, cv=gkf, groups=groups, scoring='accuracy')
print(f'Average Accuracy (Group K-Fold): {np.mean(scores):.4f}')

FAQs

How to run K-Fold Cross-Validation in Python?

Use cross_val_score() from scikit-learn and pass KFold as the cv argument.

What’s the difference between K-Fold and Stratified K-Fold?

K-Fold splits the data randomly, while Stratified K-Fold keeps class proportions consistent in every fold.

How do I choose the right number of folds?

  • Using 5-fold or 10-fold validation is standard in most situations.
  • A higher number of folds, such as 20, can reduce bias but also increases computational cost.

What does the KFold class do in Python?

It separates the dataset into n_splits folds for repeated training and validation.

Conclusion

To make sure a machine learning model performs well on unseen data, cross-validation is an essential step. K-Fold cross-validation is one of the most effective methods for preventing the model from overfitting the training data while preserving the bias-variance balance. By partitioning the dataset into multiple folds and training and validating iteratively in each round, it provides a stronger estimate of how the model will behave when presented with unknown data.

In Python, K-Fold Cross-Validation is simple to implement with libraries such as scikit-learn, which provides tools like KFold and StratifiedKFold for dealing with imbalanced datasets. Adding K-Fold Cross-Validation to your workflow helps you tune hyperparameters more effectively, compare models more reliably, and improve generalization in real-world use cases.

Whether you are building regression, classification, or deep learning models, this validation strategy remains an essential part of machine learning pipelines.

Source: digitalocean.com

Create a Free Account

Register now and get access to our Cloud Services.

Posts you might be interested in: