Stroke Prediction

Data Analysis Project


Vilmantas Gėgžna





Stroke Risk Prediction project logo. Originally generated with Leonardo.Ai.

Data analysis tools: Python
Helper tools: VS Code, Quarto, Git

Technical requirements:


  • AUC: area under the ROC curve;
  • BAcc: balanced accuracy score;
  • CI: confidence interval;
  • EDA: exploratory data analysis;
  • et al.: and others (Latin);
  • F1: F1 score;
  • F1_neg: F1 score for the negative class;
  • FE: feature engineering;
  • Kappa: Cohen’s kappa;
  • kNN: k-nearest neighbors;
  • ML: machine learning;
  • NA: missing values (“not available”);
  • NB: naive Bayes;
  • NPV: negative predictive value;
  • PPV: positive predictive value (precision);
  • RF: random forest;
  • ROC: receiver operating characteristic;
  • SFS: sequential feature selection;
  • SHAP: SHapley Additive exPlanations;
  • SVC: support vector machine for classification;
  • TNR: true negative rate;
  • TPR: true positive rate (recall);
  • VE: virtual environment;
  • WD: working directory.


According to the World Health Organization (WHO), stroke is the leading cause of disability and the second leading cause of death on a global scale. Consequently, tools aimed at anticipating it in advance hold significant potential for stroke prevention. This project was devoted to the thorough analysis of a stroke prediction dataset. It encompassed exploratory data analysis, feature engineering, and the application of predictive modeling techniques to delve deeper into this critical issue and construct an effective predictive model. After careful evaluation, the best-performing model, with an F1 score of 32.1%, balanced accuracy of 73.7%, and ROC AUC of 0.801, was chosen and subsequently deployed in a cloud-based environment. Comprehensive insights and detailed findings are presented in this report.


This project does not provide any medical advice; it is solely for educational and research purposes. If you require medical advice, please consult your physician.

1 Introduction

1.1 The Issue

According to the World Health Organization (WHO) stroke is the leading cause of disability ans the 2nd leading cause of death globally1, responsible for approximately 11% of total deaths2.

The objective of this project is to develop a machine learning model capable of predicting a patient’s likelihood of experiencing a stroke. The ability to identify patients at high risk of stroke will enable your doctors to provide guidance to both the patients and their families on how to respond in the event of an emergency.

1.2 Notes on Methodology

In statistical inference, significance level is 0.05, confidence level is 95%.

In machine learning, F1 score is used as the main performance metric as it takes class imbalance into account.

Whole dataset was split into 3 parts (size ratios 70:15:15):

  • the training set:
    • Extensive EDA was done only on the training set.
    • The decisions in feature engineering were based only on the training set too.
    • The training set was used to train the ML models.
  • The validation set was used for model comparison and selection.
  • The test set was used only for the final evaluation of the model.

1.3 Notes on Reproducibility

Create and setup virtual environment (VE) for Python

conda, pip and other required command line tools must be installed.

Create a virtual environment and install the required packages:

conda create --name proj-stroke-prediction python=3.11
conda activate proj-stroke-prediction
pip install -r requirements.txt

To use package graphviz, you may need to install Graphviz (from here) and add it to the PATH environment variable.

Set WD and activate VE

Every time before working on the project, make sure that the working directory (WD) of the Jupyter Notebook matches the working directory set in the terminal.

To get your current WD in Jupyter Notebook, run the following code in a Python cell:

!echo '%cd%' # on Windows

In the terminal, to set WD to the project’s directory, use cd and output of the above code, e.g.:

# This is just an example, you should choose the correct path in your case
cd 'D:/Data Science/projects/proj-stroke-prediction'

Next, activate the virtual environment (VE) for this project:

conda activate proj-stroke-prediction

In Jupyter Notebook, use the kernel from this environment.

Download data

To download the dataset:

  1. If you do not have it, create Kaggle API token and save it locally:

    1. Log in to Kaggle.
    2. Go to
    3. Under API section, click “Create New Token” and it will be suggested to download kaggle.json file.
    4. Save the kaggle.json locally in the location ~/.kaggle/kaggle.json.
  2. To download data, run in the terminal:

    mkdir data
    kaggle datasets download -d $url -p data/ --unzip

    The data file will be called healthcare-dataset-stroke-data.csv.

Create/Update HTML report
  • First, run all the code cells in the notebook.

  • Next, create HTML report (command line tool quarto must be installed):

    quarto render stroke-prediction.ipynb --to html --output index.html
  • Last, open the HTML file index.html in browser. On Windows, you can use the following command in terminal:

    start index.html
Random seeds and Optuna (important)

To enforce reproducible results, (pseudo)random number generator seeds (parameter random_state) are set for all models and functions that use random numbers. Unfortunately, the results of OptunaSearchCV function (used for hyperparameter tuning) are not fully reproducible (I used 8 CPU threads for computing) even by setting the seed.

1.4 Python Packages and Functions

Code: The main Python setup
# Automatically reload certain modules
%reload_ext autoreload
%autoreload 1

# Plotting
%matplotlib inline

# Packages and modules -------------------------------
# Utilities
import os
import warnings
import joblib
import html
import numpy as np
from io import StringIO
from functools import partial
from pprint import pprint
from scipy.optimize import curve_fit

# Dataframes
import pandas as pd

# EDA and plotting
import seaborn as sns
import matplotlib.pyplot as plt

import sweetviz
import klib

# Data wrangling, maths
import numpy as np

# Machine learning
from sklearn import set_config

from sklearn.base import (BaseEstimator, TransformerMixin, clone)
from sklearn.pipeline import Pipeline
from sklearn.compose import (
from sklearn.preprocessing import (StandardScaler, OneHotEncoder)
from sklearn.impute import SimpleImputer

from sklearn.metrics import (
from sklearn.model_selection import (

# ML: classification models
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from catboost import CatBoostClassifier
from lightgbm import LGBMClassifier

# ML: optimization and hyperparameter tuning
from optuna.integration import OptunaSearchCV
from optuna.distributions import (

# ML: feature selection
from mlxtend.feature_selection import SequentialFeatureSelector
from feature_engine.selection import SmartCorrelatedSelection

# ML: explainability
import shap

# Visualization
import matplotlib.pyplot as plt

# Flowcharts and display
from graphviz import Digraph
from IPython.display import Image, display

# Custom functions
import functions.fun_utils as my
import functions.fun_analysis as an
import functions.fun_ml as ml
from functions.utils import (get_stroke_risk_trend, ColumnSelector, FeatureEngineer)

%aimport functions.fun_utils
%aimport functions.fun_analysis
%aimport functions.fun_ml
%aimport functions.utils

# Settings --------------------------------------------
# Default plot options
plt.rc("figure", titleweight="bold")
plt.rc("axes", labelweight="bold", titleweight="bold")
plt.rc("font", weight="normal", size=10)
plt.rc("figure", figsize=(7, 3))

# Pandas options
pd.set_option("display.max_rows", 1000)
pd.set_option("display.max_columns", 300)
pd.set_option("display.max_colwidth", 50)  # Possible option: None
pd.set_option("display.float_format", lambda x: f"{x:.2f}")
pd.set_option("styler.format.thousands", ",")

# Turn off the scientific notation for floating point numbers.

# Scikit-learn options

# Analysis parameters: use Sweetviz for eda?
do_eda = True

# For caching results ---------------------------------
dir_cache = ".saved_results/"
os.mkdir(dir_cache) if not os.path.exists(dir_cache) else None
Code: Custom ad-hoc functions
# NOTE: Some ad-hoc functions are defined in other places of the file
#       when it seemed that it makes the analysis easier to follow.

# Flowchart =================================================================

def create_flowchart_data_split(
    n_train: int,
    n_validation: int,
    n_test: int,
    n_excluded: int,
    output_path: str,
    """Create a flowchart of to visualize sample sizes of data subset
    after data split into training, validation and test sets.

        n_train: Number of samples in the training set.
        n_validation: Number of samples in the validation set.
        n_test: Number of samples in the test set.
        n_excluded: Number of samples excluded from the analysis.
        output_path: Path to save the flowchart.
                     Must include extension (usually .png)


    def add_node_text(label: str, n: int, perc: float, fillcolor: str = "lightblue"):
        """Add pre-defined text to a node in the flowchart."""
        if perc < 1:
            perc_txt = html.escape("<1%")
            perc_txt = f"{perc:.0f}%"

        return {
            "label": f"<<b>{label}</b><br/>n = {n} ({perc_txt})>",
            "fillcolor": fillcolor,

    n = [n_train, n_validation, n_test, n_excluded]
    n_total = sum(n)
    perc = [100 * i / n_total for i in n]

    dot = Digraph(
        comment="Sample Size Flowchart",
        edge_attr={"arrowhead": "vee"},
        node_attr={"shape": "box", "style": "filled", "fillcolor": "lightgray"},

    # Level 1 nodes
        **add_node_text("Whole dataset", n_total, 100, fillcolor="lightgray"),

    # Level 2 nodes
    dot.node("training", **add_node_text("Training set", n[0], perc[0]))
    dot.node("validation", **add_node_text("Validation set", n[1], perc[1]))
    dot.node("test", **add_node_text("Test set", n[2], perc[2]))
    dot.node("excluded", **add_node_text("Excluded", n[3], perc[3], fillcolor="salmon"))

    # Edges - Level 1 to Level 2
    dot.edge("whole_dataset", "training")
    dot.edge("whole_dataset", "validation")
    dot.edge("whole_dataset", "test")

    with dot.subgraph() as sg:
        sg.attr(rank="same")  # This sets both A and B on the same rank
        sg.edge("whole_dataset", "excluded")

    # Save the flowchart
    path_without_ext, ext = os.path.splitext(output_path)

    dot.format = ext[1:]
    dot.render(path_without_ext, view=False)
    # Show the flowchart
    display(Image(path_without_ext + "." + dot.format))

# Plotting ==================================================================

def violinplot_with_roc_results(gr, y, data=None, accuracy=1, linecolor="red"):
    """Plot a violinplot with best separation threshold determined by ROC analysis.

        gr: Name of the grouping variable.
        y: Name of the target variable.
        data: Dataframe containing the variables.
        accuracy: Number of decimal places to display for the threshold.
        linecolor: Color of the threshold line.


    if data is not None:
        gr = data[gr]
        y = data[y]

    # Calculate ROC curve and AUC
    fpr, tpr, thresholds = roc_curve(gr, y)
    roc_auc = roc_auc_score(gr, y)

    # Find the optimal threshold
    optimal_threshold = thresholds[np.argmax(tpr - fpr)]

    text = (
        f"Optimal threshold: {optimal_threshold:.{accuracy}f}, "
        + f"ROC AUC = {roc_auc:.2f}"

    # Plot ROC curve
    # Create a boxplot
    plt.figure(figsize=(8, 4))
    ax = sns.violinplot(x=gr, y=y)

    # Add a horizontal line for the threshold

def plot_confusion_matrices(y_train, y_pred, figsize=(11, 3)):
    """Plot confusion matrices.

        y_train (array-like): True labels.
        y_pred (array-like): Predicted labels.
        figsize (tuple, optional): Figure size. Defaults to (11, 3).

        tuple: Figure and axes objects.

    fig, ax = plt.subplots(1, 3, figsize=figsize)

    ax[1].set_title("Proportions (true labels)")
    ax[2].set_title("Proportions (predicted labels)")

    ConfusionMatrixDisplay.from_predictions(y_train, y_pred, ax=ax[0])
        y_train, y_pred, ax=ax[1], normalize="true", values_format="0.3f"
        y_train, y_pred, ax=ax[2], normalize="pred", values_format="0.3f"
    return fig, ax

# Display ===================================================================

def display_crosstab(crosstab, percentage="column"):
    """Display a cross-tabulation with counts and column percentages.

        crosstab: Instance of class Crosstab.
        percentage: Which percentages to count? "column" or "row"

    if percentage == "column":
                [crosstab.counts, crosstab.column_percentage],
                keys=["Counts", "% (column)"],
    elif percentage == "row":
                [crosstab.counts, crosstab.row_percentage],
                keys=["Counts", "% (row)"],

# Feature engineering helpers ================================================
# Define the power function
def power_function(x: float, a: float, b: float):
    """Function to fit the power function to the data.
    ouput = a * x^b

        x (array-like): numeric values.
        a (float): Parameter a.
        b (float): Parameter b.

        array-like: Output values calculated as a * x^b.
    return a * np.power(x, b)

def get_stroke_risk_trend(age, base_prob=1, age_threshold=40):
    """Calculate so-called 'stroke risk trend': a function based on the age.

        age (array-like): Age values.
        base_prob (float): Base probability of stroke (constant for
            age < age_threshold.)
        age_threshold (float): Age threshold after which the risk increases
            (doubles every 10 years).

        array-like: Stroke risk trend.
    return np.where(
        age < age_threshold,
        base_prob * 2 ** ((age - age_threshold) / 10),
"""Ad-hoc functions and classes for the project.

NOTE: The following functions are on a separate file as this is required for 
correct pickling (saving the final pre-processing pipeline and the model). 
Otherwise, the pipeline and the model cannot be loaded from the pickle file on 
the cloud."""

import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin

def get_stroke_risk_trend(age, base_prob=1, age_threshold=40):
    """Calculate so-called 'stroke risk trend': a function based on the age.

        age (array-like): Age values.
        base_prob (float): Base probability of stroke (constant for
            age < age_threshold.)
        age_threshold (float): Age threshold after which the risk increases
            (doubles every 10 years).

        array-like: Stroke risk trend.
    return np.where(
        age < age_threshold,
        base_prob * 2 ** ((age - age_threshold) / 10),

# Pre-processing pipeline ====================================================

class ColumnSelector(BaseEstimator, TransformerMixin):
    """Keeps only the indicated DataFrame columns
    and drops the rest.

        feature_names (list): List of column names to keep.

        fit(X, y=None):
            Fit method (Returns self).
            Transform method to select columns of interest.
            Returns a DataFrame with the selected columns only.

    def __init__(self, keep):

            keep (list): List of column names to keep.
        self.keep = keep

    def fit(self, X, y=None):
        return self

    def transform(self, X):
        # Select the indicated features from the input DataFrame X
        selected_features = X[self.keep]
        return pd.DataFrame(selected_features, columns=self.keep)

class FeatureEngineer(BaseEstimator, TransformerMixin):
    """Transformer to do required feature engineering for the final model.

    From variables "age", "health_risk_score", "smoking_status"
    it creates "stroke_risk_40", "health_risk_score", "age_smoking_interaction".


    def __init__(self):

    def fit(self, X, y=None):
        return self

    def transform(self, X):
        X = X.assign(
                X["age"] * (X["smoking_status"] != "never smoked")
            stroke_risk_40=get_stroke_risk_trend(X["age"], age_threshold=40),

        cols_out = [

        return X[cols_out]
External files with code

The following functions, methods and classes are present in external files (Python modules). Not all of these are used in the project. They are displayed here for convenience only.

"""Various functions for data pre-processing, analysis and plotting."""

# Other Python libraries and modules
import re
import pathlib
import joblib
import functools
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Union
from IPython.display import display, HTML
from matplotlib.ticker import MaxNLocator

# Utilities ==================================================================
# Check/Assert
def index_has_names(obj: Union[pd.Series, pd.DataFrame]) -> bool:
    """Check if index of a Pandas object (Series of DataFrame) has names.

        obj: Object that has `.index` attribute.

        bool: True if index has names, False otherwise.

        >>> import pandas as pd

        >>> series1 = pd.Series([1, 2], index=pd.Index(['A', 'B'], name='Letters'))
        >>> index_has_names(series1)

        >>> series2 = pd.Series([1, 2], index=pd.Index(['A', 'B']))
        >>> index_has_names(series2)

        >>> dataframe1 = pd.DataFrame(
        ...    [[1, 2], [3, 4]],
        ...    index=pd.Index(['A', 'B'], name='Rows'),
        ...    columns=['X', 'Y']
        ... )
        >>> index_has_names(dataframe1)

        >>> dataframe2 = pd.DataFrame([[1, 2], [3, 4]])
        >>> index_has_names(dataframe2)
    return None not in list(obj.index.names)

def assert_values(df: pd.DataFrame, expected_values: list[str]) -> None:
    """Assert that the values of each column in a Pandas DataFrame are among
      the expected values.

        df (pd.DataFrame): The input DataFrame to check for expected values.
        expected_values (list[str]): The list of expected values.

        AssertionError: If any column in the DataFrame contains values not
        present in the expected values.

        >>> data = pd.DataFrame({
        >>>     'col1': ['Yes', 'No', 'Yes'],
        >>>     'col2': ['Yes', 'Yes', 'Yes']
        >>> })
        >>> assert_values(data, ['Yes', 'No'])
        # No AssertionError is raised

        >>> data = pd.DataFrame({
        >>>     'col1': ['Yes', 'No', 'Yes'],
        >>>     'col2': ['Yes', 'Maybe', 'no']
        >>> })
        >>> assert_values(data, ['Yes', 'No'])
        Only ['Yes', 'No'] values were expected in the following columns
        (Column name [unexpected values]):
        col2: ['Maybe', 'no']

    non_matching_values = {}
    for column in df.columns:
        non_matching = df[~df[column].isin(expected_values)][column].tolist()
        if non_matching:
            non_matching_values[column] = non_matching
    if non_matching_values:
        error_message = (
            f"\nOnly {expected_values} values were expected in the following "
            "columns\n(Column name [unexpected values]):\n"
        for col_name, unexpected_values in non_matching_values.items():
            error_message += f"{col_name}: {unexpected_values}\n"
        raise AssertionError(error_message)

# Display in Jupyter notebook
def display_collapsible(x, summary: str = "", sep=" ", is_open: bool = False):
    """Display data frame or other object surrounded by `<details>` tags

    (I.e., display in collapsible way)

        x (pd.DataFrame, str, list[str]): Object to display.
        summary (str, optional): Collapsed section name. Defaults to "".
        sep (str, optional): Symbol used to join strings (when x is a list).
             Defaults to " ".
        is_open (bool, optional): Should the section be open by default
            Defaults to False.
    if is_open:
        is_open = " open"
        is_open = ""

    if hasattr(x, "to_html") and callable(x.to_html):
        html_str = x.to_html()
    elif isinstance(x, str):
        html_str = x
        html_str = sep.join([str(i) for i in x])

            f"<details{is_open}><summary>{summary}</summary>" + html_str + "</details>"

# Cache
def cache_results(file: str, force: bool = False):
    """Decorator to cache results of a function and save them to a file in
    the pickle format.

        file (str): File name.
        force (bool, optional): Should the function be run even if the file
            exists? Defaults to False.

    def decorator_cache(fun):
        def wrapper(*args, **kwargs):
            if pathlib.Path(file).is_file() and not force:
                with open(file, "rb") as f:
                    results = joblib.load(f)
                results = fun(*args, **kwargs)
                with open(file, "wb") as f:
                    joblib.dump(results, f)
            return results

        return wrapper

    return decorator_cache

# Format values --------------------------------------------------------------
def to_snake_case(text: str):
    """Convert a string to the snake case.

        text (str): The input string to change to the snake case.

        str: The string converted to the snake case.

        >>> to_snake_case("Some Text")
        >>> to_snake_case("SomeText2")
    assert isinstance(text, str), "Input must be a string."

    return (
        .str.replace("(?<=[a-z])(?=[A-Z0-9])", "_", regex=True)
        .str.replace(r"[ ]", "_", regex=True)
        .str.replace(r"_+", "_", regex=True)

def format_p(p: float, digits: int = 3, add_p: bool = True) -> str:
    """Format p values at 3 decimal places.

        p (float): p value (number between 0 and 1).
        digits (int, optional): Number of decimal places to round to.
            Defaults to 3.
        add_p (bool, optional): Should the string start with "p"?

        >>> format_p(1)
        'p > 0.999'

        >>> format_p(0.12345)
        'p = 0.123'

        >>> format_p(0.00001)
        'p < 0.001'

        >>> format_p(0.00001, digits=2)
        'p < 0.01'

        >>> format_p(1, digits=5)
        'p > 0.99999'

    precision = 10 ** (-digits)
    if add_p:
        prefix = ["p < ", "p > ", "p = "]
        prefix = ["<", ">", ""]

    if p < precision:
        return f"{prefix[0]}{precision}"
    elif p > (1 - precision):
        return f"{prefix[1]}{1 - precision}"
        return f"{prefix[2]}{p:.{digits}f}"

def format_percent(x: Union[float, list[float], pd.Series]) -> pd.Series:
    """Round percentages to 1 decimal place and format as strings

    Values between 0 and 0.05 are printed as <0.1%
    Values between 99.95 and 100 are printed as >99.9%

        x: (A sequence of) percentage values ranging from 0 to 100.

        pd.Series[str]: Pandas series of formatted values.
        Values equal to 0 are formatted as "0%", values between
        0 and 0.05 are formatted as "<0.1%", values between 99.95 and 100
        are formatted as ">99.9%", and values equal to 100 are formatted
        as "100%".

        >>> format_percent(0)

        >>> format_percent(0.01)

        >>> format_percent(1)

        >>> format_percent(10)

        >>> format_percent(99.986)

        >>> format_percent(100)

        >>> format_percent([100, 0, 0.2])
        ['100.0%', '0%', '0.2%']

    Author: Vilmantas Gėgžna
    if not isinstance(x, (list, pd.Series)):
        x = [x]
    # fmt: off
    x_formatted = [
        "0%" if i == 0
        else "<0.1%" if i < 0.05
        else ">99.9%" if 99.95 <= i < 100
        else f"{i:.1f}%"
        for i in x
    # fmt: on

    if isinstance(x, pd.Series):
        return pd.Series(x_formatted, index=x.index)
        return x_formatted

def counts_to_percentages(x: pd.Series, name: str = "percent") -> pd.Series:
    """Express counts as percentages.

    The sum of count values is treated as 100%.

        x (int, float): Counts data as pandas.Series.
        name (str, optional): The name for output pandas.Series with percentage
             values. Defaults to "percent".

        str: pandas.Series object with `x` values expressed as percentages
             and rounded to 1 decimal place, e.g., "0.2%".
             Values equal to 0 are formatted as "0%", values between
             0 and 0.1 are formatted as "<0.1%", values between 99.9 and 100
             are formatted as ">99.9%".

        >>> import pandas as pd
        >>> counts_to_percentages(pd.Series([1, 0, 1000, 2000, 1000, 5000, 1000]))
        0    <0.1%
        1       0%
        2    10.0%
        3    20.0%
        4    10.0%
        5    50.0%
        6    10.0%
        Name: percent, dtype: object

        >>> counts_to_percentages(pd.Series([1, 0, 10000]))
        0     <0.1%
        1        0%
        2    >99.9%
        Name: percent, dtype: object

    Author: Vilmantas Gėgžna
    return format_percent(x / x.sum() * 100).rename(name)

def extract_number(text: str) -> float:
    return float(re.findall(r"-?\d+[.]?\d*", str(text))[0])

# Display -------------------------------------------------------------------
def highlight_max(s, color="green"):
    """Helper function to highlight the maximum in a Series or DataFrame

        s: Numeric values one of which will be highlighted.
        color (str, optional): Text highlight color. Defaults to 'green'.
    >>> import pandas as pd
    >>> data_frame = pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]})

    is_max = s == s.max()
    return [f"color: {color}" if cell else "" for cell in is_max]

def highlight_rows_by_index(x, values, color="green"):
    """Highlight rows/columns with certain index/column name.

        x (pandas.DataFrame): Dataframe to highlight.
        values (list): List of index/column names to highlight.
        color (str, optional): Text highlight color. Defaults to 'green'.

    >>> iris.head(10).style.apply(highlight_rows, values=[8, 9], axis=1)
    return [f"color: {color}" if ( in values) else "" for i in x]

# Function to change text color for data types with 'int' or 'float'
def highlight_int_float_text(value, color="deepskyblue"):
    if "int" in str(value) or "float" in str(value):
        return f"color: {color}"
        return ""

# Function to change text color for data types with 'object'
def highlight_category_text(value, color="limegreen"):
    if "category" in str(value):
        return f"color: {color}"
        return ""

def highlight_value(value, when, color="grey"):
    if value == when:
        return f"color: {color}"
        return ""

def highlight_between(value, min: float = None, max: float = None, color="yellow"):
    if min <= value <= max:
        return f"color: {color}"
        return ""

def highlight_above(value, min: float = None, color="yellow"):
    if min < value:
        return f"color: {color}"
        return ""

def highlight_above_str(value, min: float = None, color="yellow"):
    if min < extract_number(value):
        return f"color: {color}"
        return ""

def highlight_below(value, max: float = None, color="yellow"):
    if value < max:
        return f"color: {color}"
        return ""

def highlight_below_str(value, max: float = None, color="yellow"):
    if extract_number(value) < max:
        return f"color: {color}"
        return ""

# For data wrangling --------------------------------------------------------
# Index
def use_numeric_index(self, start=1):
    """Create a new sequential index that starts at indicated number.

        self (pd.DataFrame):
            The object the method is applied to.
        start_at (int):
            The start of an index

    i = self.index
    self.index = range(start, len(i) + start)
    return self

# Merge all tables
def merge_all(df_list, on, how="outer"):
    """Merge multiple data frames.

        df_list (list of pandas.dataframes): Data frames to join.
        on (str): Column names to join on.
             See details in pandas.DataFrame.merge().
        how (str, optional): {'left', 'right', 'outer', 'inner', 'cross'}.
             Type of merge to be performed. See details in pandas.merge().
             Defaults to "outer".

        pandas.dataframe: merged data frame.

       Function is based on
    merged = df_list[0]
    for to_merge in df_list[1:]:
        merged = pd.merge(left=merged, right=to_merge, how=how, on=on)
    return merged

# Data types
# Function to convert 0/1 to No/Yes
def convert_01_to_no_yes(x):
    dtype_no_yes = pd.CategoricalDtype(categories=["No", "Yes"], ordered=True)
    return x.replace({0: "No", 1: "Yes"}).astype(dtype_no_yes)

# Function to convert False/True to No/Yes
def convert_bool_to_no_yes(x):
    return x.astype(int).pipe(convert_01_to_no_yes)

# Function to convert No/Yes to 0/1
def convert_no_yes_to_01(x):
    return x.replace({"No": 0, "Yes": 1}).astype(np.int8)

# Function to convert False/True to 0/1
def convert_bool_to_01(x):
    return x.astype(np.int8)

# Function to convert No/Yes to -1/1
def convert_no_yes_to_mp1(x):
    return x.replace({"No": -1, "Yes": 1}).astype(np.int8)

# Plot counts ---------------------------------------------------------------
def plot_counts_with_labels(
    """Plot count data as bar plots with labels.

        counts (pandas.DataFrame): Data frame with counts data.
        title (str, optional): Figure title. Defaults to "".
        x (str, optional): Column name from `counts` to plot on x axis.
                Defaults to None: first column.
        y (str, optional): Column name from `counts` to plot on y axis.
                Defaults to "n".
        x_lab (str, optional): X axis label.
              Defaults to value of `x` with capitalized first letter.
        y_lab (str, optional): Y axis label. Defaults to "Count".
        label (str, None, optional): Column name from `counts` for value labels.
                Defaults to "percent".
                If None, label is not added.
        label_rotation (int, optional): Angle of label rotation. Defaults to 0.
        legend (bool, optional): Should legend be shown?. Defaults to False.
        ec (str, optional): Edge color. Defaults to "black".
        y_lim_max (float, optional): Upper limit for Y axis.
                Defaults to None: do not change.
        ax (matplotlib.axes.Axes, optional): Axes object. Defaults to None.
        **kwargs: further arguments to

        matplotlib.axes.Axes: Axes object of the generate plot.

    Author: Vilmantas Gėgžna
    if x is None:
        x = counts.columns[0]

    if x_lab is None:
        x_lab = x.capitalize()

    if y_lim_max is None:
        y_lim_max = counts[y].max() * 1.15

    ax =, y=y, legend=legend, ax=ax, ec=ec, **kwargs)
    ax.set_title(title, fontsize=title_fontsize)
    if label is not None:
        ax_add_value_labels_ab(ax, labels=counts[label], rotation=label_rotation)
    ax.set_ylim(0, y_lim_max)

    return ax

def ax_xaxis_integer_ticks(min_n_ticks: int, rot: int = 0):
    """Ensure that x axis ticks has integer values

        min_n_ticks (int): Minimal number of ticks to use.
        rot (int, optional): Rotation angle of x axis tick labels.
        Defaults to 0.
    ax = plt.gca()
    ax.xaxis.set_major_locator(MaxNLocator(min_n_ticks=min_n_ticks, integer=True))

def ax_axis_comma_format(axis: str = "xy", ax=None):
    """Write values of X axis ticks with comma as thousands separator

        axis (str, optional): which axis should be formatted:
           "x" X axis, "y" Y axis or "xy" (default) both axes.
        ax (axis object, None, optional):Axis of plot.
            Defaults to None: current axis.

    if ax is None:
        ax = plt.gca()

    fmt = "{x:,.0f}"
    formatter = plt.matplotlib.ticker.StrMethodFormatter(fmt)
    if "x" in axis:

    if "y" in axis:

def ax_add_value_labels_lr(ax, labels=None, spacing=25, size=7, weight="bold"):
    """Add value labels left/right to each bar in a bar chart.

        ax (matplotlib.axes.Axes): Plot (axes) to annotate.
        label (str or similar): Values to be used as labels.
        spacing (int): Number of points between bar and label.
        size (int): font size.
        weight (str): font weight.

        This function is based on

    # For each bar: Place a label
    for rect, label in zip(ax.patches, labels):
        # Get X and Y placement of label from rect.
        y_value = rect.get_y() + rect.get_height() / 2
        x_value = rect.get_width()

        space = spacing

        # Horizontal alignment for positive values
        ha = "right"

        # If the value of a bar is negative: Place left to the bar
        if x_value < 0:
            # Invert space to place label on the left
            space *= -1
            # Horizontal alignment
            ha = "left"

        # Use X value as label and format number with one decimal place
        if labels is None:
            label = "{:.1f}".format(x_value)

        # Create annotation
            (x_value, y_value),
            xytext=(space, 0),
            textcoords="offset points",

def ax_add_value_labels_ab(ax, labels=None, spacing=2, size=9, weight="bold", **kwargs):
    """Add value labels above/below each bar in a bar chart.

        ax (matplotlib.Axes): Plot (axes) to annotate.
        label (str or similar): Values to be used as labels.
        spacing (int): Number of points between bar and label.
        size (int): font size.
        weight (str): font weight.
        **kwargs: further arguments to axis.annotate.

        This function is based on

    # For each bar: Place a label
    for rect, label in zip(ax.patches, labels):
        # Get X and Y placement of label from rect.
        y_value = rect.get_height()
        x_value = rect.get_x() + rect.get_width() / 2

        space = spacing

        # Vertical alignment for positive values
        va = "bottom"

        # If the value of a bar is negative: Place label below the bar
        if y_value < 0:
            # Invert space to place label below
            space *= -1
            # Vertical alignment
            va = "top"

        # Use Y value as label and format number with one decimal place
        if labels is None:
            label = "{:.1f}".format(y_value)

        # Create annotation
            (x_value, y_value),
            xytext=(0, space),
            textcoords="offset points",

if __name__ == "__main__":
    # doctest
    import doctest

"""Functions and classes to perform statistical analysis and output the results."""

from typing import Optional, Union
from IPython.display import display

import humanize
import pandas as pd
import numpy as np
import pingouin as pg
import statsmodels.stats.api as sms
import scipy.stats as sps
import scikit_posthocs as sp
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objs as go

import functions.fun_utils as my  # Custom module
import functions.cld as cld  # Custom module; CLD calculations

from import mosaic
from statsmodels.stats.multitest import multipletests as p_adjust

from pandas.api.types import is_integer_dtype
from sklearn.feature_selection import mutual_info_classif

# Functions =================================================================

# Exploratory analysis ------------------------------------------------------

def count_unique(data: pd.DataFrame) -> pd.DataFrame:
    """Get number and percentage of unique values

        data (pd.DataFrame): Data frame to analyze.

    Return: data frame with columns `n_unique` (int) and `percent_unique` (str)
    n_unique = data.nunique()
    return pd.concat(
                my.format_percent((n_unique / data.shape[0]).multiply(100)).rename(

def summarize_numeric(x, ndigits=None):
    """Calculate some common summary statistics.

        x (pandas.Series): Numeric variable to summarize.
        ndigits (int, None, optional): Number of decimal digits to round to.
                Defaults to None.
       pandas.DataFrame with summary statistics.

    def mad(x):
        return sps.median_abs_deviation(x)

    def range(x):
        return x.max() - x.min()

    res = x.agg(["count", "min", "max", range, "mean", "median", "std", mad, "skew"])

    if ndigits is not None:
        summary = pd.DataFrame(round(res, ndigits=ndigits)).T
        summary = pd.DataFrame(res).T
    # Present count data as integers:
    summary = summary.assign(count=lambda d: d["count"].astype(int))

    return summary

def frequency_table(
    group: str,
    data: pd.DataFrame,
    sort: Union[bool, str] = True,
    weight: Optional[str] = None,
    n_label: str = "n",
    perc_label: str = "percent",
) -> pd.DataFrame:
    """Create frequency table that contains counts and percentages.

        group (str): Variable that defines the groups. Column name from `data`.
        data (pandas.DataFrame): data frame.
        sort (bool or "index", optional): Way to sort values:
             - True or "count" - sort by count descending.
             - "index" - sort by index ascending.
             - False - no sorting.
             Defaults to True.
        weight (str, optional): Frequency weights. Column name from `data`.
                Defaults to None: no weights are used.
        n_label (str, optional): Name for output column with counts.
        perc_label (str, optional): Name for output column with percentage.

    Return: pandas.DataFrame with 3 columns:
            - column with unique values of `x`,
            - column `n_label` (defaults to "n") with counts as int, and
            - column `perc_label` (defaults to "percent") with percentage
              values formatted as str.

    Author: Vilmantas Gėgžna

    vsort = sort == True or sort == "count"

    if weight is None:
        counts = data[group].value_counts(sort=vsort)
        if sort == "index":
            counts = counts.sort_index()
        counts = data.groupby(group)[weight].sum()

    percent = my.counts_to_percentages(counts)

    return (
        pd.concat([counts.rename(n_label), percent.rename(perc_label)], axis=1)

def summarize_discrete(data: pd.DataFrame, max_n_unique: int = 10, **kwargs) -> None:
    """Create and display frequency tables for columns with a low number of unique values.

    This function generates and displays frequency tables for columns in the
    input DataFrame where the number of unique values is less than the specified
    threshold (`max_n_unique`).

        data (pd.DataFrame): The input data frame.

        max_n_unique (int, optional): The maximum number of unique values
          allowed for a column to be considered for generating a frequency
          Defaults to 10.

        **kwargs: Additional keyword arguments to pass to

        None: This function displays the frequency tables using the
        `display()` function.

        To create and display frequency tables for columns with less than
        10 unique values in a DataFrame `df`, you can use:

        >>> summarize_discrete(df, max_n_unique=10)
    n = data.shape[0]

    tables_to_display = [
        frequency_table(column_name, data, **kwargs)
        for column_name in data.columns
        if data[column_name].nunique() <= max_n_unique

    for table in tables_to_display:
            .bar(subset=["n"], color="grey", width=60, vmin=0, vmax=n)
            .set_properties(**{"width": "12em"})
            # See

def col_info(
    data: pd.DataFrame,
    style: bool = False,
    n_unique_max: int = 10,
    p_dominant_threshold: float = 90.0,
) -> pd.DataFrame:
    """Get overview of data frame columns: data types, number of unique values,
    and number of missing values.

        data (pd.DataFrame): Data frame to analyze.
        style (bool, optional): Flag to return styled data frame.
            if True, styled data frame is returned:
            - index is hidden;
            - `memory_size` is formatted as human-readable string;
            - `data_type` is highlighted in blue for columns with numeric
               data types and in green for columns with "category" data type;
            - `p_unique` and `p_missing` are formatted as percentages (str);
            - data_type is highlighted for columns with numeric data types
               ('int' or 'float' in the name);
            - `n_missing` and `p_missing` are grayed out for columns with no missing
            - `n_unique` is highlighted for binary variables (in green) and
                columns with more than `n_unique_max` (usually 10) unique values
                are in blue;
            - `p_unique` is highlighted in red for columns with only one unique
                value and in orange for columns with more than
                `p_dominant_threshold` (usually 90%) of unique values.
            - `p_dominant` is highlighted in red for columns with only a single
                value and in orange  for more than `p_dominant_threshold`
                (usually 90%) of unique values.
        n_unique_max (int, optional): Maximum number of unique values to
           treat them as categorical. When `style=True`, values > n_unique_max
           are highlighted in the same way as numeric data types.
           Defaults to 10.
        p_dominant_threshold (float, optional): Threshold for percentage of
              dominant value. When `style=True`, values > p_dominant_threshold
              are highlighted in orange. Default is 90.0.

        pd.DataFrame: Data frame with columns:
        `column` (column name),
        `data_type` (data type of the column),
        `memory_size` (memory size of the column in bytes),
        `n_unique` (number of unique values),
        `p_unique` (percentage of unique values, number from 0 to 100),
        `n_missing` (number of missing values),
        `p_pissing` (percentage of missing values, number from 0 to 100),
        `n_dominant` (count of the most frequent, i.e. dominant, value),
        `p_dominant` (percentage of the most frequent value).
    n = data.shape[0]

    n_dominant = data.apply(lambda x: x.value_counts().max())
    dominant = data.apply(lambda x: x.value_counts().idxmax())

    info = pd.DataFrame({
        "column": data.columns,
        "data_type": data.dtypes,
        "memory_size": data.memory_usage(deep=True, index=False),
        "n_unique": data.nunique(),
        "p_unique": (data.nunique() / n * 100),
        "n_missing": data.isna().sum(),
        "p_missing": (data.isna().mean() * 100),
        "n_dominant": n_dominant,
        "p_dominant": (n_dominant / n * 100),
        "dominant": dominant,

    if style:
        color_category = "limegreen"
        color_binary = color_category
        color_numeric = "deepskyblue"
        color_warning = "orange"
        color_danger = "red"
        color_fade = "gray"
        return (
                memory_size=lambda d: d["memory_size"].apply(humanize.naturalsize),
                p_unique=lambda d: my.format_percent(d["p_unique"]),
                p_missing=lambda d: my.format_percent(d["p_missing"]),
                p_dominant=lambda d: my.format_percent(d["p_dominant"]),
                my.highlight_int_float_text, color=color_numeric, subset=["data_type"]
                my.highlight_category_text, color=color_category, subset=["data_type"]
                my.highlight_value, when=0, color=color_fade, subset=["n_missing"]
                my.highlight_value, when="0%", color=color_fade, subset=["p_missing"]
                my.highlight_value, when="0%", color=color_danger, subset=["p_unique"]
        return info

# Inferential statistics -----------------------------------------------------
def ci_proportion_binomial(
    method: str = "wilson",
    n_label: str = "n",
    percent_label: str = "percent",
) -> pd.DataFrame:
    """Calculate confidence intervals for binomial proportion.

    Calculates confidence intervals for each category separately on
    "category counts / total counts" basis.

    Wrapper around statsmodels.stats.proportion.proportion_confint()

    More information in documentation of statsmodels's proportion_confint().

        x (int): ps.Series, list or tuple with count data.
        method (str, optional): Method. Defaults to "wilson".
       n_label (str, optional): Name for column for counts.
       percent_label (str, optional): Name for column for percentage values.
       **kwargs: Additional arguments passed to proportion_confint().

        pd.DataFrame: Data frame with group names, absolute counts, percentages
        and their confidence intervals.

    >>> ci_proportion_binomial([62, 55])
    assert isinstance(counts, (pd.Series, list, tuple))
    if not isinstance(counts, pd.Series):
        counts = pd.Series(counts)

    return pd.concat(
            (counts / sum(counts)).rename(percent_label) * 100,
                        count_i, sum(counts), method=method, **kwargs
                    for count_i in counts
                columns=[f"ci_lower", "ci_upper"],
            * 100,

def ci_proportion_multinomial(
    method: str = "goodman",
    n_label: str = "n",
    percent_label: str = "percent",
) -> pd.DataFrame:
    """Calculate  simultaneous confidence intervals for multinomial proportion.

    Wrapper around statsmodels.stats.proportion.multinomial_proportions_confint()

    More information in documentation of statsmodels's

        x (int): ps.Series, list or tuple with count data.
        method (str, optional): Method. Defaults to "goodman".
       n_label (str, optional): Name for column for counts.
       percent_label (str, optional): Name for column for percentage values.
       **kwargs: Additional arguments passed to multinomial_proportions_confint().

        pd.DataFrame: Data frame with group names, absolute counts, percentages
        and their confidence intervals.

    >>> ci_proportion_multinomial([62, 33, 55])
    assert isinstance(counts, (pd.Series, list, tuple))
    if not isinstance(counts, pd.Series):
        counts = pd.Series(counts)

    return pd.concat(
            (counts / sum(counts)).rename(percent_label) * 100,
                sms.multinomial_proportions_confint(counts, method=method, **kwargs),
                columns=[f"ci_lower", "ci_upper"],
            * 100,

def test_chi_square_gof(
    f_obs: list[int],
    f_exp: Union[str, list[float]] = "all equal",
    output: str = "long",
) -> str:
    """Chi-squared (χ²) goodness-of-fit (gof) test

        f_obs (list[int]): Observed frequencies
        f_exp str, list[int]: List of expected frequencies or "all equal" if
              all frequencies are equal to the mean of observed frequencies.
              Defaults to "all equal".
        output (str, optional): Output format (available options:
        "short", "long"). Defaults to "long".

        str: formatted test results including p value.
    k = len(f_obs)
    n = sum(f_obs)
    exp = n / k
    dof = k - 1
    if f_exp == "all equal":
        f_exp = [exp for _ in range(k)]
    stat, p = sps.chisquare(f_obs=f_obs, f_exp=f_exp)
    # May also be formatted this way:
    if output == "short":
        result = f"chi-square test, {my.format_p(p)}"
        result = (
            f"chi-square test, χ²({dof}, n = {n}) = {round(stat, 2)}, {my.format_p(p)}"

    return result

# Classes ===================================================================

# Analyze count data --------------------------------------------------------
class AnalyzeCounts:
    """The class to analyze count data.

    - Performs omnibus chi-squared and post-hoc pair-wise chi-squared test.
    - Compactly presents results of post-hoc test as compact letter display, CLD
      (Shared CLD letter show no significant difference between groups).
    - Calculates percentages and their confidence intervals by using Goodman's
    - Creates summary of grouped values (group counts and percentages).
    - Plots results as bar plots with percentage labels.

    def __init__(self, counts, by=None, counts_of=None):
        Object initialization function.

            counts (pandas.Series[int]): Count data to analyze.
            by (str, optional): Grouping variable name. Used to create labels.
                      If None, defaults to "Group"
            counts_of (str, optional): The thing that was counted.
                    This name is used for labels in plots and tables.
                    Defaults to ``.
        assert isinstance(counts, pd.Series)

        # Set defaults
        if by is None:
            by = "Group"

        if counts_of is None:
            counts_of =

        # Set attributes: user inputs or defaults
        self.counts = counts
        self.counts_of = counts_of = by

        # Set attributes: created/calculated
        self.n_label = f"n_{counts_of}"  # Create label for counts

        # Set attributes: results to be calculated
        self.results_are_calculated = False
        self.omnibus = None
        self.n_ci_and_cld = None
        self.descriptive_stats = None

        # Alias attributes
        counts = self.counts
        by =
        n_label = self.n_label

        # Omnibus test: perform and save the results
        self.omnibus = test_chi_square_gof(counts)

        # Post-hoc (pairwise chi-square): perform
        posthoc_p = my.pairwise_chisq_gof_test(counts)
        posthoc_cld = cld.make_cld(posthoc_p, output_gr_var=by)

        # Confidence interval: calculate
        ci = (
            ci_proportion_multinomial(counts, method="goodman", n_label=n_label)

        # Make sure datasets are mergeable
        ci[by] = ci[by].astype(str)
        posthoc_cld[by] = posthoc_cld[by].astype(str)

        # Merge results
        n_ci_and_cld = pd.merge(ci, posthoc_cld, on=by)

        # Format percentages and counts
        vars = ["percent", "ci_lower", "ci_upper"]
        n_ci_and_cld[vars] = n_ci_and_cld[vars].apply(my.format_percent)

        # Save results
        self.n_ci_and_cld = n_ci_and_cld

        # Descriptive statistics: calculate
        to_format = ["min", "max", "range", "mean", "median", "std", "mad"]

        def format_0f(x):
            return [f"{i:,.0f}" for i in x]

        summary_count = my.summarize_numeric(ci[n_label])
        summary_count[to_format] = summary_count[to_format].apply(format_0f)

        summary_perc = my.summarize_numeric(ci["percent"])
        summary_perc[to_format] = summary_perc[to_format].apply(my.format_percent)
        # Save results
        self.descriptive_stats = pd.concat([summary_count, summary_perc])

        # Initialization status
        self.results_are_calculated = True

        # Output
        return self

    def print(
        omnibus: bool = True,
        posthoc: bool = True,
        descriptives: bool = True,
        """Print numeric results.

            omnibus (bool, optional): Flag to print omnibus test results.
                                      Defaults to True.
            posthoc (bool, optional): Flag to print post-hoc test results.
                                      Defaults to True.
            descriptives (bool, optional): Flag to print descriptive statistics.
                                      Defaults to True.

            Exception: if calculations with `.fit()` method were not performed.
        if not self.results_are_calculated:
            raise Exception("No results. Run `.fit()` first.")

        # Omnibus test
        if omnibus:
            print("Omnibus (chi-squared) test results:")
            print(self.omnibus, "\n")

        # Post-hoc and CI
        if posthoc:
                f"Counts of {self.counts_of} with 95% CI "
                "and post-hoc (pairwise chi-squared) test results:"
            print(self.n_ci_and_cld, "\n")

        # Descriptive statistics: display
        if descriptives:
            print(f"Descriptive statistics of group ({}) counts:")
            print(self.descriptive_stats, "\n")

    def display(
        omnibus: bool = True,
        posthoc: bool = True,
        descriptives: bool = True,
        """Display numeric results in Jupyter Notebooks.

            omnibus (bool, optional): Flag to print omnibus test results.
                                      Defaults to True.
            posthoc (bool, optional): Flag to print post-hoc test results.
                                      Defaults to True.
            descriptives (bool, optional): Flag to print descriptive statistics.
                                      Defaults to True.

            Exception: if calculations with `.analyze()` method were
            not performed.
        if not self.results_are_calculated:
            raise Exception("No results. Run `.fit()` first.")

        # Omnibus test
        if omnibus:
            my.display_collapsible(self.omnibus, "Omnibus (chi-squared) test results")

        # Post-hoc and CI
        if posthoc:
      {self.n_label: "{:,.0f}"}),
                f"Counts of {self.counts_of} with 95% CI and post-hoc "
                " (pairwise chi-squared) test results",

        # Descriptive statistics: display
        if descriptives:
                f"Descriptive statistics of group ({}) counts",

    def plot(self, xlabel=None, ylabel=None, **kwargs):
        """Plot analysis results.

            xlabel (str, None, optional): X axis label.
                    Defaults to None: autogenerated label.
            ylabel (str, None, optional): Y axis label.
                    Defaults to None: autogenerated label.
            **kwargs: further arguments passed to `my.plot_counts_with_labels()`

            Exception: if calculations with `.fit()` method were
            not performed.

            matplotlib.axes object
        if not self.results_are_calculated:
            raise Exception("No results. Run `.fit()` first.")

        # Plot
        if xlabel is None:
            xlabel =

        if ylabel is None:
            ylabel = f"Number of {self.counts_of}"

        ax = my.plot_counts_with_labels(


        return ax

# Analyze numeric groups ------------------------------------------------------
class AnalyzeNumericGroups:
    """Class to analyze numeric/continuous data by groups.

    - Calculates mean ratings per group and their confidence intervals using
        t distribution.
    - Performs omnibus (Kruskal-Wallis) and post-hoc (Conover-Iman) tests.
    - Compactly presents results of post-hoc test as compact letter display, CLD
      NOTE: for CLD calculations, R is required.
      (Shared CLD letter show no significant difference between groups).
    - Creates summary of grouped values (group counts and percentages).
    - Plots results as points with 95% confidence interval error bars.

    def __init__(self, data, y: str, by: str):
        """Initialize the class.

            y (str): Name of numeric/continuous (dependent) variable.
            by (str): Name of grouping (independent) variable.
            data (pandas.DataFrame): data frame with variables indicated in
                `y` and `by`.
        assert isinstance(data, pd.DataFrame)

        # Set attributes: user inputs = data
        self.y = y = by

        # Set attributes: results to be calculated
        self.results_are_calculated = False
        self.omnibus = None
        self.ci_and_cld = None
        self.descriptive_stats = None

    def fit(self):
        # Aliases:
        data =
        y = self.y
        by =

        # Omnibus test: Kruskal-Wallis test
        omnibus = pg.kruskal(data=data, dv=y, between=by)
        omnibus["p-unc"] = my.format_p(omnibus["p-unc"][0])

        self.omnibus = omnibus

        # Confidence intervals
        ci_raw = data.groupby(by)[y].apply(
            lambda x: [np.mean(x), *sms.DescrStatsW(x).tconfint_mean()]
        ci = pd.DataFrame(
            columns=["mean", "ci_lower", "ci_upper"],

        # Post-hoc test: Conover-Iman test
        posthoc_p_matrix = sp.posthoc_conover(
            data, val_col=y, group_col=by, p_adjust="holm"
        posthoc_p_df = posthoc_p_matrix.stack().to_df("p.adj", ["group1", "group2"])
        posthoc_cld = my.convert_pairwise_p_to_cld(posthoc_p_df, output_gr_var=by)

        # Make sure datasets are mergeable
        ci[by] = ci[by].astype(str)
        posthoc_cld[by] = posthoc_cld[by].astype(str)

        self.ci_and_cld = pd.merge(posthoc_cld, ci, on=by)

        # Descriptive statistics of means
        self.descriptive_stats = my.summarize_numeric(ci["mean"])

        # Results are present
        self.results_are_calculated = True

        # Output:
        return self

    def print(
        omnibus: bool = True,
        posthoc: bool = True,
        descriptives: bool = True,
        """Print numeric results.

            omnibus (bool, optional): Flag to print omnibus test results.
                                      Defaults to True.
            posthoc (bool, optional): Flag to print post-hoc test results.
                                      Defaults to True.
            descriptives (bool, optional): Flag to print descriptive statistics.
                                      Defaults to True.

            Exception: if calculations with `.fit()` method were
            not performed.
        if not self.results_are_calculated:
            raise Exception("No results. Run `.fit()` first.")

        # Omnibus test
        if omnibus:
            print("Omnibus (Kruskal-Wallis) test results:")
            print(self.omnibus, "\n")

        # Post-hoc and CI
        if posthoc:
                "Post-hoc (Conover-Iman) test results as CLD and "
                "Confidence intervals (CI):",
            print(self.ci_and_cld, "\n")

        # Descriptive statistics
        if descriptives:
            print(f"Descriptive statistics of group ({}) means:")
            print(self.descriptive_stats, "\n")

    def display(
        omnibus: bool = True,
        posthoc: bool = True,
        descriptives: bool = True,
        """Display numeric results in Jupyter Notebooks.

            omnibus (bool, optional): Flag to print omnibus test results.
                                      Defaults to True.
            posthoc (bool, optional): Flag to print post-hoc test results.
                                      Defaults to True.
            descriptives (bool, optional): Flag to print descriptive statistics.
                                      Defaults to True.

            Exception: if calculations with `.fit()` method were
            not performed.
        if not self.results_are_calculated:
            raise Exception("No results. Run `.fit()` first.")

        # Omnibus test
        if omnibus:
                self.omnibus, "Omnibus (Kruskal-Wallis) test results"

        # Post-hoc and CI
        if posthoc:
                "Post-hoc (Conover-Iman) test results as CLD and "
                "Confidence intervals (CI)",

        # Descriptive statistics of means
        if descriptives:
                f"Descriptive statistics of group ({}) means",

    def plot(self, title=None, xlabel=None, ylabel=None, **kwargs):
        """Plot the results


            xlabel (str, None, optional): X axis label.
                    Defaults to None: capitalized value of `by`.
            ylabel (str, None, optional): Y axis label.
                    Defaults to None: capitalized value of `y`.
            title (str, None, optional): The title of the plot.
                    Defaults to None.

            Tuple with matplotlib figure and axis objects (fig, ax).
        if not self.results_are_calculated:
            raise Exception("No results. Run `.fit()` first.")

        # Aliases:
        ci = self.ci_and_cld
        by =
        y = self.y

        # Create figure and axes
        fig, ax = plt.subplots()

        # Construct plot
        x = ci.iloc[:, 0]

            yerr=[ci["mean"] - ci["ci_lower"], ci["ci_upper"] - ci["mean"]],

        if xlabel is None:
            xlabel = by.capitalize()

        if ylabel is None:
            ylabel = y.capitalize()

        ax.set_ylim([0, None])

        # Output
        return (fig, ax)

# Cross-tabulation ------------------------------------------------------------
class Crosstab:
    """Class for cross-tabulation analysis.

    Author: Vilmantas Gėgžna


    def __init__(self, x=None, y=None, data=None, **kwargs):
        """Create a cross-tabulation from data frame.

        Args (option 1):
            data (pandas.DataFrame): Data frame
            x (str): Column name in `data`
            y (str): Column name in `data`

        Args (option 2):
            x (pandas.Series): Variable with numeric values.
            y (pandas.Series): Variable with numeric values.

        Args (common):
            **kwargs: other arguments to pandas.crosstab()

        if data is None:
            # x and y are series objects
            self.xlabel =
            self.ylebel =
            # x and y are column names in data
            self.xlabel = x
            self.ylabel = y
            x = data[x]
            y = data[y]

        counts = pd.crosstab(x, y, **kwargs)

        self.crosstab = counts

        self.counts = counts
        self.row_percentage = round(
            counts.div(counts.sum(axis=1), axis=0) * 100, ndigits=1
        self.column_percentage = round(
            counts.div(counts.sum(axis=0), axis=1) * 100, ndigits=1
        self.total_percentage = round(counts.div(counts.sum().sum()) * 100, ndigits=1)

    def __call__(self) -> pd.DataFrame:
        """Return cross-tabulation."""
        return self.crosstab

    def print(self):
        """Print cross-tabulation."""

    def display(self):
        """Display cross-tabulation in Jupyter Notebook."""

    def heatmap(
        title: Optional[str] = None,
        xlabel: str = None,
        ylabel: str = None,
        vmax: Optional[int] = None,
        vmin: int = 0,
        cbar: bool = True,
        fmt: Union[str, dict] = "1d",
        annot_kws: dict = {"size": 10},
        cmap: str = "RdYlBu",
    ) -> plt.Axes:
        """Plot a Cross-Tabulation as a heatmap.

            title (str, optional): Title of the plot.
            xlabel (str, optional): Label for the x-axis. If None (default),
                                    the name will be used.
            ylabel (str, optional): Label for the y-axis. If None (default),
                                    the name will be used.
            vmax (int, optional): Maximum value for color scale.
                                If not provided, the maximum frequency in the
                                cross-tabulation will be calculated.
            vmin (int, optional): Minimum value for color scale. Defaults to 0.
            cbar (bool, optional): Whether to show the color bar. Defaults to True.
            fmt (str or dict, optional): String formatting code for annotations.
                                        Defaults to "1d". Can also be a dictionary
                                        of format codes for specific columns.
            annot_kws (dict, optional): Additional keyword arguments for annotations.
                                        Defaults to {"size": 10}.
            cmap (str, optional): Colormap to use. Defaults to "RdYlBu".
            **kwargs: Additional keyword arguments to be passed to the underlying
                    seaborn heatmap function.

            plt.Axes: The created Axes object.
        crosstab = self.crosstab

        # Visualize

        if vmax is None:
            vmax = crosstab.max().max()

        if xlabel is None:
            xlabel = self.xlabel

        if ylabel is None:
            ylabel = self.ylabel

        ax = sns.heatmap(

        if title is not None:

        return ax

    def mosaic(
        self, title: str = None, xlabel: str = None, ylabel: str = None, **kwargs
        """Plot a Cross-Tabulation as a mosaic plot.

            title (str, optional): Title of the plot.
            xlabel (str, optional): Label for the x-axis. If None (default),
                                    the name will be used.
            ylabel (str, optional): Label for the y-axis. If None (default),
                                    the name will be used.
            **kwargs: Additional keyword arguments to be passed to the underlying

            plt.figure: The created Axes object.
        crosstab = self.crosstab

        if xlabel is None:
            xlabel = self.xlabel

        if ylabel is None:
            ylabel = self.ylabel

        fig, _ = mosaic(crosstab.T.unstack().to_dict(), title=title, **kwargs)

        return fig

    def barplot(
        title: str = None,
        xlabel: str = None,
        ylabel: str = "AUTO",
        rot: int = 0,
        normalize: str = "none",
    ) -> plt.Axes:
        Plot a cross-tabulation as a barplot.

            title (str, optional): Title of the plot.
            xlabel (str, optional): Label for the x-axis.
                Defaults to None (use the default value).
            ylabel (str, optional): Label for the y-axis.
                Defaults to "AUTO" (results in either "Count" or "Percentage").
            rot (int, optional): Rotation angle for x-axis labels. Defaults to 0.
            normalize (str, optional): Whether to show absolute counts ("none"),
                row percentages ("rows"), column percentages ("cols"), or
                overall percentages ("all"). Defaults to "none".
            **kwargs: Additional arguments to be passed to the underlying

            plt.Axes: The created Axes object.
        cross_t = self.crosstab

        if xlabel is None:
            xlabel = self.xlabel

        if normalize == "rows":
            # Row percentage
            cross_p = round(cross_t.div(cross_t.sum(axis=1), axis=0) * 100, ndigits=1)
        elif normalize == "cols":
            # Column percentage
            cross_p = round(cross_t.div(cross_t.sum(axis=0), axis=1) * 100, ndigits=1)
        elif normalize == "all":
            # Overall percentage
            cross_p = round(cross_t.div(cross_t.sum().sum()) * 100, ndigits=1)
            # Absolute counts
            cross_p = cross_t

        if ylabel == "AUTO":
            if normalize in ("rows", "cols", "all"):
                ylabel = "Percentage"
                ylabel = "Count"

        # Visualize
        ax =
            ec="black", title=title, rot=rot, xlabel=xlabel, ylabel=ylabel, **kwargs
        return ax

    def mosaic_dict(self):
        return self.crosstab.T.unstack().to_dict()

    def mosaic_go(
        self, title: str = None, xlabel: str = None, ylabel: str = None, **kwargs
        """Plot a Cross-Tabulation as a merimekko (mosaic) chart using Plotly.

            title (str, optional): Title of the plot.
            xlabel (str, optional): Label for the x-axis. If None (default),
                                    the name will be used.
            ylabel (str, optional): Label for the y-axis. If None (default),
                                    the name will be used.
            **kwargs: Additional keyword arguments to be passed to the Plotly go.Figure().

            go.Figure: The created merimekko chart Figure.
        crosstab = self.crosstab

        if xlabel is None:
            xlabel = self.xlabel

        if ylabel is None:
            ylabel = self.ylabel

        years = list(crosstab.columns)
        marker_colors = {
            col: "rgb({}, {}, {})".format(*np.random.randint(0, 256, 3))
            for col in crosstab.index

        fig = go.Figure()

        for idx in crosstab.index:
            dff = crosstab.loc[idx:idx]
            widths = np.array(dff.values[0])

                    x=np.cumsum(widths) - widths,
                    text=["{:.2f}%".format(x) for x in dff.values[0]],

            tickvals=np.cumsum(widths) - widths,
            ticktext=["%s<br>%d" % (l, w) for l, w in zip(years, widths)],

        fig.update_xaxes(range=[0, sum(widths)])
        fig.update_yaxes(range=[0, 100])

            title_text=title if title else "Merimekko Chart",
            uniformtext=dict(mode="hide", minsize=10),

        return fig

def get_mutual_information(
    data: pd.DataFrame,
    target: str,
    drop: list[str] = [],
    precision: int = 3,
    random_state: int = None,
    """Get mutual information scores for classification problem.

        data (pd.DataFrame): Dataframe with target column.
        target (str): Target column name for classification task.
        drop (list[str], optional): Columns to drop. Defaults to [].
        precision (int, optional): Number of decimal places for rounding.
            Defaults to 3.
        random_state (int, optional): Random state for mutual information
            calculation. It is needed as some random noise is added to break ties. Defaults to None.

        pd.DataFrame: Mutual information scores.
    # Copy the data and drop unnecessary columns
    X = data.dropna().drop(columns=drop)
    y = X.pop(target)

    # Label encoding for categorical columns
    for colname in X.select_dtypes(["object", "category"]):
        X[colname], _ = X[colname].factorize()

    # Identify discrete features
    discrete_features = [is_integer_dtype(i) for i in X.dtypes]

    # Calculate mutual information scores
    mutual_info = mutual_info_classif(
        X, y, discrete_features=discrete_features, random_state=random_state

    # Create a DataFrame for the scores
    mi_scores = pd.DataFrame(
        {"variable_1": target, "variable_2": X.columns, "mutual_info": mutual_info}

    # Style the DataFrame and return
    styled_scores = (
        mi_scores.sort_values("mutual_info", ascending=False).pipe(my.use_numeric_index)
        #   .reset_index(drop=True)
        .style.format({"mutual_info": f"{{:.{precision}f}}"})
    )"BrBG", subset=["mutual_info"])

    return styled_scores

def get_pointbiserial_corr_scores(data, target: str):
    """Get point-biserial correlation scores for numeric variables.

    Pairwise missing values are removed.

        data (pd.DataFrame): Dataframe with target column.
        target (str): Target column name.

        pd.DataFrame: Point-biserial correlation scores.
    non_target_numeric = data.select_dtypes("number").drop(columns=target).columns

    def pointbiserialr_wo_pairwise_na(x, y):
        df = pd.DataFrame(zip(x, y)).dropna()
        return df.shape[0], *sps.pointbiserialr(df.iloc[:, 0], df.iloc[:, 1])

    cor_data = [
        (i, *pointbiserialr_wo_pairwise_na(data[target], data[i]))
        for i in non_target_numeric
    cor_data = pd.DataFrame(cor_data, columns=["variable_2", "n", "r_pb", "p"])
    cor_data.insert(0, "variable_1", target)

    return (
        cor_data.sort_values("r_pb", ascending=False, key=abs)
        .assign(p_adj=lambda x: [my.format_p(i, add_p=False) for i in p_adjust(x.p)[1]])
        .assign(p=lambda x: x.p.apply(my.format_p, add_p=False))
        .style.format({"r_pb": "{:.3f}"})
        .background_gradient(cmap="Blues", subset=["n"])
        .bar(vmin=-1, vmax=1, cmap="BrBG", subset=["r_pb"])
"""Various functions for machine learning related tasks."""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import functions.fun_utils as my  # Custom module

from IPython.display import display

from sklearn.metrics import (

def get_metric_abbreviation_and_sign(scoring: str):
    """Internal function to parse scoring string and return
    abbreviation and sign.
    sign = -1 if scoring.startswith("neg") else 1

    if scoring == "neg_root_mean_squared_error":
        metric = "RMSE"
    elif scoring == "balanced_accuracy":
        metric = "BAcc"
    elif scoring == "r2":
        metric = "R²"
    elif scoring == "f1":
        metric = "F1"
        metric = scoring
    return sign, metric

# Feature selection
def sfs_get_score(sfs_object, k_features):
    """Return performance score achieved with certain number of features.

        sfs_object: result of function do_sfs_lin_reg()
        k_features (int): number of features.
    md = round(np.median(sfs_object.get_metric_dict()[k_features]["cv_scores"]), 3)
    return {
        "k_features": k_features,
        "mean_score": round(sfs_object.get_metric_dict()[k_features]["avg_score"], 3),
        "median_score": md,
        "sd_score": round(sfs_object.get_metric_dict()[k_features]["std_dev"], 3),

def sfs_plot_results(sfs_object, sub_title="", ref_y=None):
    """Plot results from SFS object

      sfs_object: object with SFS results.
      sub_title (str): second line of title.
      ref_y (float): Y coordinate of reference line.

    scoring = sfs_object.get_params()["scoring"]

    sign, metric = get_metric_abbreviation_and_sign(scoring)

    if sfs_object.forward:
        sfs_plot_title = "Forward Feature Selection"
        sfs_plot_title = "Backward Feature Elimination"

    fig, ax = plt.subplots(1, 2, sharey=True)

    xlab = "Number of predictors included"

    if ref_y is not None:
        ax[0].axhline(y=ref_y, color="darkred", linestyle="--", lw=0.5)
        ax[1].axhline(y=ref_y, color="darkred", linestyle="--", lw=0.5)

    avg_score = [
        (int(i), sign * c["avg_score"]) for i, c in sfs_object.subsets_.items()

    averages = pd.DataFrame(avg_score, columns=["k_features", "avg_score"])

            title=f"Average {metric}",

    cv_scores = {int(i): sign * c["cv_scores"] for i, c in sfs_object.subsets_.items()}
            title=f"{metric} in CV splits",


    if not sfs_object.forward:

    main_title = f"{sfs_plot_title} with {}-fold CV " + f"\n{sub_title}"


    # Print results
    if not sfs_object.interrupted_:
        if sfs_object.is_parsimonious:
            note = "[Parsimonious]"
            k_selected = f"k = {len(sfs_object.k_feature_names_)}"
            score_at_k = f"avg. {metric} = {sign * sfs_object.k_score_:.3f}"
            note_2 = "Smallest number of predictors at best ± 1 SE score"
            note = "[Best]"
            if sign < 0:
                best = averages.nsmallest(1, "avg_score")
                best = averages.nlargest(1, "avg_score")
            k_selected = f"k = {int(best.k_features.values)}"
            score_at_k = f"avg. {metric} = {float(best.avg_score.values):.3f}"
            note_2 = "Number of predictors at best score"

        print(f"{k_selected}, {score_at_k} {note}\n({note_2})")

def sfs_list_features(sfs_result):
    """List features by order when they were added.
    Current implementation correctly works with forward selection only.

        sfs_result (SFS object)

    scoring = sfs_result.get_params()["scoring"]

    sign, metric = get_metric_abbreviation_and_sign(scoring)

    feature_dict = sfs_result.get_metric_dict()
    lst = [[*feature_dict[i]["feature_names"]] for i in feature_dict]
    feature = []

    if sfs_result.forward:
        for x, y in zip(lst[0::], lst[1::]):
        res = pd.DataFrame({
            "added_feature": [*lst[0], *feature],
            "metric": metric,
            "score": [sign * feature_dict[i]["avg_score"] for i in feature_dict],
        for x, y in zip(lst[0::], lst[1::]):
        res = pd.DataFrame({
            "feature": [*feature, *lst[-1]],
            "metric": metric,
            "score": [sign * feature_dict[i]["avg_score"] for i in feature_dict],

    return (
        res.assign(score_improvement=lambda x: sign * x.score.diff())
        .assign(score_percentage_change=lambda x: sign * x.score.pct_change() * 100)

# Functions for classification
def get_classification_scores(model, X, y):
    """Calculate scores of classification performance for a model.

    The following metrics are calculated:

    - No information rate
    - Accuracy
    - Balanced Accuracy (BAcc)
    - Balanced Accuracy adjusted to be between 0 and 1 (BAcc_01)
    - F1 Score
    - F1 macro average (F1_macro),
    - F1 weighted macro average (F1_weighted),
    - ROC AUC Value
    - Cohen's Kappa

        model (object): Scikit-learn classifier.
        X (array-like): Predictor variables.
        y (array-like): Target variable.

        dict: Dictionary with classification scores.

    y_pred = model.predict(X)
    y_proba = model.predict_proba(X)[:, 1]
    return {
        "n": len(y),
        "No_info_rate": max(y.mean(), 1 - y.mean()),
        "Accuracy": accuracy_score(y, y_pred),
        "BAcc": balanced_accuracy_score(y, y_pred),
        "BAcc_01": balanced_accuracy_score(y, y_pred, adjusted=True),
        "F1": f1_score(y, y_pred, pos_label=1),
        "F1_neg": f1_score(y, y_pred, pos_label=0),
        "TPR": recall_score(y, y_pred, pos_label=1),
        "TNR": recall_score(y, y_pred, pos_label=0),
        "PPV": precision_score(y, y_pred, pos_label=1),
        "NPV": precision_score(y, y_pred, pos_label=0),
        "Kappa": cohen_kappa_score(y, y_pred),
        "ROC_AUC": roc_auc_score(y, y_proba),

def print_classification_scores(
    title="--- All data ---",
    """Print classification scores for a set of models.

        models (dictionary): A dictionary of models.
        X: predictor variables.
        y: target variable.
        title (str, optional): Title to print. Defaults to "--- All data ---".
        color (str, optional): Text highlight color. Defaults to "green".
        precision (int, optional): Number of digits after the decimal point.
        sort_by (str, optional): Column name to sort by.
            Defaults to "No_info_rate".
        "No information rate: ",
        round(pd.Series(y).value_counts(normalize=True).max(), precision),
                name: get_classification_scores(model, X, y)
                for name, model in models.items()
        .sort_values(sort_by, ascending=False)
        .apply(my.highlight_max, color=color)

2 Exploration and Pre-Processing

2.1 Dataset Description

In this project, stroke prediction dataset (version 1) was used. Both dataset and its description are available on Kaggle. The dataset contains the following variables:

  1. id: Unique identifier.
  2. gender: “Male”, “Female” or “Other”.
  3. age: Age of the patient.
  4. hypertension: 0 if the patient doesn’t have hypertension, 1 if the patient has hypertension.
  5. heart_disease: 0 if the patient doesn’t have any heart diseases, 1 if the patient has a heart disease.
  6. ever_married: “No” or “Yes”.
  7. work_type: “children”, “Govt_job”, “Never_worked”, “Private” or “Self-employed”.
  8. Residence_type: “Rural” or “Urban”.
  9. avg_glucose_level: Average glucose level in blood.
  10. bmi: Body mass index.
  11. smoking_status: “formerly smoked”, “never smoked”, “smokes” or “Unknown”3.
  12. stroke: 1 if the patient had a stroke or 0 if not.


2.2 Import Data

Data file is in CSV format:

!head -n 5 data/healthcare-dataset-stroke-data.csv
9046,Male,67,0,1,Yes,Private,Urban,228.69,36.6,formerly smoked,1
51676,Female,61,0,0,Yes,Self-employed,Rural,202.21,N/A,never smoked,1
31112,Male,80,0,1,Yes,Private,Rural,105.92,32.5,never smoked,1

Read data and unify column name style:

file = "./data/healthcare-dataset-stroke-data.csv"
data_all = pd.read_csv(file).rename(columns=my.to_snake_case)

2.3 Initial Inspection

The initial inspection of the whole dataset and inference on the target variable revealed that:

  • The dataset has 5110 rows and 12 columns (11 explanatory and 1 target variable).
  • No duplicate rows (patient IDs) were found.
  • Some binary No/Yes variables have values “Yes” and “No”, some – 1 and 0. This can be unified.
  • Column bmi has explicit missing values and in smoking_status missing values are coded as “Unknown”.
  • The target:
    • Target variable variable is binary.
    • The classes of the target variable are imbalanced (95.1% cases belong to the negative class, Figure 2.1) and the difference between the class sample sizes is significant (chi-square test, p < 0.001).
    • Based on this dataset, Wilson’s 95% confidence interval indicates that we can expect between 4.4% and 5.4% of stroke cases in the population (Table 2.1). This interval does not include incidence rate of 11% mentioned reported in the introduction of this project. This may be due to the fact that the dataset is not representative of the population or methods of calculating the incidence rate are different.
  • Variables hypertension and heart_disease are also imbalanced (90.3% and 94.6% of cases belong to the negative class, respectively). Still this fact does not require any action.
  • There is only a single case in sex column with value “Other”. This case will be removed from the further analysis as not representative.
  • Class “Never_worked” in work_type column has only 22 cases. This must be investigated more as too small sample sizes may lead to unreliable results.
  • For many variables to use less memory, more efficient Python data types can be used.

Please, find more detailed results below.

Initial Inspection

Number of rows and columns:

# For later use (number of rows in the original dataset):
n_initial = data_all.shape[0]
(5110, 12)

A few rows of data:

id gender age hypertension heart_disease ever_married work_type residence_type avg_glucose_level bmi smoking_status stroke
0 9046 Male 67.00 0 1 Yes Private Urban 228.69 36.60 formerly smoked 1
1 51676 Female 61.00 0 0 Yes Self-employed Rural 202.21 NaN never smoked 1
2 31112 Male 80.00 0 1 Yes Private Rural 105.92 32.50 never smoked 1
3 60182 Female 49.00 0 0 Yes Private Urban 171.23 34.40 smokes 1
4 1665 Female 79.00 1 0 Yes Self-employed Rural 174.12 24.00 never smoked 1

General information about the columns:

an.col_info(data_all, style=True)
  column data_type memory_size n_unique p_unique n_missing p_missing n_dominant p_dominant dominant
1 id int64 40.9 kB 5,110 100.0% 0 0% 1 <0.1% 9,046
2 gender object 317.7 kB 3 0.1% 0 0% 2,994 58.6% Female
3 age float64 40.9 kB 104 2.0% 0 0% 102 2.0% 78.000000
4 hypertension int64 40.9 kB 2 <0.1% 0 0% 4,612 90.3% 0
5 heart_disease int64 40.9 kB 2 <0.1% 0 0% 4,834 94.6% 0
6 ever_married object 304.8 kB 2 <0.1% 0 0% 3,353 65.6% Yes
7 work_type object 333.4 kB 5 0.1% 0 0% 2,925 57.2% Private
8 residence_type object 316.8 kB 2 <0.1% 0 0% 2,596 50.8% Urban
9 avg_glucose_level float64 40.9 kB 3,979 77.9% 0 0% 6 0.1% 93.880000
10 bmi float64 40.9 kB 418 8.2% 201 3.9% 41 0.8% 28.700000
11 smoking_status object 342.8 kB 4 0.1% 0 0% 1,892 37.0% never smoked
12 stroke int64 40.9 kB 2 <0.1% 0 0% 4,861 95.1% 0

In this type of .col_info() tables:

  • dominant is the most frequent value;
  • p_ is a percentage of certain values;
  • n_ is a number of certain values;
  • in data_type, numeric data types (float and integer) are highlighted in blue, and the “category” data type is in green;
  • in n_unique, binary variables are in different green, and columns with a high number of unique values (\(> 10\)) are highlighted in blue;
  • in n_missing and p_missing, zero values are in grey;
  • in p_dominant, percentages above 90% are in orange;
  • errors or extremely suspicious values are highlighted in red.

Frequency tables for variables with \(\leq 10\) unique values:

an.summarize_discrete(data_all, sort="index", max_n_unique=10)
gender n percent
Female 2,994 58.6%
Male 2,115 41.4%
Other 1 <0.1%
hypertension n percent
0 4,612 90.3%
1 498 9.7%
heart_disease n percent
0 4,834 94.6%
1 276 5.4%
ever_married n percent
No 1,757 34.4%
Yes 3,353 65.6%
work_type n percent
Govt_job 657 12.9%
Never_worked 22 0.4%
Private 2,925 57.2%
Self-employed 819 16.0%
children 687 13.4%
residence_type n percent
Rural 2,514 49.2%
Urban 2,596 50.8%
smoking_status n percent
Unknown 1,544 30.2%
formerly smoked 885 17.3%
never smoked 1,892 37.0%
smokes 789 15.4%
stroke n percent
0 4,861 95.1%
1 249 4.9%

General info on numeric variables:

cols_numeric = ["age", "avg_glucose_level", "bmi"]
data_all[cols_numeric].describe(){"count": "{:.0f}"}, precision=1)
  count mean std min 25% 50% 75% max
age 5110 43.2 22.6 0.1 25.0 45.0 61.0 82.0
avg_glucose_level 5110 106.1 45.3 55.1 77.2 91.9 114.1 271.7
bmi 4909 28.9 7.9 10.3 23.5 28.1 33.1 97.6
Statistical Inference

the following inferential procedures were used:

  • chi-square goodness-of-fit test (testing hypothesis that the group sizes are equal);
  • Wilson’s 95% confidence intervals (ci) of proportions.

Note: The population here is the whole target group of clients represented by the dataset.

# Calculate counts and percentages
target_value_counts = data_all.stroke.value_counts()
target_frequences = (
    an.ci_proportion_binomial(target_value_counts, method="wilson")
    .rename({"index": "Stroke"}, axis=1)

cols = ["percent", "ci_lower", "ci_upper"]
target_frequences[cols] = target_frequences[cols].apply(my.counts_to_percentages)
    "The difference between stroke and non-stroke group sizes is significant \n("
    + an.test_chi_square_gof(target_value_counts)
    + ")."
The difference between stroke and non-stroke group sizes is significant 
(chi-square test, χ²(1, n = 5110) = 4162.53, p < 0.001).

Ratio of frequencies of the two groups:

round(target_value_counts[0] / target_value_counts[1], 2)
Code of the figure
# Plot
ax = my.plot_counts_with_labels(target_frequences, rot=0)
chi_sq_rez_1 = an.test_chi_square_gof(target_value_counts, output="short")

# Get the limits of the x-axis and y-axis
xlim = ax.get_xlim()
ylim = ax.get_ylim()

# Set the position of the annotation
x_pos = xlim[1] - 0.05 * (xlim[1] - xlim[0])  # 5% offset from the right edge
y_pos = ylim[1] - 0.05 * (ylim[1] - ylim[0])  # 5% offset from the top edge

# Add the annotation to the top right corner
    xy=(x_pos, y_pos),
Fig. 2.1. The frequencies of target variable values: “0” means “did not have stroke”, “1” means “had stroke”. Chi-squared goodness of fit test to test the hypothesis that the group sizes are equal was significant. This indicates that the target variable is imbalanced: there were approximately 19.5 times more people without stroke than with stroke in the dataset.
target_frequences.columns = ["Stroke", "n", "%", "95% CI (lower)", "95% CI (upper)"]"index")
Table 2.1. Stroke and non-stroke group sizes, percentages and 95% confidence intervals (CI) of the percentages.
Stroke n % 95% CI (lower) 95% CI (upper)
0 4,861 95.1% 95.6% 94.6%
1 249 4.9% 4.4% 5.4%

2.4 Pre-Processing: Group-Independent Steps

2.4.1 Feature Engineering Principles

In this project, domain-knowledge-based, training set EDA-based as well as other common feature engineering techniques were used to expand the dataset with additional variables.


Extensive exploratory data analysis (EDA) was exclusively conducted on the training data. Additionally, the process of feature engineering (FE) was informed by the insights gained from the EDA performed on the training set too. Several rounds of iterative FE and EDA were carried out, with the first round of EDA taking place before FE. However, to avoid redundancy and to streamline the report, only the final results are presented herein. Therefore, while reading the report, it may appear that the analysis followed a linear progression and that some parts are arranged in an unconventional order, whereas, in reality, it constituted an iterative process.

  1. stroke_incidence is a non-linear function based on patient’s age. In scientific literature, a plot with stroke incidence rates was found (see Figure 2.2). The Y axis values (incidence) of the red points (except the first one at <20 and the last one at 95+ years) from the plot were digitized4 using an online tool called WebPlotDigitizer and the centers of the age intervals were used as X-axis values. These data points were fitted to the equation of the form \(incidence = a⋅age^b\), and the resulting function was used to calculate the values of a new variable – expected stroke incidence.

    The calculation of stroke_incidence is implemented as function get_stroke_incidence().

Fig. 2.2. Trend of stroke incidence by age and sex (1998-2017). Source: Akyea et al. 2021
# Data representing approximate values of the red dots from
# (except from the last dot at age 95+) acquired by using WebPlotDigitizer
# (
ages = np.array([22, 27, 32, 37, 42, 47, 52, 57, 62, 67, 72, 77, 82, 87, 92])
incidence = np.array(
    [5, 12, 13, 24, 38, 63, 98, 145, 215, 309, 436, 582, 712, 858, 1000]

# Fit the power function to the data
params, covariance = curve_fit(power_function, ages, incidence)

# Extract the fitted parameters
a_fit, b_fit = params

# Make the function for the fitted curve
# NOTE: this is the main function that calculates "stroke incidence"
get_stroke_incidence = partial(power_function, a=a_fit, b=b_fit)

Let’s illustrate the results graphically:

# Create a range of age values for plotting
age_range = np.linspace(0, 100, 100)
# Calculate the corresponding y values using the fitted equation
incidence_fit = get_stroke_incidence(age_range)

# Plot the data points and the fitted curve
plt.scatter(ages, incidence, color="red", label="Data")
plt.plot(age_range, incidence_fit, color="tab:blue", label="Model (fitted curve)")
plt.ylabel("Incidence of Stroke per 100'000")
plt.title(f"Fitted Equation: y = {a_fit:.2e} * age^{b_fit:.3f}")
  1. stroke_risk_trend is one more non-linear mathematical function based on age. In this source, it was mentioned that after being 55 years old, the chance of stroke doubles every 10 years. It was decided to use a variable, that mimics a similar trend. An age threshold value (e.g., 55 years) should be selected, then the risk trend value before the threshold is kept constant at 1 and after the threshold, it doubles every 10 years. This mathematical function is implemented as function get_stroke_risk_trend(). Two thresholds will be tested:
    • as variable stroke_risk_55 where the threshold of 55 years is based on on literature and
    • as variable stroke_risk_40 where the threshold of 40 years was chosen based on the EDA results of training data.
    The graphical illustration of stroke_risk_trend:
risk_trend = get_stroke_risk_trend(age_range, age_threshold=40)

# Plot the data points and the fitted curve
plt.plot(age_range, risk_trend, label="Model (threshold = 40 years)", color="limegreen")
plt.ylabel("Stroke Risk Trend")
  1. health_risk_score is a numerical measure that summarizes an individual’s health risk by summing the presence of four factors (ranges from 0 to 4):
    • Hypertension presence (No = 0 or Yes = 1).
    • Heart disease presence (No = 0 or Yes = 1).
    • Overweight or obesity status (1 if BMI > 25, 0 otherwise).
    • Extreme diabetic condition (1 if average glucose level is above 165, 0 otherwise).

The chosen values of BMI and average glucose level are based on the EDA results of training data.

Additional variables were created by using common feature engineering techniques. Next, only the principles will be described (with a few examples), and the details can be found in the code.

  1. Continuous variables were discretized into categories:
    • by using domain knowledge (e.g., BMI categories);
    • by using the results of EDA (e.g., additional average glucose level category “Extreme diabetic” which seemed to form a separate cluster in a plot or a derivative from this variable avg_glucose_gr_above_165_01 where 1 if glucose level is above 165 and 0 otherwise);
    • arbitrary (e.g., age categories).
  2. Binary variables were created from categorical variables with more than two categories, e.g.:
    • residence_type (Urban/Rural) was converted into residence_is_urban_01 (0 if Rural, 1 if Urban).
  3. Binary missing value indicators were created for variables with missing values.
  4. Categories of ordinal variables were encoded as integers.
  5. Square of age variable.
  6. Several interaction variables were created. Several strategies were used:
    • by multiplying two continuous variables (e.g., age and avg_glucose_level);
    • by multiplying a continuous and a categorical variable represented as 0/1 (e.g., age and smoking_status which is 0 if “never smoked” and 1 otherwise);
    • by multiplying a continuous and a categorical variable represented as -1/1 (e.g., age and sex which is -1 for females and +1 for males);
    • etc.
  7. Ratio of two continuous variables was created (e.g., bmi and avg_glucose_level).

2.4.2 Pre-Processing Steps

In the group-independent pre-processing, the following steps were performed:

  1. Remove the only case with gender “Other”.
  2. Convert binary variables to both
    • 0/1 format (for modeling, tha names of these features will end in _01);
    • No/Yes format (for EDA; after EDA step these variables will be removed as redundant).
  3. Merge work types “children” and “Never_worked” into a single category “Never worked”.
  4. Names of some categories were fixed or changed to be more informative or consistent.
  5. Moving target variable to the beginning of the dataset and order the remaining variables alphabetically.
  6. More efficient data types were automatically chosen via klib.convert_datatypes.
  7. Variable id was removed as not informative.

2.4.3 Pre-Processing Code


Variables that names end in _01 are binary variables with 0 for “No” and 1 for “Yes”. Their categorical No/Yes counterparts (have no pattern in variable name) will be removed after EDA.

# Fix/Change data types ----------------------------------------------------
# Define datatypes
dtype_no_yes = pd.CategoricalDtype(categories=["No", "Yes"], ordered=True)

categories_smoking_status_1 = ["never smoked", "formerly smoked", "smokes"]
categories_smoking_status_2 = ["never smoked", "formerly smoked", "smokes", "Unknown"]
dtype_smoking_1 = pd.CategoricalDtype(
    categories=categories_smoking_status_1, ordered=True
dtype_smoking_2 = pd.CategoricalDtype(
    categories=categories_smoking_status_2, ordered=True

work_categories_0 = [
    "Never worked",
    "Government Sector",
    "Private Sector",
    "Self Employed",
work_categories = [
    "Never worked",
    "Government Sector",
    "Private Sector",
    "Self Employed",
dtype_work_type_0 = pd.CategoricalDtype(categories=work_categories_0)
dtype_work_type = pd.CategoricalDtype(categories=work_categories)

# Transform dataset
# fmt: off
data_all = (
    # Remove case that has only a single value of gender
    data_all.query("gender != 'Other'").assign(
        stroke_01=lambda df: df["stroke"],
        stroke=lambda df: my.convert_01_to_no_yes(df["stroke"]),
        age_square=lambda df: df["age"] ** 2,
        age_55_plus=lambda df: my.convert_bool_to_01(df["age"] >= 55),
        age_group=lambda df: pd.cut(
            bins=[0, 18, 35, 55, 1000],
            labels=["Child", "Young Adult", "Adult", "Elderly"],
        age_group_num=lambda df: df["age_group"].replace({
            "Child": 0,
            "Young Adult": 1,
            "Adult": 2,
            "Elderly": 3,
        # Non-linear functions based on age
        stroke_risk_40=lambda x: get_stroke_risk_trend(x["age"], age_threshold=40),
        stroke_risk_55=lambda x: get_stroke_risk_trend(x["age"], age_threshold=55),
        stroke_incidence=lambda x: get_stroke_incidence(x["age"]),
        # Further variables
        # Glucose concentration thresholds
        avg_glucose_gr_medical=lambda df: pd.cut(
            bins=[0, 100, 125, 1000],
            labels=["Normal", "Prediabetic", "Diabetic"],
        avg_glucose_gr_medical_num=lambda df: df["avg_glucose_gr_medical"].replace({
            "Normal": 0,
            "Prediabetic": 1,
            "Diabetic": 2,
        avg_glucose_gr_medical_165=lambda df: pd.cut(
            bins=[0, 100, 125, 165, 1000],
            labels=["Normal", "Prediabetic", "Diabetic", "Extreme diabetic"],
        avg_glucose_gr_medical_165_num=lambda df: df["avg_glucose_gr_medical_165"].replace({
            "Normal": 0,
            "Prediabetic": 1,
            "Diabetic": 2,
            "Extreme diabetic": 3,
        avg_glucose_is_diabetic_01=lambda df: (
            my.convert_bool_to_01(df["avg_glucose_level"] > 125)

        avg_glucose_gr_165=lambda df: pd.cut(
            df["avg_glucose_level"], bins=[0, 165, 1000]
        avg_glucose_gr_above_165_01=lambda df: (
                "(0, 165]": 0,
                "(165, 1000]": 1,
        bmi_is_unknown_01=lambda df: my.convert_bool_to_01(df["bmi"].isna()),
        # BMI categories:
        bmi_group=lambda x: pd.cut(
            bins=[0, 18.5, 25, 30, 10000],
            labels=["Underweight", "Normal", "Overweight", "Obese"],
        bmi_group_num=lambda df: (
                "Underweight": 0,
                "Normal": 1,
                "Overweight": 2,
                "Obese": 3,
        bmi_overweight_or_obese_01=lambda df: (
            my.convert_bool_to_01(df["bmi"] >= 25).fillna(0)
        bmi_normal_or_underweight_01=lambda df: (
            my.convert_bool_to_01(df["bmi"] < 25).fillna(0)
        gender=lambda df: df["gender"].astype("category"),
        gender_is_male_01=lambda df: my.convert_bool_to_01(df["gender"]=="Male"),
        residence_type=lambda df: df["residence_type"].astype("category"),
        residence_is_urban_01=lambda df: my.convert_bool_to_01(df["residence_type"] == "Urban"),
        hypertension_01=lambda df: df["hypertension"],
        hypertension=lambda df: my.convert_01_to_no_yes(df["hypertension"]),
        heart_disease_01=lambda df: df["heart_disease"],
        heart_disease=lambda df: my.convert_01_to_no_yes(df["heart_disease"]),
        ever_married=lambda df: df["ever_married"].astype(dtype_no_yes),
        ever_married_01=lambda df: my.convert_no_yes_to_01(df["ever_married"]),
        smoking_status_2=lambda df: df["smoking_status"].astype(dtype_smoking_2),
        smoking_status=lambda df: df["smoking_status"].astype(dtype_smoking_1),
        smoking_status_is_unknown_01=lambda df: (
        work_type_original=lambda df: (
                "children": "Children",
                "Never_worked": "Never worked",
                "Govt_job": "Government Sector",
                "Private": "Private Sector",
                "Self-employed": "Self Employed",
        work_type=lambda df: (
            .replace({"Children": "Never worked", "Never worked": "Never worked"})

        # Health risk score
        health_risk_score=lambda x: (
            + x["heart_disease_01"]
            + (x["bmi"] >= 25).astype(int) # Is overweight or obese
            + (x["avg_glucose_level"] > 165).astype(int) # Is extreme diabetic
        # Interaction terms
        age_bmi_interaction=lambda x: x["age"] * x["bmi"],
        age_gender_interaction=lambda x: (
            x["age"] * x["gender"].replace({"Male": 1, "Female": -1}).astype(int)
        age_heart_disease_interaction=lambda x: x["age"] * x["heart_disease_01"],
        age_hypertension_interaction=lambda x: x["age"] * x["hypertension_01"],
        age_smoking_interaction=lambda x: (
            x["age"] * (x["smoking_status"] != "never smoked")

        bmi_heart_disease_interaction=lambda x: x["bmi"] * x["heart_disease_01"],
        bmi_hypertension_interaction=lambda x: x["bmi"] * x["hypertension_01"],
        bmi_smoking_interaction=lambda x: (
            x["bmi"] * (x["smoking_status"] != "never smoked")
        hypertension_heart_disease_interaction_01=lambda x: (
            x["hypertension_01"] * x["heart_disease_01"]

        # Ratios
        avg_glucose_bmi_ratio=lambda x: x["avg_glucose_level"] / x["bmi"],
    # Use more efficient data types
# fmt: on

# Sort columns -------------------------------------------------------------
# Sort required columns and remove unnecessary ones
first_cols = ["stroke", "stroke_01"]
remove_cols = ["id"]
column_order = [
        [col for col in data_all.columns if col not in [*first_cols, *remove_cols]]

data_all = data_all[column_order]
n_updated = data_all.shape[0]

2.5 Create Training, Validation and Test Sets

To prevent data leakage and to get more rigorous estimates of model performance, the dataset was split into train, validation and test sets in the ratio 70:15:15. The split was stratified by the target variable to take class imbalance into account.

  • the train set is used to gain more insights on the data, train and tune models,
  • the validation set is used to test the performance of the candidate models,
  • the test set is used to evaluate the final model.
# The example to split data into 3 datasets
data_train, data_vt = train_test_split(
    data_all, test_size=0.30, random_state=22, stratify=data_all.stroke

data_validation, data_test = train_test_split(
    data_vt, test_size=0.5, random_state=22, stratify=data_vt.stroke

n_train, n_validation, n_test = (

Actual sample sizes in these sets are:

Code of the flowchart
# Create the flowchart
n_excluded = n_initial - n_updated
Fig. 2.3. Sample size in different sets. A single sample with gender value “Other” was removed.

2.6 EDA on Training Set

The initial inspection of the whole dataset (Section 2.3) was used to get an overview of data and to catch some obvious anomalies and discrepancies. To get deeper insights but no to leak data (get reliable model performance estimates), further investigation was made basing on training set only. It suggested some pre-processing ideas that were implemented in the previous section (Section 2.4).

Note. This section contains “quick and dirty” exploratory plots the purpose of which is to get insight about data and spot trends and discrepancies but not make all plots publication-ready. This means:

  • Plots might not be extremely pretty;
  • Colors in different plots might not match the same groups;
  • Plots may lack captions;
  • Legend position might not be optimal;
  • Other imperfections might be introduced/not fixed.

2.6.1 General EDA

Besides the target variable (stroke), hypertension, heart disease and their derivatives exhibit a high class imbalance. As it is expected to have a higher proportion of healthy individuals compared to those with specific medical conditions, no further actions need to be taken in this regard.

The indicator for missing BMI values is also imbalanced, but it is fortunate that there is not a significant percentage of missing values in this variable. Furthermore, the columns related to BMI display the same missing value pattern, as anticipated. In contrast, the smoking_status variable exhibits a different pattern (as illustrated in Figure 2.4).

The plots in the sweetviz report indicate that many variables, including age, the trend of stroke incidence, average glucose level, and certain BMI groups, are associated with the target variable to some extent. However, it appears that gender and residence type either show no significant associations with the target variable or exhibit extremely weak associations.

From the technical side, it can be noticed that now variables use more efficient Python data types (Table 2.2).

EDA: General info on columns, data types and values

Note that instead of int64 and float64, more efficient data types were chosen (mainly via klib.convert_datatypes()).

an.col_info(data_train, style=True)
Table 2.2. Summary of the variables in the dataset after pre-processing.
  column data_type memory_size n_unique p_unique n_missing p_missing n_dominant p_dominant dominant
1 stroke category 3.8 kB 2 0.1% 0 0% 3,402 95.1% No
2 stroke_01 int8 3.6 kB 2 0.1% 0 0% 3,402 95.1% 0
3 age float32 14.3 kB 104 2.9% 0 0% 73 2.0% 78.000000
4 age_55_plus int8 3.6 kB 2 0.1% 0 0% 2,339 65.4% 0
5 age_bmi_interaction float32 14.3 kB 2,834 79.3% 133 3.7% 5 0.1% 1,895.400024
6 age_gender_interaction float32 14.3 kB 202 5.6% 0 0% 42 1.2% -50.000000
7 age_group category 4.0 kB 4 0.1% 0 0% 1,237 34.6% Elderly
8 age_group_num int8 3.6 kB 4 0.1% 0 0% 1,237 34.6% 3
9 age_heart_disease_interaction float32 14.3 kB 40 1.1% 0 0% 3,393 94.9% 0.000000
10 age_hypertension_interaction float32 14.3 kB 57 1.6% 0 0% 3,228 90.3% 0.000000
11 age_smoking_interaction float32 14.3 kB 105 2.9% 0 0% 1,337 37.4% 0.000000
12 age_square float32 14.3 kB 104 2.9% 0 0% 73 2.0% 6,084.000000
13 avg_glucose_bmi_ratio float32 14.3 kB 3,437 96.1% 133 3.7% 2 0.1% 2.317313
14 avg_glucose_gr_165 category 3.7 kB 2 0.1% 0 0% 3,128 87.5% (0, 165]
15 avg_glucose_gr_above_165_01 int8 3.6 kB 2 0.1% 0 0% 3,128 87.5% 0
16 avg_glucose_gr_medical category 3.9 kB 3 0.1% 0 0% 2,200 61.5% Normal
17 avg_glucose_gr_medical_165 category 4.0 kB 4 0.1% 0 0% 2,200 61.5% Normal
18 avg_glucose_gr_medical_165_num int8 3.6 kB 4 0.1% 0 0% 2,200 61.5% 0
19 avg_glucose_gr_medical_num int8 3.6 kB 3 0.1% 0 0% 2,200 61.5% 0
20 avg_glucose_is_diabetic_01 int8 3.6 kB 2 0.1% 0 0% 2,866 80.1% 0
21 avg_glucose_level float32 14.3 kB 2,998 83.8% 0 0% 5 0.1% 91.849998
22 bmi float32 14.3 kB 388 10.9% 133 3.7% 31 0.9% 26.700001
23 bmi_group category 4.0 kB 4 0.1% 133 3.7% 1,341 37.5% Obese
24 bmi_group_num float16 7.2 kB 4 0.1% 133 3.7% 1,341 37.5% 3.000000
25 bmi_heart_disease_interaction float32 14.3 kB 116 3.2% 133 3.7% 3,281 91.8% 0.000000
26 bmi_hypertension_interaction float32 14.3 kB 180 5.0% 133 3.7% 3,136 87.7% 0.000000
27 bmi_is_unknown_01 int8 3.6 kB 2 0.1% 0 0% 3,443 96.3% 0
28 bmi_normal_or_underweight_01 int8 3.6 kB 2 0.1% 0 0% 2,465 68.9% 0
29 bmi_overweight_or_obese_01 int8 3.6 kB 2 0.1% 0 0% 2,332 65.2% 1
30 bmi_smoking_interaction float32 14.3 kB 352 9.8% 133 3.7% 1,308 36.6% 0.000000
31 ever_married category 3.8 kB 2 0.1% 0 0% 2,323 65.0% Yes
32 ever_married_01 int8 3.6 kB 2 0.1% 0 0% 2,323 65.0% 1
33 gender category 3.8 kB 2 0.1% 0 0% 2,131 59.6% Female
34 gender_is_male_01 int8 3.6 kB 2 0.1% 0 0% 2,131 59.6% 0
35 health_risk_score int8 3.6 kB 5 0.1% 0 0% 1,881 52.6% 1
36 heart_disease category 3.8 kB 2 0.1% 0 0% 3,393 94.9% No
37 heart_disease_01 int8 3.6 kB 2 0.1% 0 0% 3,393 94.9% 0
38 hypertension category 3.8 kB 2 0.1% 0 0% 3,228 90.3% No
39 hypertension_01 int8 3.6 kB 2 0.1% 0 0% 3,228 90.3% 0
40 hypertension_heart_disease_interaction_01 int8 3.6 kB 2 0.1% 0 0% 3,531 98.7% 0
41 residence_is_urban_01 int8 3.6 kB 2 0.1% 0 0% 1,811 50.6% 1
42 residence_type category 3.8 kB 2 0.1% 0 0% 1,811 50.6% Urban
43 smoking_status category 3.9 kB 3 0.1% 1,080 30.2% 1,337 37.4% never smoked
44 smoking_status_2 category 4.0 kB 4 0.1% 0 0% 1,337 37.4% never smoked
45 smoking_status_is_unknown_01 int8 3.6 kB 2 0.1% 0 0% 2,496 69.8% 0
46 stroke_incidence float32 14.3 kB 104 2.9% 0 0% 73 2.0% 562.276306
47 stroke_risk_40 float32 14.3 kB 43 1.2% 0 0% 1,580 44.2% 1.000000
48 stroke_risk_55 float32 14.3 kB 28 0.8% 0 0% 2,393 66.9% 1.000000
49 work_type category 4.0 kB 4 0.1% 0 0% 2,026 56.7% Private Sector
50 work_type_original category 4.1 kB 5 0.1% 0 0% 2,026 56.7% Private Sector

In this type of .col_info() tables:

  • dominant is the most frequent value;
  • p_ is a percentage of certain values;
  • n_ is a number of certain values;
  • in data_type, numeric data types (float and integer) are highlighted in blue, and the “category” data type is in green;
  • in n_unique, binary variables are in different green, and columns with a high number of unique values (\(> 10\)) are highlighted in blue;
  • in n_missing and p_missing, zero values are in grey;
  • in p_dominant, percentages above 90% are in orange;
  • errors or extremely suspicious values (if any) are highlighted in red.
EDA: Patterns of missing values

Only the columns with missing values are shown here.

EDA: Data Profiling Report for Training Data (sweetviz)

Sweetviz performs data profiling in respect to the target variable: light blue bars indicate distribution of a predictor variable and dark blue lines with points indicate the distribution (mean) of the target variable in each class/range of values of predictor variable.

Pay attentions to the “Associations” button in the report. Pay attention that for numeric variables Pearson’s correlation is calculated which does not represent non-linear relationships well.

if do_eda:
    report = sweetviz.analyze(
        [data_train, "Training Data"],
EDA: Frequency tables

List values and their counts for variables with less than 10 unique values. These results partially duplicate sweetviz report, but give flexibility to order classes not only by count but also by index (in some cases it may give more insights).

an.summarize_discrete(data_all, sort="index")
stroke n percent
No 4,860 95.1%
Yes 249 4.9%
stroke_01 n percent
0 4,860 95.1%
1 249 4.9%
age_55_plus n percent
0 3,330 65.2%
1 1,779 34.8%
age_group n percent
Child 856 16.8%
Young Adult 988 19.3%
Adult 1,486 29.1%
Elderly 1,779 34.8%
age_group_num n percent
0 856 16.8%
1 988 19.3%
2 1,486 29.1%
3 1,779 34.8%
avg_glucose_gr_165 n percent
(0, 165] 4,465 87.4%
(165, 1000] 644 12.6%
avg_glucose_gr_above_165_01 n percent
0 4,465 87.4%
1 644 12.6%
avg_glucose_gr_medical n percent
Normal 3,131 61.3%
Prediabetic 979 19.2%
Diabetic 999 19.6%
avg_glucose_gr_medical_165 n percent
Normal 3,131 61.3%
Prediabetic 979 19.2%
Diabetic 355 6.9%
Extreme diabetic 644 12.6%
avg_glucose_gr_medical_165_num n percent
0 3,131 61.3%
1 979 19.2%
2 355 6.9%
3 644 12.6%
avg_glucose_gr_medical_num n percent
0 3,131 61.3%
1 979 19.2%
2 999 19.6%
avg_glucose_is_diabetic_01 n percent
0 4,110 80.4%
1 999 19.6%
bmi_group n percent
Underweight 337 6.9%
Normal 1,242 25.3%
Overweight 1,409 28.7%
Obese 1,920 39.1%
bmi_group_num n percent
0.000000 337 6.9%
1.000000 1,242 25.3%
2.000000 1,409 28.7%
3.000000 1,920 39.1%
bmi_is_unknown_01 n percent
0 4,908 96.1%
1 201 3.9%
bmi_normal_or_underweight_01 n percent
0 3,530 69.1%
1 1,579 30.9%
bmi_overweight_or_obese_01 n percent
0 1,780 34.8%
1 3,329 65.2%
ever_married n percent
No 1,756 34.4%
Yes 3,353 65.6%
ever_married_01 n percent
0 1,756 34.4%
1 3,353 65.6%
gender n percent
Female 2,994 58.6%
Male 2,115 41.4%
gender_is_male_01 n percent
0 2,994 58.6%
1 2,115 41.4%
health_risk_score n percent
0 1,535 30.0%
1 2,654 51.9%
2 688 13.5%
3 211 4.1%
4 21 0.4%
heart_disease n percent
No 4,833 94.6%
Yes 276 5.4%
heart_disease_01 n percent
0 4,833 94.6%
1 276 5.4%
hypertension n percent
No 4,611 90.3%
Yes 498 9.7%
hypertension_01 n percent
0 4,611 90.3%
1 498 9.7%
hypertension_heart_disease_interaction_01 n percent
0 5,045 98.7%
1 64 1.3%
residence_is_urban_01 n percent
0 2,513 49.2%
1 2,596 50.8%
residence_type n percent
Rural 2,513 49.2%
Urban 2,596 50.8%
smoking_status n percent
never smoked 1,892 53.1%
formerly smoked 884 24.8%
smokes 789 22.1%
smoking_status_2 n percent
never smoked 1,892 37.0%
formerly smoked 884 17.3%
smokes 789 15.4%
Unknown 1,544 30.2%
smoking_status_is_unknown_01 n percent
0 3,565 69.8%
1 1,544 30.2%
work_type n percent
Never worked 709 13.9%
Government Sector 657 12.9%
Private Sector 2,924 57.2%
Self Employed 819 16.0%
work_type_original n percent
Children 687 13.4%
Never worked 22 0.4%
Government Sector 657 12.9%
Private Sector 2,924 57.2%
Self Employed 819 16.0%

2.6.2 ROC Analysis

ROC analysis has 2 purposes in this EDA:

  • evaluate how well separable are the classes of the target variable using values of a numeric variable. Here AUC score is used: AUC=1 indicates perfect separation between classes while AUC=0.5 indicates the accuracy of random guessing.
  • to find the threshold that separates the classes best (by maximizing the balanced accuracy score).

The most important results here:

  1. By comparing age, stroke_incidence, stroke_risk_40, and stroke_risk_55, the first 3 exhibit AUC=0.83 and the last one performs worst with AUC=0.79.
  2. Optimal age threshold for stroke and non-stoke cases is 54 years (in FE part it was rounded to 55 to be consistent with the literature).

More ROC result will be presented in the following (more appropriate) sections.

EDA and ROC analysis of age, stroke risk and stroke incidence variables
gr = my.convert_no_yes_to_01(data_train.stroke)
y = data_train.age

violinplot_with_roc_results(gr, y)
gr = my.convert_no_yes_to_01(data_train.stroke)
y = data_train.stroke_incidence

violinplot_with_roc_results(gr, y)
gr = my.convert_no_yes_to_01(data_train.stroke)
y = data_train.stroke_risk_40

violinplot_with_roc_results(gr, y)
gr = my.convert_no_yes_to_01(data_train.stroke)
y = data_train.stroke_risk_55

violinplot_with_roc_results(gr, y)

2.6.3 Hypertension, Heart Disease, Stroke, and Age

In the training set, there are much less people with hypertension (Figure 2.5) and heart disease (Figure 2.7) than without them. Yet, the proportion of people who experienced stroke is higher among those with hypertension and heart disease (Figure 2.6, Figure 2.8). It seems numbers of cases with hypertension, heart disease or stroke increase with age, but there is a bigger chance that hypertension may start at younger age than the other two conditions (Figure 2.10, Figure 2.11, Figure 2.12).

EDA: Hypertension and heart disease


Code of the figure
    kind="bar", title="Hypertension distribution"
Fig. 2.5. Hyperension distribution.
crosstab_hypertension = an.Crosstab("hypertension", "stroke", data=data_train)
display_crosstab(crosstab_hypertension, percentage="row")
Counts % (row)
stroke No Yes No Yes
No 3102 126 96.10 3.90
Yes 300 48 86.20 13.80
Code of the figure
crosstab_hypertension.barplot(normalize="rows", stacked=True, width=1);
Fig. 2.6. Stroke incidence by hypertension status.

Heart disease

Code of the figure
    kind="bar", title="Heart disease distribution"
Fig. 2.7. Heart disease distribution.
crosstab_hd = an.Crosstab("heart_disease", "stroke", data=data_train)
Counts % (column)
stroke No Yes No Yes
No 3246 147 95.40 84.50
Yes 156 27 4.60 15.50
Code of the figure
crosstab_hd.barplot(normalize="rows", stacked=True, width=1);
Fig. 2.8. Stroke incidence by heart disease status.
EDA: Age-related trends
Code of the figure
sns.histplot(x="age", data=data_train, multiple="stack", binwidth=5);
Fig. 2.9. Age distribution.
Code of the figure
sns.histplot(x="age", hue="hypertension", data=data_train, multiple="fill", binwidth=5)
Fig. 2.10. Hypertension status by age.
Code of the figure
sns.histplot(x="age", hue="heart_disease", data=data_train, multiple="fill", binwidth=5)
Fig. 2.11. Heart disease status by age.
Code of the figure
sns.histplot(x="age", hue="stroke", data=data_train, multiple="fill", binwidth=5)
Fig. 2.12. Stroke status by age.

2.6.4 Average Glucose Level

Average glucose level distribution suggests that there are 2 clusters of patients (Figure 2.13, Figure 2.16, Figure 2.17) and the normalized distribution indicates that stroke rates in these two groups are different (Figure 2.14). ROC analysis suggests that optimal threshold to separate stroke and non-stroke cases is 165.3 (Figure 2.15) and this number is between the mentioned clusters. In FE part, the threshold was rounded to 165 and the group above 165 was called “Extreme diabetic”. Comparing regular diabetic (avg. glucodse >125) and extreme diabetic (>165) groups, the latter has higher stroke rate (see Figure 2.18 and Figure 2.19).

EDA: Average glucose concentration related trends
Code of the figure
    binrange=(50, 280),
Fig. 2.13. Distribution of glucose levels in the training set.
Code of the figure
    binrange=(50, 280),
Fig. 2.14. Normalized distribution of glucose levels in the training set.
gr = my.convert_no_yes_to_01(data_train.stroke)
y = data_train.avg_glucose_level

violinplot_with_roc_results(gr, y, linecolor="black")
Fig. 2.15. Glucose level distribution for stroke and non-stroke groups with ROC analysis results.
Code of the figure
ax = sns.scatterplot(
    x="avg_glucose_level", y="bmi", hue="stroke", data=data_train, alpha=0.3
plt.axvline(125, color="tab:blue", linestyle="--")
ax.text(123, 100, "Non-diabetic", rotation=90, ha="right", va="top", color="tab:blue")
ax.text(127, 100, "Diabetic", rotation=90, ha="left", va="top", color="red")

plt.axvline(165, color="black", linestyle="--")
    "Optimal threshold\n(from ROC analysis)",
Fig. 2.16. Glucose level and BMI distribution. THe “Optimal threshold” is the threshold that separates stroke and non-stroke groups the best (according to ROC analysis).
Code of the figure
ax = sns.scatterplot(
    x="age", y="avg_glucose_level", hue="stroke", data=data_train, alpha=0.3

ax.axhline(y=165, color="black", linestyle="--")
    "Optimal threshold (based on ROC analysis)",
Fig. 2.17. Glucose level and age distribution.
crosstab = an.Crosstab("avg_glucose_is_diabetic_01", "stroke", data=data_train)
display_crosstab(crosstab, percentage="row")
Counts % (row)
stroke No Yes No Yes
0 2762 104 96.40 3.60
1 640 70 90.10 9.90
Code of the figure
crosstab.barplot(normalize="rows", stacked=True, width=1);
Fig. 2.18. Distribution of diabetic (glucose level > 125) vs. remaining groups.
crosstab = an.Crosstab("avg_glucose_gr_165", "stroke", data=data_train)
display_crosstab(crosstab, percentage="row")
Counts % (row)
stroke No Yes No Yes
(0, 165] 3018 110 96.50 3.50
(165, 1000] 384 64 85.70 14.30
Code of the figure
crosstab.barplot(normalize="rows", stacked=True, width=1);
Fig. 2.19. Distribution of extreme diabetic (glucose level > 165) vs.  remaining groups.

2.6.5 BMI

The analysis of BMI variable shows that the most common BMI values are in the range of overweghtness (Figure 2.20). If age is taken into account, underweightness is most common to children, normal weight to teenagers and young adults and obesity to older adults between 30 and 80 years old (Figure 2.24 and Figure 2.25). Normalized avg. glucose level plot indicates that stroke cases are slightly more common in the overweightness and obesity groups (Figure 2.21). The ROC analysis suggests that optimal BMI threshold is 25.6 (Figure 2.22) ad this is in alignment with the literature, which states that overweightness starts at BMI=25.

It is essential to emphasize that a higher proportion of missing BMI values is found within the group of individuals who have had a stroke (19.5% among stroke cases compared to 4.3% among non-stroke cases, as shown in Figure Figure 2.26). This observation suggests that the missingness of BMI values is likely obtained after the occurrence of a stroke and is not a causal factor for stroke. Consequently, it qualifies as a misleading predictor and will be excluded from the modeling process.

EDA: BMI-related trends
Code for BMI labels
def bmi_labels(ax, y_txt: float = 28):
    # Add vertical dashed lines
    thresholds = [18.5, 25, 30]
    for threshold in thresholds:
        ax.axvline(x=threshold, color="darkred", linestyle="--")

    # Define label positions and text values
    label_positions = [14, 23, 28.5, 34]
    label_values = ["underweight", "normal", "overweight", "obese"]

    # Add text labels with transparent backgrounds in a loop
    for x, label in zip(label_positions, label_values):
Code of the figure
ax = sns.histplot(
    x="bmi", hue="stroke", data=data_train, multiple="stack", binwidth=1.5
Fig. 2.20. Distribution of BMI values for stroke and no stroke groups.
Code of the figure
ax = sns.histplot(x="bmi", hue="stroke", data=data_train, multiple="fill", binwidth=3)
bmi_labels(ax, y_txt=0.2)
Fig. 2.21. Normalized distribution of BMI values for stroke and no stroke groups.
Code of the figure
data_train_wo_na = data_train.dropna(subset=["bmi"])
gr = my.convert_no_yes_to_01(data_train_wo_na.stroke)
y = data_train_wo_na.bmi

violinplot_with_roc_results(gr, y)
Fig. 2.22. BMI distribution for stroke and no stroke groups with ROC analysis results.
Code of the figure
sns.scatterplot(x="age", y="bmi", hue="stroke", data=data_train, alpha=0.3);
Fig. 2.23. Relationship between BMI and age.
Code of the figure
sns.histplot(x="age", hue="bmi_group", data=data_train, multiple="fill", binwidth=5)
Fig. 2.24. BMI distribution by age (proportions).
Code of the figure
sns.histplot(x="age", hue="bmi_group", data=data_train, multiple="dodge", binwidth=10)
Fig. 2.25. BMI distribution by age (counts).
crosstab_bmi_na = an.Crosstab("bmi_is_unknown_01", "stroke", data=data_train)
display_crosstab(crosstab_bmi_na, percentage="row")
Counts % (row)
stroke No Yes No Yes
0 3295 148 95.70 4.30
1 107 26 80.50 19.50
Code of the figure
    x="age", hue="bmi_is_unknown_01", data=data_train, multiple="fill", binwidth=5
Fig. 2.27. Distribution of BMI missing values by age.

2.6.6 Gender, Social Factors, and Smoking

In the training set, there are more females than males (Figure 2.28). However, gender differences in stroke incidence rates appear to be negligible (Figure 2.29), as are differences in residence types (Figure 2.32).

Individuals who have ever been married are more likely to have a stroke (Figure 2.35). However, this marital status is closely related to age (Figure 2.36): the majority of people who have been married are over 30 years old and older people are more likely to have a stroke (Figure 2.12).

Similarly, varying stroke incidence rates are observed across different work types (Figure 2.38), as well as smoking statuses (Figure 2.46). However, the “Never worked” category is more common among children and teenagers, while the proportion of self-employed individuals increases with age (Figure 2.39). Likewise, the proportion of people who have never smoked is higher at an extremely young age, and there is a proportional increase in former smokers at older ages (Figure 2.48).

The missingness of smoking status is related to a younger age (Figure 2.45) and this trend seems almost reversed comparing to marriage status (Figure 2.36).

What is more, work type categories “Never worked” and “children” (Figure 2.37) from the original form of variable were merged into “Never worked” (Figure 2.39) as they both were present in the younger age.

EDA: Gender


Code of the figure
data_train.value_counts("gender").plot(kind="bar", title="Gender distribution")
Fig. 2.28. Gender distribution.
crosstab_gender = an.Crosstab("gender", "stroke", data=data_train)
display_crosstab(crosstab_gender, percentage="row")
Counts % (row)
stroke No Yes No Yes
Female 2031 100 95.30 4.70
Male 1371 74 94.90 5.10
Code of the figure
crosstab_gender.barplot(normalize="rows", stacked=True, width=1);
Fig. 2.29. Stroke incidence by gender.
Code of the figure
sns.histplot(x="age", hue="gender", data=data_train, multiple="fill", binwidth=5)
Fig. 2.30. Gender distribution by age.
EDA: Residence type

Residence type

Code of the figure
    kind="bar", title="Residence type distribution"
Fig. 2.31. Residence type distribution.
crosstab_residence = an.Crosstab("residence_type", "stroke", data=data_train)
display_crosstab(crosstab_residence, percentage="row")
Counts % (row)
stroke No Yes No Yes
Rural 1686 79 95.50 4.50
Urban 1716 95 94.80 5.20
Code of the figure
crosstab_residence.barplot(normalize="rows", stacked=True, width=1);
Fig. 2.32. Stroke incidence by residence type.
Code of the figure
    x="age", hue="residence_type", data=data_train, multiple="fill", binwidth=5
Fig. 2.33. Residence type by age.
EDA: Marriage status
crosstab_married = an.Crosstab("ever_married", "stroke", data=data_train)
Counts % (column)
stroke No Yes No Yes
No 1233 20 36.20 11.50
Yes 2169 154 63.80 88.50
Code of the figure
Fig. 2.34. Stroke incidence by marriage status (counts).
Code of the figure
crosstab_married.barplot(normalize="rows", stacked=True, width=1);
Fig. 2.35. Stroke incidence by marriage status (proportions).
Code of the figure
sns.histplot(x="age", hue="ever_married", data=data_train, multiple="fill", binwidth=5)
Fig. 2.36. Marriage status by age.
EDA: Work types (before merging classes)
Code of the figure
Fig. 2.37. Work types by age (before merging “children” and “Never worked” classes).
EDA: Work types (final)
crosstab_work = an.Crosstab("work_type", "stroke", data=data_train)
display_crosstab(crosstab_work, percentage="row")
Counts % (row)
stroke No Yes No Yes
Never worked 506 2 99.60 0.40
Government Sector 449 21 95.50 4.50
Private Sector 1925 101 95.00 5.00
Self Employed 522 50 91.30 8.70
Code of the figure
crosstab_work.barplot(normalize="rows", stacked=True, width=1, rot=0);
Fig. 2.38. Stroke incidence by work types.
Code of the figure
Fig. 2.39. Work types by age (proportions).
Code of the figure
Fig. 2.40. Work types by age.
EDA: Smoking status
crosstab_smoking_2 = an.Crosstab("smoking_status_2", "stroke", data=data_train)
display_crosstab(crosstab_smoking_2, percentage="row")
Counts % (row)
stroke No Yes No Yes
never smoked 1272 65 95.10 4.90
formerly smoked 571 48 92.20 7.80
smokes 514 26 95.20 4.80
Unknown 1045 35 96.80 3.20
Code of the figure
crosstab_smoking_2.barplot(normalize="rows", stacked=True, width=1);
Fig. 2.41. Stroke incidence by smoking status (explicit missing values).
Code of the figure
Fig. 2.42. Smoking status (explicit missing values) by age (proportions).
Code of the figure
Fig. 2.43. Smoking status (explicit missing values) by age.
crosstab_smoking_na = an.Crosstab("smoking_status_is_unknown_01", "stroke", data=data_train)
display_crosstab(crosstab_smoking_na, percentage="row")
Counts % (row)
stroke No Yes No Yes
0 2357 139 94.40 5.60
1 1045 35 96.80 3.20
Code of the figure
crosstab_smoking_na.barplot(normalize="rows", stacked=True, width=1);
Fig. 2.44. Stroke incidence by the missingness of smoking status.
Code of the figure
Fig. 2.45. Missingness of smoking status by age (proportions).
crosstab_smoking = an.Crosstab("smoking_status", "stroke", data=data_train)
display_crosstab(crosstab_smoking, percentage="row")
Counts % (row)
stroke No Yes No Yes
never smoked 1272 65 95.10 4.90
formerly smoked 571 48 92.20 7.80
smokes 514 26 95.20 4.80
Code of the figure
crosstab_smoking.barplot(normalize="rows", stacked=True, width=1);
Fig. 2.46. Stroke incidence by smoking status.
Code of the figure
Fig. 2.47. Smoking status by age.
Code of the figure
    hue_order=["never smoked", "smokes", "formerly smoked"],
Fig. 2.48. Smoking status by age (different order of categories).

2.6.7 Engineered Features

A higher value of engineered feature “health risk score”, which summarizes the presence of 4 medical conditions, seems to reflect higher stroke incidence (Figure 2.49). Older people tend to have higher health risk scores too (Figure 2.50). Higher values of stroke incidence and stroke risk trend are associated with higher stroke incidence as well (Figure 2.51, Figure 2.52, and Figure 2.53).

EDA: Health risk score
Code of the figure
    binrange=(-0.5, 5.5),
Fig. 2.49. Stroke incidence by health risk score (proportions).
Code of the figure
    x="age", hue="health_risk_score", data=data_train, multiple="fill", binwidth=5
Fig. 2.50. Health risk score by age.
EDA: Non-linear functions of age
Code of the figure
    x="stroke_incidence", hue="stroke", data=data_train, multiple="fill", binwidth=100
Fig. 2.51. Stroke status by stroke incidence trend value.
Code of the figure
    x="stroke_risk_40", hue="stroke", data=data_train, multiple="fill", binwidth=2
Fig. 2.52. Stroke status by stroke risk trend value (threshold=40).
Code of the figure
    x="stroke_risk_55", hue="stroke", data=data_train, multiple="fill", binwidth=.6
Fig. 2.53. Stroke status by stroke risk trend value (threshold=55).

Stroke risk and stroke incidence shows similar trends with age. However, their dependence to each other is not linear.

sns_plot = sns.scatterplot(x="age", y="stroke_risk_40", data=data_train)
sns_plot.set_ylim(0, None);
sns_plot = sns.scatterplot(x="age", y="stroke_incidence", data=data_train)
sns_plot.set_ylim(0, None);
sns_plot = sns.scatterplot(x="stroke_incidence", y="stroke_risk_40", data=data_train)
sns_plot.set_ylim(0, None);
EDA: Interaction features
sns_plot = sns.scatterplot(
sns_plot = sns.scatterplot(x="age", y="age_bmi_interaction", data=data_train);
sns_plot = sns.scatterplot(
    x="age", y="age_gender_interaction", hue="gender", data=data_train

2.6.8 Relationships Between Variables

Mutual information (MI) scores are highest between the target and variables derived from age and lowest for BMI-based variables (Table 2.3). Point-biserial correlation is strongest between the target and stroke_incidence as well as stroke_risk_40 variables (Table 2.4).

While correlation coefficient matrix allows seeing linear relationships (both positive and negative) between variables (Figure 2.54), a matrix of absolute correlation coefficient values with rows and columns ordered according to hierarchical clustering results reveals that there are 7 groups of highly inter-correlated variables (Figure 2.55).

EDA: Mutual Information (target vs. predictors)

Mutual information shows strength of association between two variables (in this case, between the target and the remaining variables).

Advantages of mutual information:

  • Shows strength of relationship between the target and the remaining variables.
  • Captures not only linear relationships but also non-linear ones.

Drawbacks of mutual information:

  • Does not show direction of the relationship;
  • Calculation algorithm includes randomness to break ties, so the results might be slightly different each time.
    data_train.dropna(), target="stroke_01", drop=["stroke"], random_state=31
Table 2.3. Mutual information between the target variable and the features.
  variable_1 variable_2 mutual_info
1 stroke_01 age_gender_interaction 0.036
2 stroke_01 age_square 0.036
3 stroke_01 age 0.035
4 stroke_01 stroke_risk_40 0.034
5 stroke_01 stroke_incidence 0.033
6 stroke_01 age_bmi_interaction 0.030
7 stroke_01 stroke_risk_55 0.029
8 stroke_01 age_group_num 0.025
9 stroke_01 age_group 0.025
10 stroke_01 age_55_plus 0.021
11 stroke_01 health_risk_score 0.016
12 stroke_01 age_hypertension_interaction 0.014
13 stroke_01 bmi_hypertension_interaction 0.011
14 stroke_01 avg_glucose_gr_medical_165 0.011
15 stroke_01 avg_glucose_gr_medical_165_num 0.011
16 stroke_01 avg_glucose_gr_165 0.010
17 stroke_01 avg_glucose_gr_above_165_01 0.010
18 stroke_01 hypertension_01 0.009
19 stroke_01 hypertension 0.009
20 stroke_01 avg_glucose_level 0.008
21 stroke_01 avg_glucose_bmi_ratio 0.007
22 stroke_01 avg_glucose_gr_medical 0.006
23 stroke_01 avg_glucose_gr_medical_num 0.006
24 stroke_01 avg_glucose_is_diabetic_01 0.006
25 stroke_01 age_smoking_interaction 0.006
26 stroke_01 heart_disease_01 0.004
27 stroke_01 heart_disease 0.004
28 stroke_01 work_type_original 0.004
29 stroke_01 work_type 0.004
30 stroke_01 age_heart_disease_interaction 0.003
31 stroke_01 ever_married 0.003
32 stroke_01 ever_married_01 0.003
33 stroke_01 bmi_group 0.001
34 stroke_01 bmi_overweight_or_obese_01 0.001
35 stroke_01 bmi_normal_or_underweight_01 0.001
36 stroke_01 hypertension_heart_disease_interaction_01 0.001
37 stroke_01 smoking_status 0.001
38 stroke_01 smoking_status_2 0.001
39 stroke_01 gender 0.000
40 stroke_01 gender_is_male_01 0.000
41 stroke_01 residence_is_urban_01 0.000
42 stroke_01 residence_type 0.000
43 stroke_01 smoking_status_is_unknown_01 0.000
44 stroke_01 bmi_group_num 0.000
45 stroke_01 bmi 0.000
46 stroke_01 bmi_heart_disease_interaction 0.000
47 stroke_01 bmi_smoking_interaction 0.000
48 stroke_01 bmi_is_unknown_01 0.000
EDA: Point-Biserial Correlation (target vs. predictors)

The table below shows relationship between binary target variable and the remaining numeric variables expressed as point-biserial correlation. This relationship quantifies how much the continuous variable differs between the two binary groups.

Details. The correlation coefficient (denoted as r_pb in the table below) was calculated, along with the corresponding p-values. Additionally, p-values adjusted for multiple comparisons using the Holm-Šídák method were computed. The correlation analysis results are summarized in Table 2.4.

an.get_pointbiserial_corr_scores(data_train, target="stroke_01")
Table 2.4. Point-biserial correlation (r_pb) between the target variable and the numeric/binary features. For multiple comparisons Holm-Sidak correction is used (p_adj).
  variable_1 variable_2 n r_pb p p_adj
1 stroke_01 stroke_incidence 3576 0.286 <0.001 <0.001
2 stroke_01 stroke_risk_40 3576 0.282 <0.001 <0.001
3 stroke_01 age_square 3576 0.273 <0.001 <0.001
4 stroke_01 stroke_risk_55 3576 0.271 <0.001 <0.001
5 stroke_01 age 3576 0.244 <0.001 <0.001
6 stroke_01 age_55_plus 3576 0.232 <0.001 <0.001
7 stroke_01 age_bmi_interaction 3443 0.214 <0.001 <0.001
8 stroke_01 age_group_num 3576 0.206 <0.001 <0.001
9 stroke_01 health_risk_score 3576 0.166 <0.001 <0.001
10 stroke_01 avg_glucose_gr_above_165_01 3576 0.166 <0.001 <0.001
11 stroke_01 age_hypertension_interaction 3576 0.157 <0.001 <0.001
12 stroke_01 avg_glucose_level 3576 0.137 <0.001 <0.001
13 stroke_01 hypertension_01 3576 0.136 <0.001 <0.001
14 stroke_01 bmi_hypertension_interaction 3443 0.136 <0.001 <0.001
15 stroke_01 bmi_is_unknown_01 3576 0.134 <0.001 <0.001
16 stroke_01 age_smoking_interaction 3576 0.131 <0.001 <0.001
17 stroke_01 avg_glucose_gr_medical_165_num 3576 0.131 <0.001 <0.001
18 stroke_01 bmi_heart_disease_interaction 3443 0.116 <0.001 <0.001
19 stroke_01 avg_glucose_is_diabetic_01 3576 0.116 <0.001 <0.001
20 stroke_01 ever_married_01 3576 0.112 <0.001 <0.001
21 stroke_01 age_heart_disease_interaction 3576 0.111 <0.001 <0.001
22 stroke_01 heart_disease_01 3576 0.107 <0.001 <0.001
23 stroke_01 avg_glucose_gr_medical_num 3576 0.103 <0.001 <0.001
24 stroke_01 bmi_normal_or_underweight_01 3576 -0.082 <0.001 <0.001
25 stroke_01 bmi_group_num 3443 0.074 <0.001 <0.001
26 stroke_01 avg_glucose_bmi_ratio 3443 0.073 <0.001 <0.001
27 stroke_01 bmi 3443 0.056 0.001 0.008
28 stroke_01 smoking_status_is_unknown_01 3576 -0.050 0.003 0.021
29 stroke_01 hypertension_heart_disease_interaction_01 3576 0.044 0.008 0.046
30 stroke_01 bmi_overweight_or_obese_01 3576 0.026 0.120 0.472
31 stroke_01 residence_is_urban_01 3576 0.018 0.285 0.738
32 stroke_01 gender_is_male_01 3576 0.010 0.559 0.914
33 stroke_01 bmi_smoking_interaction 3443 0.007 0.699 0.914
34 stroke_01 age_gender_interaction 3576 -0.006 0.732 0.914
EDA: Pearson Correlation (matrix)

This correlation matrix shows strength of linear relationship between numeric variables.

Code of the figure
data_num = data_train.select_dtypes("number")
corr_coefs = data_num.corr(method="pearson").fillna(0)

g = sns.clustermap(
    annot_kws={"size": 8},
    figsize=(13, 10),
    cbar_pos=(0.94, 0.91, 0.03, 0.1),
    cbar_kws={"location": "right"},
    dendrogram_ratio=(0.075, 0),

    "Pearson Correlation (Matrix with Hierarchical Clustering)",
Fig. 2.54. Matrix of Pearson correlation coefficients. Hierarchical clustering is used to group variables with similar correlation patterns.
EDA: Pearson Correlation (matrix of absolute values)
Code of the figure
data_num = data_train.select_dtypes("number")
corr_coefs = data_num.corr(method="pearson").fillna(0).abs()

g = sns.clustermap(
    annot_kws={"size": 8},
    figsize=(13, 10),
    cbar_pos=(0.94, 0.91, 0.03, 0.1),
    cbar_kws={"location": "right"},
    dendrogram_ratio=(0.075, 0),

    "Absolute Values of Pearson Correlation (Matrix with Hierarchical Clustering)",
Fig. 2.55. Matrix of absolute values of Pearson correlation coefficients. Hierarchical clustering is used to group variables with similar correlation patterns. Absolute values are used to make it easier to identify variable clusters that share correlation patterns.

3 Modeling

There were several main stages of the modelling procedure:

  1. Remove redundant variables (e.g., binary No/Yes variables that have 0/1 counterparts);
  2. Remove highly correlated variables (to make calculations in the next steps more efficient);
  3. Pre-tune models (to avoid totally bad models);
  4. Perform sequential feature selection (selecting best performing feature combinations for 3 model candidates);
  5. Fine-tune hyperparameters (and selecting the best model);
  6. Evaluate the final model on the test set.

F1 score was used as the main metric for model evaluation as it takes class imbalance into account. Other metrics were also considered as additional indicators.

3.1 Feature Selection: Filtering

3.1.1 Identify Correlated Variable Groups

The initial phase of feature selection involves the identification of variables that lack informativeness or contain redundant information. To achieve this, the “smart correlation” algorithm implemented in the feature-engine package will be employed. This algorithm is expected to identify clusters of correlated variables similar to ones illustrated in Figure 2.55.

Correlated feature selection:

  1. Detect groups of correlated features based on Pearson’s correlation (\(|r| > 0.75\)).
  2. Select the best feature from each group based on Random Forest classification performance (ROC AUC).
# Identify correlated variable sets -----------------------------------------

X_train = (
    data_train.drop(columns=["stroke", "stroke_01"])
    .sample(frac=1, axis=1, random_state=20)
y_train = data_train["stroke_01"]

X_train = (

# Random forest for feature selection
rf = RandomForestClassifier(n_estimators=30, random_state=20, n_jobs=-1)

# Correlation selector
sel = SmartCorrelatedSelection(
    variables=None,  # All numerical variables

# Apply the selector, y_train);

List sets of correlated features.

[{'age', 'age_55_plus', 'age_bmi_interaction', 'age_group_num', 'age_square'},
 {'age_gender_interaction', 'gender_is_male_01'},
 {'stroke_incidence', 'stroke_risk_40', 'stroke_risk_55'},
 {'age_smoking_interaction', 'bmi_smoking_interaction'}]

Basing on EDA results, I manually updated the list of correlated feature groups. I ranked the importance of these features based on the results of the Random Forest classifier and highlighted the best performing features of each group in green.

# Mimic the output of the feature selector by additionally calculating ROC AUC

correlated_feature_sets = [
        "avg_glucose_bmi_ratio",  # Added
    {'age_gender_interaction', 'gender_is_male_01'},
    {"age_smoking_interaction", "bmi_smoking_interaction"},

# Select each group of correlated features
i = 0
for group in correlated_feature_sets:
    i += 1
    rez = []

    # Build random forest with cross validation for each feature
    for feature in group:
        model = cross_validate(

        rez.append((f"Group {i}", model["test_score"].mean(), feature))
        pd.DataFrame(rez, columns=["Group", "AUC", "Feature"])
        .sort_values("AUC", ascending=False)
        .style.highlight_max(color="green", axis=0, subset="AUC")
        .set_properties(**{"width": "6em"})
  Group AUC Feature
1 Group 1 0.818 age_square
3 Group 1 0.818 age
4 Group 1 0.818 stroke_incidence
6 Group 1 0.817 stroke_risk_40
0 Group 1 0.801 stroke_risk_55
2 Group 1 0.782 age_group_num
5 Group 1 0.756 age_55_plus
7 Group 1 0.654 age_bmi_interaction
  Group AUC Feature
3 Group 2 0.627 avg_glucose_gr_above_165_01
0 Group 2 0.616 avg_glucose_gr_medical_165_num
1 Group 2 0.607 avg_glucose_is_diabetic_01
4 Group 2 0.587 avg_glucose_gr_medical_num
2 Group 2 0.532 avg_glucose_bmi_ratio
5 Group 2 0.522 avg_glucose_level
  Group AUC Feature
1 Group 3 0.594 hypertension_01
0 Group 3 0.575 age_hypertension_interaction
2 Group 3 0.523 bmi_hypertension_interaction
  Group AUC Feature
1 Group 4 0.605 bmi_group_num
0 Group 4 0.588 bmi_normal_or_underweight_01
2 Group 4 0.550 bmi
3 Group 4 0.529 bmi_overweight_or_obese_01
  Group AUC Feature
1 Group 5 0.554 heart_disease_01
2 Group 5 0.510 age_heart_disease_interaction
0 Group 5 0.487 bmi_heart_disease_interaction
  Group AUC Feature
0 Group 6 0.767 age_gender_interaction
1 Group 6 0.464 gender_is_male_01
  Group AUC Feature
1 Group 7 0.719 age_smoking_interaction
0 Group 7 0.551 bmi_smoking_interaction

The tables above let me decide which correlated variables to remove from the dataset. Basing on the result, I decided to keep only the best-ranked variable in each group with the following exceptions:

  • From group 1, I decided to keep first 4 variables as they perform similarly well.
  • I also decided to beep both variables from group 6.

3.1.2 Remove Variables

The whole list of variables to remove at this stage is as follows:

cols_to_drop = [
    # Remove categoricals as duplicates of
    # other numeric variable with similar name
    # Remove categoricals as duplicates of
    # other numeric variable with different name
    # Remove based on domain knowledge and EDA results:
    # Remove as inter-correlated to other variables:

data_train_2 = data_train.drop(columns=cols_to_drop)

Nineteen variables remains in the dataset after this step (Table 3.1).

(3576, 19)
an.col_info(data_train_2, style=True)
Table 3.1. List of variables used in modeling.
  column data_type memory_size n_unique p_unique n_missing p_missing n_dominant p_dominant dominant
1 stroke_01 int8 3.6 kB 2 0.1% 0 0% 3,402 95.1% 0
2 age float32 14.3 kB 104 2.9% 0 0% 73 2.0% 78.000000
3 age_gender_interaction float32 14.3 kB 202 5.6% 0 0% 42 1.2% -50.000000
4 age_smoking_interaction float32 14.3 kB 105 2.9% 0 0% 1,337 37.4% 0.000000
5 age_square float32 14.3 kB 104 2.9% 0 0% 73 2.0% 6,084.000000
6 avg_glucose_gr_above_165_01 int8 3.6 kB 2 0.1% 0 0% 3,128 87.5% 0
7 bmi_group_num float16 7.2 kB 4 0.1% 133 3.7% 1,341 37.5% 3.000000
8 ever_married_01 int8 3.6 kB 2 0.1% 0 0% 2,323 65.0% 1
9 gender_is_male_01 int8 3.6 kB 2 0.1% 0 0% 2,131 59.6% 0
10 health_risk_score int8 3.6 kB 5 0.1% 0 0% 1,881 52.6% 1
11 heart_disease_01 int8 3.6 kB 2 0.1% 0 0% 3,393 94.9% 0
12 hypertension_01 int8 3.6 kB 2 0.1% 0 0% 3,228 90.3% 0
13 hypertension_heart_disease_interaction_01 int8 3.6 kB 2 0.1% 0 0% 3,531 98.7% 0
14 residence_is_urban_01 int8 3.6 kB 2 0.1% 0 0% 1,811 50.6% 1
15 smoking_status category 3.9 kB 3 0.1% 1,080 30.2% 1,337 37.4% never smoked
16 smoking_status_is_unknown_01 int8 3.6 kB 2 0.1% 0 0% 2,496 69.8% 0
17 stroke_incidence float32 14.3 kB 104 2.9% 0 0% 73 2.0% 562.276306
18 stroke_risk_40 float32 14.3 kB 43 1.2% 0 0% 1,580 44.2% 1.000000
19 work_type category 4.0 kB 4 0.1% 0 0% 2,026 56.7% Private Sector

3.1.3 Prepare for Modeling

Now, the training and validation sets will be prepared for the modeling process. This involves separating the target variable from the remaining features.

target = "stroke_01"

X_train = data_train_2.drop(target, axis=1)
y_train = data_train_2[target].to_numpy()

# data_validation_2 = pre_processing.transform(data_validation)
X_validation = data_validation.drop(target, axis=1)
y_validation = data_validation[target].to_numpy()

3.2 Pre-Processing: Group-Dependent Steps

Group-dependent pre-processing steps are implemented as a pipeline. These steps are repeated during cross-validation (CV) procedure to prevent data leakage. The steps illustrated in the schematic below. The main highlights are:

  1. There are 3 parallel pre-processing pipelines: for binary (selected by name pattern _01), numeric, and categorical variables.
  2. The numeric pipeline consists of:
    • imputation of missing values with median;
    • scaling with standard scaler.
  3. The categorical pipeline consists of:
    • imputation of missing values with most frequent value;
    • one-hot encoding.
  4. In the binary pipeline variables are just passed through.
# Group-dependent pre-processing steps
# that will be performed before each model re-fitting

# Select numeric variables by data type and name pattern
select_numeric_nonbinary = make_column_selector(
    dtype_include="number", pattern="^(?!.*_01$).*$"
# Select binary variables by data type and name pattern
select_binary_01 = make_column_selector(dtype_include="number", pattern="_01$")
# Select categorical variables by data type
select_categorical = make_column_selector(dtype_include=["object", "category"])

# Create the pipelines for each of the 3 groups of variables
numeric_transformer = Pipeline(
    steps=[("imputer", SimpleImputer(strategy="median")), ("scaler", StandardScaler())]

categorical_transformer = Pipeline(
        ("imputer", SimpleImputer(strategy="most_frequent")),
        ("onehot", OneHotEncoder(sparse_output=False)),

group_dependent_preprocessor = ColumnTransformer(
        ("numeric", numeric_transformer, select_numeric_nonbinary),
        ("binary_01", "passthrough", select_binary_01),
        ("categorical", categorical_transformer, select_categorical),

                                                 ('scaler', StandardScaler())]),
                                 <sklearn.compose._column_transformer.make_column_selector object at 0x0000029F3804F850>),
                                ('binary_01', 'passthrough',
                                 <sklearn.compose._column_transformer.make_column_selector object at 0x0000029F3CC5FA50>),
                                 <sklearn.compose._column_transformer.make_column_selector object at 0x0000029F3CC5D6D0>)],
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with

3.3 Model Pre-Tuning

In this section, a few hyperparameters for 8 ML algorithms (logistic regression, k-nearest neighbors, support vector machine for classification (SVC), random forest, Gaussian Naive Bayes, XGBoost, LightGBM and CatBoost) will be pre-tuned and their performance will be compared. This is a preparation step before sequential feature selection.

Two lists of models, along with their corresponding hyperparameter grids or distributions, have been created. Class imbalance will be addressed by adjusting class weights, and for most of the models, there is an automatic option for this.

Why have these models been pre-tuned using two different strategies? For models with fewer parameters to tune, we employ grid search, as it offers reproducibility. However, for models with numerous parameters, grid search becomes impractical, leading us to use Bayesian optimization. Unfortunately, I have observed that the results of “OptunaSearchCV” are not reproducible on my computer, even when a random seed is set.

The results are summarized in Table 3.2 (training set) Table 3.3 (validation set). Based on the validation set results, most of the models perform similarly (F1 score is between 0.2 and 0.3), with the exception of kNN that is significantly worse. Yet, as dataset is small, the calculation will not take long and all models will be considered in the next step.

Define search spaces:

# The list of classifiers along with their respective hyperparameter grids.
# Boosting algorithms are excluded here.
classifiers = [
        "Logistic Regression",
        LogisticRegression(random_state=1, max_iter=1000, class_weight="balanced"),
        {"classifier__C": [0.001, 0.01, 0.1, 1, 10, 100, 1000]},
    ("Naive Bayes", GaussianNB(), {}),
        {"classifier__n_neighbors": [3, 5, 7, 10, 12, 15, 17, 20]},
        SVC(random_state=1, probability=True, class_weight="balanced"),
                "classifier__kernel": ["linear"],
                "classifier__C": [0.01, 0.1, 1, 10, 100],
                "classifier__kernel": ["rbf"],
                "classifier__C": [0.01, 0.1, 1, 10, 100],
                "classifier__gamma": [0.01, 0.1, 1, 10, 100],
        "Random Forest",
        RandomForestClassifier(random_state=1, class_weight="balanced"),
            "classifier__n_estimators": [50, 100, 200, 300],
            "classifier__max_depth": [3, 5, 7, 9, None],
# The list of classifiers along with their respective hyperparameter distributions.
# Boosting algorithms only.
classifiers_boost = [
        XGBClassifier(random_state=1, enable_categorical=False),
            "classifier__n_estimators": IntDistribution(50, 1000, step=50),
            "classifier__max_depth": IntDistribution(1, 12),
            "classifier__scale_pos_weight": FloatDistribution(1, 30),
            "classifier__min_child_weight": IntDistribution(1, 12),
            "classifier__gamma": FloatDistribution(0, 1),
            "classifier__reg_alpha": FloatDistribution(0, 1),
            "classifier__subsample": FloatDistribution(0.1, 1),
            "classifier__colsample_bytree": FloatDistribution(0.1, 1),
            "classifier__n_estimators": IntDistribution(50, 1000, step=50),
            "classifier__max_depth": IntDistribution(1, 12),
            "classifier__boosting_type": CategoricalDistribution(["gbdt"]),
            "classifier__lambda_l1": FloatDistribution(1e-8, 10.0, log=True),
            "classifier__lambda_l2": FloatDistribution(1e-8, 10.0, log=True),
            "classifier__num_leaves": IntDistribution(2, 256),
            "classifier__feature_fraction": FloatDistribution(0.4, 1.0),
            "classifier__bagging_fraction": FloatDistribution(0.4, 1.0),
            "classifier__bagging_freq": IntDistribution(1, 7),
            "classifier__min_child_samples": IntDistribution(5, 100),
            random_state=1, auto_class_weights="Balanced", verbose=False
            "classifier__n_estimators": IntDistribution(50, 1000, step=50),
            "classifier__depth": IntDistribution(1, 12),
            "classifier__colsample_bylevel": FloatDistribution(0.1, 1),
            "classifier__boosting_type": CategoricalDistribution(["Ordered", "Plain"]),
            "classifier__bootstrap_type": CategoricalDistribution(
                ["Bayesian", "Bernoulli", "MVS"]
            "classifier__random_strength": FloatDistribution(1e-4, 10.0, log=True),
            "classifier__l2_leaf_reg": FloatDistribution(1e-8, 10.0, log=True),
            # if param["bootstrap_type"] == "Bayesian":
            "classifier__bagging_temperature": FloatDistribution(0, 10),
            # if param["bootstrap_type"] == "Bernoulli":
            "classifier__subsample": FloatDistribution(0.1, 1),


@my.cache_results(dir_cache + "01_1_pre_tuned_models_nonboost.pickle")
def pre_tune_models_nonboost():
    """The following code is wrapped into a function for result catching.

    Select the best model for each classifier.

    # Create a list to store pre-tuned models
    pre_tuned_models = {}

    # Iterate over the classifiers and perform hyperparameter tuning
    # using cross-validation
    for name, classifier, param_grid in classifiers:
        # Create the pipeline with the preprocessor, and the classifier
        pipeline = Pipeline(
                ("preprocessor", group_dependent_preprocessor),
                ("classifier", classifier),

        # Perform hyperparameter tuning using cross-validation
        print(f"\nClassifier: {name}")
        grid_search = GridSearchCV(
            pipeline, param_grid, cv=5, scoring="f1", n_jobs=-1, verbose=1
        ), y_train)

        # Get the best model
        pre_tuned_models[name] = grid_search.best_estimator_
    return pre_tuned_models

pre_tuned_models_nonboost = pre_tune_models_nonboost()
# Duration: 2m 59.8s
@my.cache_results(dir_cache + "01_2_pre_tuned_models_boost.pickle")
def pre_tune_models_boost():
    """The following code is wrapped into a function for result catching.

    Select the best model for each classifier.

    # Create a list to store pre-tuned models
    pre_tuned_models = {}

    # Iterate over the classifiers and perform hyperparameter tuning
    # using cross-validation
    for name, classifier, param_candidates in classifiers_boost:
        # Create the pipeline with the preprocessor, and the classifier
        pipeline = Pipeline(
                ("preprocessor", group_dependent_preprocessor),
                ("classifier", classifier),

        # Perform hyperparameter tuning using cross-validation
        print(f"\nClassifier: {name}")
        optuna_search = OptunaSearchCV(
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=FutureWarning)
  , y_train)

        # Get the best model
        pre_tuned_models[name] = optuna_search.best_estimator_
    return pre_tuned_models

pre_tuned_models_boost = pre_tune_models_boost()
# Duration: 10m 2.6s
OptunaSearchCV is experimental (supported from v0.17.0). The interface can change in the future.
[I 2023-09-24 00:10:56,751] Finished hyperparameter search!
[I 2023-09-24 00:10:56,754] Refitting the estimator using 3576 samples...
[I 2023-09-24 00:10:56,933] Finished refitting! (elapsed time: 0.178 sec.)
OptunaSearchCV is experimental (supported from v0.17.0). The interface can change in the future.
[I 2023-09-24 00:10:56,940] A new study created in memory with name: no-name-ebda560e-6e38-4835-b168-feeda0770dc6
[I 2023-09-24 00:10:56,941] Searching the best hyperparameters using 3576 samples...

Classifier: CatBoost
[I 2023-09-24 00:15:14,043] Trial 11 finished with value: 0.21604680624078915 and parameters: {'classifier__n_estimators': 550, 'classifier__depth': 11, 'classifier__colsample_bylevel': 0.6546053454479358, 'classifier__boosting_type': 'Ordered', 'classifier__bootstrap_type': 'MVS', 'classifier__random_strength': 0.04096476364867308, 'classifier__l2_leaf_reg': 3.8527617379512423e-08, 'classifier__bagging_temperature': 4.090692965310868, 'classifier__subsample': 0.6080936683714078}. Best is trial 28 with value: 0.23330949817428506.
[I 2023-09-24 00:15:16,987] Trial 2 finished with value: 0.21154397423230478 and parameters: {'classifier__n_estimators': 850, 'classifier__depth': 6, 'classifier__colsample_bylevel': 0.7158719129254822, 'classifier__boosting_type': 'Ordered', 'classifier__bootstrap_type': 'MVS', 'classifier__random_strength': 0.02975194172914774, 'classifier__l2_leaf_reg': 5.484403942638211e-06, 'classifier__bagging_temperature': 5.095075058846753, 'classifier__subsample': 0.713887897451381}. Best is trial 28 with value: 0.23330949817428506.
[I 2023-09-24 00:15:35,647] Trial 38 finished with value: 0.21694881401255364 and parameters: {'classifier__n_estimators': 950, 'classifier__depth': 2, 'classifier__colsample_bylevel': 0.22001906303171076, 'classifier__boosting_type': 'Ordered', 'classifier__bootstrap_type': 'MVS', 'classifier__random_strength': 0.00012914292622773072, 'classifier__l2_leaf_reg': 0.0003413538157867929, 'classifier__bagging_temperature': 5.364696994314535, 'classifier__subsample': 0.2961910275741255}. Best is trial 28 with value: 0.23330949817428506.
[I 2023-09-24 00:15:35,648] Finished hyperparameter search!
[I 2023-09-24 00:15:35,650] Refitting the estimator using 3576 samples...
[I 2023-09-24 00:15:38,216] Finished refitting! (elapsed time: 2.564 sec.)
best_models = {**pre_tuned_models_nonboost, **pre_tuned_models_boost}

The performance of the classification. The best scores of each column are in orange (training set)/green (validation set):

    "--- Train ---",
--- Train ---
No information rate:  0.951
Table 3.2. Classification scores for the train set. The rows are sorted by F1 score. The best values in each column are highlighted.
  n No_info_rate Accuracy BAcc BAcc_01 F1 F1_neg TPR TNR PPV NPV Kappa ROC_AUC
LightGBM 3576 0.951 0.869 0.931 0.862 0.425 0.926 1.000 0.862 0.270 1.000 0.378 0.967
kNN 3576 0.951 0.957 0.612 0.225 0.345 0.978 0.230 0.995 0.690 0.962 0.328 0.966
XGBoost 3576 0.951 0.886 0.744 0.488 0.334 0.938 0.586 0.902 0.233 0.977 0.284 0.885
Random Forest 3576 0.951 0.785 0.808 0.616 0.274 0.874 0.833 0.782 0.164 0.989 0.210 0.895
Naive Bayes 3576 0.951 0.802 0.776 0.552 0.269 0.886 0.747 0.805 0.164 0.984 0.205 0.840
CatBoost 3576 0.951 0.741 0.807 0.613 0.248 0.844 0.879 0.734 0.145 0.992 0.180 0.881
Logistic Regression 3576 0.951 0.770 0.767 0.535 0.244 0.864 0.764 0.770 0.145 0.985 0.177 0.841
SVC 3576 0.951 0.770 0.762 0.523 0.241 0.864 0.753 0.770 0.144 0.984 0.174 0.840
    best_models, X_validation, y_validation, "--- Validation ---", sort_by="F1"
--- Validation ---
No information rate:  0.952
Table 3.3. Classification scores for the validation set. The rows are sorted by F1 score. The best values in each column are highlighted.
  n No_info_rate Accuracy BAcc BAcc_01 F1 F1_neg TPR TNR PPV NPV Kappa ROC_AUC
XGBoost 766 0.952 0.873 0.664 0.328 0.248 0.931 0.432 0.896 0.174 0.969 0.192 0.825
CatBoost 766 0.952 0.731 0.807 0.615 0.243 0.837 0.892 0.723 0.140 0.992 0.174 0.826
Random Forest 766 0.952 0.768 0.762 0.525 0.239 0.863 0.757 0.768 0.142 0.984 0.172 0.823
SVC 766 0.952 0.768 0.750 0.499 0.233 0.863 0.730 0.770 0.138 0.982 0.165 0.827
LightGBM 766 0.952 0.817 0.699 0.397 0.231 0.896 0.568 0.830 0.145 0.974 0.167 0.801
Logistic Regression 766 0.952 0.765 0.735 0.471 0.224 0.862 0.703 0.768 0.133 0.981 0.156 0.824
Naive Bayes 766 0.952 0.778 0.717 0.433 0.220 0.871 0.649 0.785 0.133 0.978 0.152 0.829
kNN 766 0.952 0.941 0.520 0.040 0.082 0.970 0.054 0.986 0.167 0.954 0.059 0.590

3.4 Sequential Feature Selection (SFS)

In this section, Sequential Feature Selection (SFS) is performed, and its results are compared.

  1. This subsection defines the function to carry out SFS.
    • One advantage of the selected SFS algorithm implementation is that it allows us to obtain and inspect intermediate results, enabling better decision-making.
    • One drawback is that the algorithm fails when a pipeline is used instead of a classifier, requiring data preprocessing before SFS (which increases the likelihood of data leakage).
  2. In the following subsections, SFS will be conducted for each model, including both forward and backward selections.
  3. In the final subsection on SFS, the results of the SFS are summarized and compared (see Table 3.20).
X_preprocessed = group_dependent_preprocessor.fit_transform(X_train)

def SFS(classifier, forward=True):
    """Perform Sequential Feature Selector for the given classifier

        classifier (str): Name of the classifier.
        forward (bool, optional): Whether to use forward selection.
            Defaults to True.
    estimator = best_models[classifier]["classifier"]

    return SequentialFeatureSelector(
    ).fit(X_preprocessed, y_train)

3.4.1 SFS for Logistic Regression

@my.cache_results(dir_cache + "02_1_sfs_lr_res_f1_forward.pickle")
def do_sfs_lr_f():
    return SFS("Logistic Regression", forward=True)

sfs_lr_res_forward = do_sfs_lr_f()
# Duration: 7.1s
Example of SFS output
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  23 out of  23 | elapsed:    3.2s finished
Features: 1/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  22 out of  22 | elapsed:    0.1s finished
Features: 2/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  21 out of  21 | elapsed:    0.1s finished
Features: 3/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  20 out of  20 | elapsed:    0.1s finished
Features: 4/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  19 out of  19 | elapsed:    0.1s finished
Features: 5/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  18 out of  18 | elapsed:    0.1s finished
Features: 6/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  17 out of  17 | elapsed:    0.1s finished
Features: 7/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  16 out of  16 | elapsed:    0.1s finished
Features: 8/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  15 out of  15 | elapsed:    0.1s finished
Features: 9/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  14 out of  14 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done  14 out of  14 | elapsed:    0.1s finished
Features: 10/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  13 out of  13 | elapsed:    0.1s finished
Features: 11/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  10 out of  12 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done  12 out of  12 | elapsed:    0.1s finished
Features: 12/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   8 out of  11 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done  11 out of  11 | elapsed:    0.0s finished
Features: 13/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done  10 out of  10 | elapsed:    0.0s finished
Features: 14/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   4 out of   9 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done   9 out of   9 | elapsed:    0.0s finished
Features: 15/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 out of   8 | elapsed:    0.0s remaining:    0.2s
[Parallel(n_jobs=-1)]: Done   8 out of   8 | elapsed:    0.0s finished
Features: 16/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   7 out of   7 | elapsed:    0.0s finished
Features: 17/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   6 out of   6 | elapsed:    0.0s finished
Features: 18/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 out of   5 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done   5 out of   5 | elapsed:    0.0s finished
Features: 19/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done   4 out of   4 | elapsed:    0.0s finished
Features: 20/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   3 out of   3 | elapsed:    0.0s finished
Features: 21/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done   2 out of   2 | elapsed:    0.0s finished
Features: 22/23[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 out of   1 | elapsed:    0.0s finished
Features: 23/23
@my.cache_results(dir_cache + "02_1_sfs_lr_res_f1_backward.pickle")
def do_sfs_lr_b():
    return SFS("Logistic Regression", forward=False)

sfs_lr_res_backward = do_sfs_lr_b()
# 4.3S
ml.sfs_plot_results(sfs_lr_res_forward, "Logistic Regression")
Fig. 3.1. Sequential Feature Selection using Logistic Regression classifier.
k = 2, avg. F1 = 0.269 [Parsimonious]
(Smallest number of predictors at best ± 1 SE score)
List of features
Table 3.4. Sequential Feature Selection using Logistic Regression classifier. In column added_feature (forward selection) or feature (backward elimination), the row of interest and the rows above it collectively indicate the combination of features included in the model. Columns score_improvement and score_percentage_change show the difference between the current and the row above it.
  added_feature metric score score_improvement score_percentage_change
1 stroke_risk_40 F1 0.264 nan nan
2 health_risk_score F1 0.269 0.005 2.085
3 residence_is_urban_01 F1 0.270 0.001 0.362
4 hypertension_01 F1 0.271 0.001 0.436
5 heart_disease_01 F1 0.272 0.001 0.311
6 hypertension_heart_disease_interaction_01 F1 0.272 0.000 0.000
7 work_type_Private Sector F1 0.272 0.000 0.000
8 work_type_Government Sector F1 0.272 0.000 0.017
9 work_type_Never worked F1 0.271 -0.001 -0.523
10 smoking_status_smokes F1 0.271 0.000 0.163
11 gender_is_male_01 F1 0.271 0.000 0.000
12 work_type_Self Employed F1 0.269 -0.002 -0.700
13 smoking_status_is_unknown_01 F1 0.268 -0.002 -0.631
14 smoking_status_formerly smoked F1 0.263 -0.004 -1.652
15 avg_glucose_gr_above_165_01 F1 0.265 0.001 0.541
16 smoking_status_never smoked F1 0.264 -0.001 -0.201
17 age_gender_interaction F1 0.261 -0.003 -1.162
18 ever_married_01 F1 0.261 0.000 0.106
19 stroke_incidence F1 0.256 -0.006 -2.148
20 bmi_group_num F1 0.253 -0.003 -1.105
21 age_smoking_interaction F1 0.248 -0.005 -1.823
22 age_square F1 0.241 -0.007 -2.849
23 age F1 0.240 -0.001 -0.495
ml.sfs_plot_results(sfs_lr_res_backward, "Logistic Regression")
Fig. 3.2. Sequential Feature Selection using Logistic Regression classifier.
k = 2, avg. F1 = 0.269 [Parsimonious]
(Smallest number of predictors at best ± 1 SE score)
List of features
Table 3.5. Sequential Feature Selection using Logistic Regression classifier. For details, refer to the description of Table 3.4.
  feature metric score score_improvement score_percentage_change
23 stroke_risk_40 F1 0.264 nan nan
22 health_risk_score F1 0.269 0.005 2.085
21 work_type_Self Employed F1 0.269 -0.000 -0.150
20 avg_glucose_gr_above_165_01 F1 0.265 -0.004 -1.544
19 bmi_group_num F1 0.265 -0.000 -0.011
18 hypertension_01 F1 0.267 0.002 0.751
17 age_gender_interaction F1 0.263 -0.003 -1.296
16 work_type_Private Sector F1 0.264 0.001 0.203
15 smoking_status_formerly smoked F1 0.261 -0.003 -1.106
14 age_smoking_interaction F1 0.261 -0.000 -0.060
13 gender_is_male_01 F1 0.261 0.000 0.093
12 work_type_Government Sector F1 0.260 -0.001 -0.208
11 smoking_status_is_unknown_01 F1 0.261 0.001 0.257
10 hypertension_heart_disease_interaction_01 F1 0.261 0.000 0.000
9 smoking_status_never smoked F1 0.260 -0.000 -0.180
8 smoking_status_smokes F1 0.260 -0.000 -0.007
7 work_type_Never worked F1 0.260 -0.000 -0.103
6 heart_disease_01 F1 0.259 -0.001 -0.534
5 residence_is_urban_01 F1 0.257 -0.002 -0.596
4 ever_married_01 F1 0.255 -0.002 -0.927
3 stroke_incidence F1 0.248 -0.007 -2.595
2 age_square F1 0.241 -0.007 -2.849
1 age F1 0.240 -0.001 -0.495

3.4.2 SFS for Naive Bayes

@my.cache_results(dir_cache + "02_2_sfs_nb_res_f1_forward.pickle")
def do_sfs_nb_f():
    return SFS("Naive Bayes", forward=True)

sfs_nb_res_forward = do_sfs_nb_f()
# Duration: 2.5s
@my.cache_results(dir_cache + "02_2_sfs_nb_res_f1_backward.pickle")
def do_sfs_nb_b():
    return SFS("Naive Bayes", forward=False)

sfs_nb_res_backward = do_sfs_nb_b()
# Duration: 2.7s
ml.sfs_plot_results(sfs_nb_res_forward, "Naive Bayes")
Fig. 3.3. Sequential Feature Selection using Naive Bayes classifier.
k = 4, avg. F1 = 0.300 [Parsimonious]
(Smallest number of predictors at best ± 1 SE score)
List of features
Table 3.6. Sequential Feature Selection using Naive Bayes classifier. For details, refer to the description of Table 3.4.
  added_feature metric score score_improvement score_percentage_change
1 stroke_risk_40 F1 0.195 nan nan
2 age F1 0.270 0.075 38.144
3 avg_glucose_gr_above_165_01 F1 0.294 0.024 9.076
4 smoking_status_never smoked F1 0.300 0.006 2.056
5 residence_is_urban_01 F1 0.301 0.001 0.301
6 smoking_status_smokes F1 0.303 0.002 0.611
7 age_gender_interaction F1 0.301 -0.002 -0.648
8 work_type_Private Sector F1 0.301 0.000 0.025
9 work_type_Government Sector F1 0.299 -0.003 -0.833
10 gender_is_male_01 F1 0.299 0.000 0.131
11 smoking_status_is_unknown_01 F1 0.295 -0.004 -1.294
12 work_type_Self Employed F1 0.299 0.003 1.163
13 hypertension_heart_disease_interaction_01 F1 0.295 -0.004 -1.282
14 ever_married_01 F1 0.287 -0.008 -2.725
15 bmi_group_num F1 0.288 0.001 0.491
16 smoking_status_formerly smoked F1 0.281 -0.008 -2.693
17 stroke_incidence F1 0.275 -0.005 -1.882
18 age_smoking_interaction F1 0.280 0.004 1.597
19 health_risk_score F1 0.284 0.005 1.682
20 hypertension_01 F1 0.283 -0.001 -0.518
21 work_type_Never worked F1 0.271 -0.012 -4.297
22 heart_disease_01 F1 0.272 0.001 0.301
23 age_square F1 0.267 -0.004 -1.507
ml.sfs_plot_results(sfs_nb_res_backward, "Naive Bayes")
Fig. 3.4. Sequential Feature Selection using Naive Bayes classifier.
k = 3, avg. F1 = 0.308 [Parsimonious]
(Smallest number of predictors at best ± 1 SE score)
List of features
    .style.apply(my.highlight_rows_by_index, values=[21, 22, 23], axis=1)
Table 3.7. Sequential Feature Selection using Naive Bayes classifier. For details, refer to the description of Table 3.4. The highlighted rows indicate the features of the best Naive Bayes model chosen for further investigation.
  feature metric score score_improvement score_percentage_change
23 stroke_risk_40 F1 0.195 nan nan
22 stroke_incidence F1 0.267 0.071 36.530
21 avg_glucose_gr_above_165_01 F1 0.308 0.041 15.483
20 hypertension_heart_disease_interaction_01 F1 0.296 -0.012 -3.897
19 smoking_status_formerly smoked F1 0.302 0.006 1.881
18 work_type_Self Employed F1 0.290 -0.011 -3.769
17 work_type_Never worked F1 0.286 -0.004 -1.362
16 health_risk_score F1 0.287 0.000 0.092
15 age_square F1 0.287 0.001 0.261
14 residence_is_urban_01 F1 0.289 0.001 0.519
13 bmi_group_num F1 0.291 0.002 0.621
12 smoking_status_never smoked F1 0.291 0.001 0.270
11 work_type_Government Sector F1 0.291 -0.001 -0.270
10 work_type_Private Sector F1 0.291 -0.000 -0.020
9 age_smoking_interaction F1 0.287 -0.004 -1.327
8 hypertension_01 F1 0.285 -0.002 -0.617
7 ever_married_01 F1 0.279 -0.006 -2.085
6 heart_disease_01 F1 0.280 0.001 0.220
5 age_gender_interaction F1 0.279 -0.000 -0.060
4 smoking_status_smokes F1 0.280 0.000 0.122
3 gender_is_male_01 F1 0.279 -0.001 -0.340
2 smoking_status_is_unknown_01 F1 0.274 -0.004 -1.562
1 age F1 0.267 -0.007 -2.548

3.4.3 SFS for kNN

@my.cache_results(dir_cache + "02_3_sfs_knn_res_f1_forward.pickle")
def do_sfs_knn_f():
    return SFS("kNN", forward=True)

sfs_knn_res_forward = do_sfs_knn_f()
# 20.2s
@my.cache_results(dir_cache + "02_3_sfs_knn_res_f1_backward.pickle")
def do_sfs_knn_b():
    return SFS("kNN", forward=False)

sfs_knn_res_backward = do_sfs_knn_b()
# 15.8s
ml.sfs_plot_results(sfs_knn_res_forward, "k Nearest Neighbors (kNN)")
Fig. 3.5. Sequential Feature Selection using k nearest neighbors (kNN) classifier.
k = 11, avg. F1 = 0.143 [Parsimonious]
(Smallest number of predictors at best ± 1 SE score)
List of features
Table 3.8. Sequential Feature Selection using kNN classifier. For details, refer to the description of Table 3.4.
  added_feature metric score score_improvement score_percentage_change
1 bmi_group_num F1 0.071 nan nan
2 age_gender_interaction F1 0.102 0.031 44.348
3 work_type_Government Sector F1 0.103 0.001 0.778
4 age_smoking_interaction F1 0.124 0.021 20.810
5 ever_married_01 F1 0.132 0.008 6.547
6 gender_is_male_01 F1 0.132 0.000 0.000
7 work_type_Never worked F1 0.132 0.000 0.000
8 age F1 0.122 -0.010 -7.223
9 stroke_incidence F1 0.123 0.000 0.239
10 hypertension_01 F1 0.121 -0.002 -1.293
11 smoking_status_formerly smoked F1 0.143 0.022 18.349
12 smoking_status_smokes F1 0.145 0.001 0.806
13 health_risk_score F1 0.146 0.002 1.252
14 hypertension_heart_disease_interaction_01 F1 0.152 0.006 3.939
15 smoking_status_is_unknown_01 F1 0.156 0.004 2.686
16 smoking_status_never smoked F1 0.148 -0.008 -5.385
17 age_square F1 0.133 -0.015 -10.229
18 avg_glucose_gr_above_165_01 F1 0.126 -0.007 -5.434
19 heart_disease_01 F1 0.128 0.003 2.271
20 residence_is_urban_01 F1 0.114 -0.014 -11.091
21 work_type_Self Employed F1 0.114 -0.000 -0.005
22 stroke_risk_40 F1 0.088 -0.026 -22.501
23 work_type_Private Sector F1 0.079 -0.009 -10.615
ml.sfs_plot_results(sfs_knn_res_backward, "k Nearest Neighbors (kNN)")
Fig. 3.6. Sequential Feature Selection using k nearest neighbors (kNN) classifier.
k = 4, avg. F1 = 0.168 [Parsimonious]
(Smallest number of predictors at best ± 1 SE score)
List of features
Table 3.9. Sequential Feature Selection using kNN classifier. For details, refer to the description of Table 3.4.
  feature metric score score_improvement score_percentage_change
23 bmi_group_num F1 0.071 nan nan
22 age_smoking_interaction F1 0.072 0.002 2.653
21 health_risk_score F1 0.133 0.060 83.194
20 hypertension_01 F1 0.168 0.036 26.901
19 work_type_Government Sector F1 0.167 -0.001 -0.461
18 smoking_status_smokes F1 0.151 -0.016 -9.727
17 work_type_Private Sector F1 0.164 0.013 8.421
16 residence_is_urban_01 F1 0.148 -0.016 -9.836
15 age_gender_interaction F1 0.149 0.002 1.088
14 stroke_risk_40 F1 0.153 0.003 2.161
13 smoking_status_formerly smoked F1 0.142 -0.010 -6.684
12 smoking_status_is_unknown_01 F1 0.137 -0.006 -4.161
11 smoking_status_never smoked F1 0.137 0.001 0.669
10 avg_glucose_gr_above_165_01 F1 0.142 0.004 3.104
9 gender_is_male_01 F1 0.142 0.000 0.000
8 ever_married_01 F1 0.143 0.002 1.089
7 hypertension_heart_disease_interaction_01 F1 0.134 -0.009 -6.217
6 heart_disease_01 F1 0.130 -0.004 -3.296
5 age F1 0.124 -0.006 -4.441
4 work_type_Never worked F1 0.124 0.000 0.000
3 age_square F1 0.122 -0.002 -1.762
2 stroke_incidence F1 0.114 -0.008 -6.324
1 work_type_Self Employed F1 0.070 -0.044 -38.750

3.4.4 SFS for SVC

@my.cache_results(dir_cache + "02_4_sfs_svc_res_f1_forward.pickle")
def do_sfs_svc_f():
    return SFS("SVC", forward=True)

sfs_svc_res_forward = do_sfs_svc_f()
# 17m 44.4s
@my.cache_results(dir_cache + "02_4_sfs_svc_res_f1_backward.pickle")
def do_sfs_svc_b():
    return SFS("SVC", forward=False)

sfs_svc_res_backward = do_sfs_svc_b()
# 19m 15.5s
Output details

Time: Time: 1m 2.3s

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   8 out of  11 | elapsed:    8.1s remaining:    3.0s
[Parallel(n_jobs=-1)]: Done  11 out of  11 | elapsed:   11.5s finished
Features: 10/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   6 out of  10 | elapsed:    6.4s remaining:    4.2s
[Parallel(n_jobs=-1)]: Done  10 out of  10 | elapsed:    9.0s finished
Features: 9/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   4 out of   9 | elapsed:    5.3s remaining:    6.7s
[Parallel(n_jobs=-1)]: Done   9 out of   9 | elapsed:    7.6s finished
Features: 8/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 out of   8 | elapsed:    5.3s remaining:   16.2s
[Parallel(n_jobs=-1)]: Done   8 out of   8 | elapsed:    6.0s finished
Features: 7/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   7 out of   7 | elapsed:    5.1s finished
Features: 6/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   6 out of   6 | elapsed:    4.8s finished
Features: 5/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 out of   5 | elapsed:    3.1s remaining:    4.7s
[Parallel(n_jobs=-1)]: Done   5 out of   5 | elapsed:    4.0s finished
Features: 4/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   4 out of   4 | elapsed:    3.6s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done   4 out of   4 | elapsed:    3.6s finished
Features: 3/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   3 out of   3 | elapsed:    3.8s finished
Features: 2/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 out of   2 | elapsed:    2.1s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done   2 out of   2 | elapsed:    2.1s finished
Features: 1/1

Time: 58.4s

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   8 out of  11 | elapsed:   10.2s remaining:    3.8s
[Parallel(n_jobs=-1)]: Done  11 out of  11 | elapsed:   11.6s finished
Features: 10/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   6 out of  10 | elapsed:    5.5s remaining:    3.6s
[Parallel(n_jobs=-1)]: Done  10 out of  10 | elapsed:    8.0s finished
Features: 9/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   4 out of   9 | elapsed:    5.1s remaining:    6.4s
[Parallel(n_jobs=-1)]: Done   9 out of   9 | elapsed:    7.1s finished
Features: 8/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 out of   8 | elapsed:    4.8s remaining:   14.7s
[Parallel(n_jobs=-1)]: Done   8 out of   8 | elapsed:    5.2s finished
Features: 7/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   7 out of   7 | elapsed:    4.7s finished
Features: 6/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   6 out of   6 | elapsed:    4.3s finished
Features: 5/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 out of   5 | elapsed:    3.1s remaining:    4.7s
[Parallel(n_jobs=-1)]: Done   5 out of   5 | elapsed:    3.8s finished
Features: 4/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   4 out of   4 | elapsed:    3.4s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done   4 out of   4 | elapsed:    3.4s finished
Features: 3/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   3 out of   3 | elapsed:    3.6s finished
Features: 2/1[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 out of   2 | elapsed:    1.9s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done   2 out of   2 | elapsed:    1.9s finished
Features: 1/1
ml.sfs_plot_results(sfs_svc_res_forward, "SVM for classification (SVC)")
Fig. 3.7. Sequential Feature Selection using support vector machine for classification (SVC).
k = 3, avg. F1 = 0.280 [Parsimonious]
(Smallest number of predictors at best ± 1 SE score)
List of features
    .style.apply(my.highlight_rows_by_index, values=[1, 2, 3], axis=1)
Table 3.10. Sequential Feature Selection using SVC classifier. For details, refer to the description of Table 3.4. The highlighted rows indicate the features of the best SVC model chosen for further investigation.
  added_feature metric score score_improvement score_percentage_change
1 stroke_risk_40 F1 0.262 nan nan
2 health_risk_score F1 0.275 0.013 4.990
3 age_smoking_interaction F1 0.280 0.005 1.916
4 age_gender_interaction F1 0.284 0.004 1.297
5 gender_is_male_01 F1 0.284 0.000 0.149
6 hypertension_heart_disease_interaction_01 F1 0.284 0.000 0.000
7 work_type_Government Sector F1 0.284 -0.000 -0.107
8 smoking_status_smokes F1 0.283 -0.000 -0.119
9 smoking_status_formerly smoked F1 0.282 -0.001 -0.288
10 work_type_Private Sector F1 0.283 0.001 0.198
11 residence_is_urban_01 F1 0.281 -0.002 -0.716
12 heart_disease_01 F1 0.280 -0.001 -0.299
13 smoking_status_is_unknown_01 F1 0.281 0.001 0.188
14 work_type_Self Employed F1 0.279 -0.002 -0.754
15 work_type_Never worked F1 0.278 -0.001 -0.364
16 smoking_status_never smoked F1 0.277 -0.000 -0.052
17 avg_glucose_gr_above_165_01 F1 0.279 0.001 0.424
18 hypertension_01 F1 0.277 -0.002 -0.654
19 bmi_group_num F1 0.271 -0.006 -2.214
20 ever_married_01 F1 0.268 -0.003 -1.131
21 stroke_incidence F1 0.257 -0.010 -3.808
22 age_square F1 0.247 -0.011 -4.224
23 age F1 0.234 -0.012 -5.024
ml.sfs_plot_results(sfs_svc_res_backward, "SVM for classification (SVC)")
Fig. 3.8. Sequential Feature Selection using support vector machine for classification (SVC).
k = 3, avg. F1 = 0.278 [Parsimonious]
(Smallest number of predictors at best ± 1 SE score)
List of features
Table 3.11. Sequential Feature Selection using SVC classifier. For details, refer to the description of Table 3.4.
  feature metric score score_improvement score_percentage_change
23 stroke_risk_40 F1 0.262 nan nan
22 avg_glucose_gr_above_165_01 F1 0.269 0.007 2.812
21 hypertension_01 F1 0.278 0.009 3.282
20 stroke_incidence F1 0.277 -0.000 -0.171
19 ever_married_01 F1 0.277 -0.000 -0.003
18 smoking_status_formerly smoked F1 0.280 0.003 0.903
17 gender_is_male_01 F1 0.280 0.000 0.000
16 age_gender_interaction F1 0.279 -0.001 -0.231
15 work_type_Never worked F1 0.281 0.002 0.603
14 work_type_Self Employed F1 0.280 -0.001 -0.299
13 work_type_Private Sector F1 0.280 0.000 0.121
12 residence_is_urban_01 F1 0.281 0.001 0.357
11 hypertension_heart_disease_interaction_01 F1 0.281 0.000 0.000
10 work_type_Government Sector F1 0.281 0.000 0.000
9 smoking_status_smokes F1 0.281 0.000 0.000
8 heart_disease_01 F1 0.281 -0.001 -0.251
7 smoking_status_is_unknown_01 F1 0.280 -0.000 -0.160
6 smoking_status_never smoked F1 0.279 -0.001 -0.230
5 bmi_group_num F1 0.274 -0.006 -1.999
4 age_smoking_interaction F1 0.268 -0.006 -2.190
3 health_risk_score F1 0.257 -0.011 -3.926
2 age_square F1 0.247 -0.011 -4.224
1 age F1 0.234 -0.012 -5.024

3.4.5 SFS for RF

@my.cache_results(dir_cache + "02_5_sfs_rf_res_f1_forward.pickle")
def do_sfs_rf_f():
    return SFS("Random Forest", forward=True)

sfs_rf_res_forward = do_sfs_rf_f()
# 4m 43.9s
@my.cache_results(dir_cache + "02_5_sfs_rf_res_f1_backward.pickle")
def do_sfs_rf_b():
    return SFS("Random Forest", forward=False)

sfs_rf_res_backward = do_sfs_rf_b()
# 5m 2.4s
ml.sfs_plot_results(sfs_rf_res_forward, "Random Forest (RF)")
Fig. 3.9. Sequential Feature Selection using Random Forest classifier.
k = 5, avg. F1 = 0.259 [Parsimonious]
(Smallest number of predictors at best ± 1 SE score)
List of features
Table 3.12. Sequential Feature Selection using Random Forest classifier. For details, refer to the description of Table 3.4.
  added_feature metric score score_improvement score_percentage_change
1 stroke_risk_40 F1 0.227 nan nan
2 age_smoking_interaction F1 0.253 0.026 11.485
3 smoking_status_smokes F1 0.253 0.000 0.119
4 hypertension_01 F1 0.255 0.001 0.424
5 age_square F1 0.259 0.005 1.785
6 stroke_incidence F1 0.258 -0.001 -0.485
7 work_type_Never worked F1 0.252 -0.005 -2.111
8 smoking_status_never smoked F1 0.252 -0.000 -0.149
9 work_type_Self Employed F1 0.253 0.001 0.420
10 smoking_status_formerly smoked F1 0.251 -0.002 -0.821
11 hypertension_heart_disease_interaction_01 F1 0.252 0.001 0.335
12 smoking_status_is_unknown_01 F1 0.249 -0.003 -1.076
13 age F1 0.248 -0.001 -0.370
14 residence_is_urban_01 F1 0.247 -0.002 -0.622
15 gender_is_male_01 F1 0.248 0.001 0.546
16 work_type_Private Sector F1 0.246 -0.002 -0.647
17 heart_disease_01 F1 0.244 -0.002 -0.817
18 ever_married_01 F1 0.248 0.004 1.578
19 bmi_group_num F1 0.244 -0.004 -1.738
20 avg_glucose_gr_above_165_01 F1 0.242 -0.002 -0.838
21 health_risk_score F1 0.234 -0.007 -3.094
22 age_gender_interaction F1 0.238 0.004 1.731
23 work_type_Government Sector F1 0.241 0.003 1.126
ml.sfs_plot_results(sfs_rf_res_backward, "Random Forest (RF)")
Fig. 3.10. Sequential Feature Selection using Random Forest classifier.
k = 4, avg. F1 = 0.260 [Parsimonious]
(Smallest number of predictors at best ± 1 SE score)
List of features
Table 3.13. Sequential Feature Selection using Random Forest classifier. For details, refer to the description of Table 3.4.
  feature metric score score_improvement score_percentage_change
23 age_gender_interaction F1 0.222 nan nan
22 health_risk_score F1 0.252 0.030 13.680
21 stroke_risk_40 F1 0.253 0.001 0.307
20 bmi_group_num F1 0.260 0.007 2.798
19 stroke_incidence F1 0.254 -0.006 -2.195
18 work_type_Self Employed F1 0.249 -0.005 -2.020
17 work_type_Private Sector F1 0.251 0.002 0.660
16 smoking_status_formerly smoked F1 0.246 -0.004 -1.702
15 work_type_Never worked F1 0.249 0.002 0.825
14 age F1 0.252 0.004 1.596
13 gender_is_male_01 F1 0.247 -0.005 -2.010
12 heart_disease_01 F1 0.252 0.005 1.872
11 smoking_status_never smoked F1 0.253 0.001 0.572
10 ever_married_01 F1 0.252 -0.001 -0.571
9 age_square F1 0.258 0.006 2.507
8 hypertension_01 F1 0.257 -0.001 -0.514
7 work_type_Government Sector F1 0.253 -0.004 -1.497
6 smoking_status_is_unknown_01 F1 0.254 0.000 0.187
5 hypertension_heart_disease_interaction_01 F1 0.249 -0.005 -2.004
4 age_smoking_interaction F1 0.250 0.001 0.382
3 smoking_status_smokes F1 0.244 -0.005 -2.190
2 residence_is_urban_01 F1 0.246 0.002 0.754
1 avg_glucose_gr_above_165_01 F1 0.241 -0.005 -1.940

3.4.6 SFS for XGBoost

@my.cache_results(dir_cache + "02_6_sfs_xgb_res_f1_forward.pickle")
def do_sfs_xgb_f():
    return SFS("XGBoost", forward=True)

sfs_xgb_res_forward = do_sfs_xgb_f()
# Duration: 56.3s
@my.cache_results(dir_cache + "02_6_sfs_xgb_res_f1_backward.pickle")
def do_sfs_xgb_b():
    return SFS("XGBoost", forward=False)

sfs_xgb_res_backward = do_sfs_xgb_b()
# Duration: 56.2s
ml.sfs_plot_results(sfs_xgb_res_forward, "XGBoost")
Fig. 3.11. Sequential Feature Selection using XGBoost classifier.
k = 3, avg. F1 = 0.314 [Parsimonious]
(Smallest number of predictors at best ± 1 SE score)
List of features
    .style.apply(my.highlight_rows_by_index, values=[1, 2, 3], axis=1)
Table 3.14. Sequential Feature Selection using XGBoost classifier. For details, refer to the description of Table 3.4. The highlighted rows indicate the features of the best XGBoost model chosen for further investigation.
  added_feature metric score score_improvement score_percentage_change
1 age_gender_interaction F1 0.252 nan nan
2 avg_glucose_gr_above_165_01 F1 0.287 0.036 14.140
3 stroke_risk_40 F1 0.314 0.027 9.231
4 work_type_Private Sector F1 0.322 0.009 2.718
5 age F1 0.316 -0.007 -2.148
6 smoking_status_is_unknown_01 F1 0.316 0.001 0.241
7 work_type_Never worked F1 0.319 0.003 0.988
8 bmi_group_num F1 0.316 -0.004 -1.098
9 stroke_incidence F1 0.315 -0.001 -0.238
10 gender_is_male_01 F1 0.314 -0.001 -0.447
11 age_square F1 0.312 -0.001 -0.448
12 smoking_status_smokes F1 0.311 -0.001 -0.469
13 hypertension_heart_disease_interaction_01 F1 0.312 0.001 0.280
14 work_type_Self Employed F1 0.309 -0.002 -0.732
15 residence_is_urban_01 F1 0.310 0.001 0.176
16 work_type_Government Sector F1 0.307 -0.003 -0.885
17 heart_disease_01 F1 0.306 -0.002 -0.537
18 smoking_status_formerly smoked F1 0.300 -0.005 -1.710
19 smoking_status_never smoked F1 0.298 -0.003 -0.906
20 age_smoking_interaction F1 0.283 -0.015 -4.911
21 ever_married_01 F1 0.288 0.005 1.637
22 health_risk_score F1 0.286 -0.001 -0.512
23 hypertension_01 F1 0.291 0.005 1.795
ml.sfs_plot_results(sfs_xgb_res_backward, "XGBoost")
Fig. 3.12. Sequential Feature Selection using XGBoost classifier.
k = 3, avg. F1 = 0.302 [Parsimonious]
(Smallest number of predictors at best ± 1 SE score)
List of features
Table 3.15. Sequential Feature Selection using XGBoost classifier. For details, refer to the description of Table 3.4.
  feature metric score score_improvement score_percentage_change
23 age_square F1 0.249 nan nan
22 avg_glucose_gr_above_165_01 F1 0.294 0.045 17.863
21 smoking_status_smokes F1 0.302 0.008 2.646
20 work_type_Self Employed F1 0.303 0.002 0.629
19 hypertension_heart_disease_interaction_01 F1 0.311 0.008 2.483
18 work_type_Never worked F1 0.307 -0.004 -1.190
17 age_gender_interaction F1 0.318 0.011 3.426
16 age F1 0.315 -0.003 -0.924
15 bmi_group_num F1 0.314 -0.001 -0.209
14 residence_is_urban_01 F1 0.314 0.000 0.038
13 gender_is_male_01 F1 0.310 -0.004 -1.279
12 smoking_status_formerly smoked F1 0.305 -0.006 -1.879
11 smoking_status_never smoked F1 0.305 0.001 0.200
10 smoking_status_is_unknown_01 F1 0.302 -0.003 -1.117
9 age_smoking_interaction F1 0.295 -0.007 -2.241
8 health_risk_score F1 0.297 0.003 0.852
7 work_type_Private Sector F1 0.300 0.003 0.968
6 hypertension_01 F1 0.297 -0.003 -1.104
5 stroke_incidence F1 0.296 -0.001 -0.242
4 heart_disease_01 F1 0.300 0.004 1.226
3 stroke_risk_40 F1 0.296 -0.004 -1.419
2 work_type_Government Sector F1 0.291 -0.005 -1.638
1 ever_married_01 F1 0.291 0.000 0.170

3.4.7 SFS for LightGBM

@my.cache_results(dir_cache + "02_7_sfs_lgbm_res_f1_forward.pickle")
def do_sfs_lgbm_f():
    return SFS("LightGBM", forward=True)

sfs_lgbm_res_forward = do_sfs_lgbm_f()
# 1m 58.4s
@my.cache_results(dir_cache + "02_7_sfs_lgbm_res_f1_backward.pickle")
def do_sfs_lgbm_b():
    return SFS("LightGBM", forward=False)

sfs_lgbm_res_backward = do_sfs_lgbm_b()
# 2m 28.4s
ml.sfs_plot_results(sfs_lgbm_res_forward, "LightGBM")
Fig. 3.13. Sequential Feature Selection using LightGBM classifier.
k = 11, avg. F1 = 0.255 [Parsimonious]
(Smallest number of predictors at best ± 1 SE score)
List of features
ml.sfs_plot_results(sfs_lgbm_res_backward, "LightGBM")
Fig. 3.14. Sequential Feature Selection using LightGBM classifier.
k = 9, avg. F1 = 0.260 [Parsimonious]
(Smallest number of predictors at best ± 1 SE score)
Table 3.16. Sequential Feature Selection using LightGBM classifier. For details, refer to the description of Table 3.4.
  added_feature metric score score_improvement score_percentage_change
1 stroke_risk_40 F1 0.219 nan nan
2 health_risk_score F1 0.232 0.013 5.723
3 smoking_status_never smoked F1 0.242 0.011 4.598
4 ever_married_01 F1 0.245 0.002 0.982
5 hypertension_01 F1 0.246 0.001 0.398
6 smoking_status_smokes F1 0.250 0.005 1.841
7 hypertension_heart_disease_interaction_01 F1 0.250 0.000 0.000
8 age_square F1 0.245 -0.005 -1.892
9 age_gender_interaction F1 0.246 0.000 0.193
10 bmi_group_num F1 0.249 0.003 1.421
11 residence_is_urban_01 F1 0.255 0.006 2.442
12 age_smoking_interaction F1 0.250 -0.005 -2.002
13 age F1 0.257 0.007 2.872
14 avg_glucose_gr_above_165_01 F1 0.258 0.001 0.200
15 smoking_status_is_unknown_01 F1 0.257 -0.001 -0.505
16 stroke_incidence F1 0.257 -0.000 -0.009
17 smoking_status_formerly smoked F1 0.249 -0.008 -3.131
18 work_type_Government Sector F1 0.251 0.002 0.770
19 work_type_Never worked F1 0.253 0.002 0.817
20 gender_is_male_01 F1 0.252 -0.001 -0.255
21 heart_disease_01 F1 0.251 -0.001 -0.523
22 work_type_Private Sector F1 0.243 -0.008 -3.071
23 work_type_Self Employed F1 0.250 0.007 2.830
List of features
Table 3.17. Sequential Feature Selection using LightGBM classifier. For details, refer to the description of Table 3.4.
  feature metric score score_improvement score_percentage_change
23 age_square F1 0.217 nan nan
22 health_risk_score F1 0.238 0.022 9.980
21 smoking_status_smokes F1 0.227 -0.012 -4.896
20 bmi_group_num F1 0.243 0.017 7.349
19 stroke_risk_40 F1 0.244 0.000 0.194
18 residence_is_urban_01 F1 0.248 0.004 1.791
17 work_type_Government Sector F1 0.250 0.002 0.697
16 smoking_status_is_unknown_01 F1 0.256 0.006 2.473
15 work_type_Never worked F1 0.260 0.003 1.363
14 ever_married_01 F1 0.250 -0.009 -3.499
13 work_type_Private Sector F1 0.252 0.001 0.592
12 gender_is_male_01 F1 0.248 -0.004 -1.448
11 age F1 0.255 0.007 2.842
10 avg_glucose_gr_above_165_01 F1 0.259 0.004 1.396
9 heart_disease_01 F1 0.254 -0.005 -1.932
8 hypertension_01 F1 0.256 0.002 0.792
7 age_smoking_interaction F1 0.250 -0.006 -2.182
6 smoking_status_formerly smoked F1 0.253 0.002 0.922
5 age_gender_interaction F1 0.251 -0.002 -0.790
4 smoking_status_never smoked F1 0.248 -0.003 -1.219
3 work_type_Self Employed F1 0.245 -0.002 -0.956
2 stroke_incidence F1 0.250 0.005 1.882
1 hypertension_heart_disease_interaction_01 F1 0.250 0.000 0.000

3.4.8 SFS for CatBoost

@my.cache_results(dir_cache + "02_8_sfs_cat_res_f1_forward.pickle")
def do_sfs_cat_f():
    return SFS("CatBoost", forward=True)

sfs_cat_res_forward = do_sfs_cat_f()
# 5m 35.1s
@my.cache_results(dir_cache + "02_8_sfs_cat_res_f1_backward.pickle")
def do_sfs_cat_b():
    return SFS("CatBoost", forward=False)

sfs_cat_res_backward = do_sfs_cat_b()
# 8m 56.0s
ml.sfs_plot_results(sfs_cat_res_forward, "CatBoost")
Fig. 3.15. Sequential Feature Selection using CatBoost classifier.
k = 4, avg. F1 = 0.239 [Parsimonious]
(Smallest number of predictors at best ± 1 SE score)
List of features
Table 3.18. Sequential Feature Selection using CatBoost classifier. For details, refer to the description of Table 3.4.
  added_feature metric score score_improvement score_percentage_change
1 age_gender_interaction F1 0.211 nan nan
2 avg_glucose_gr_above_165_01 F1 0.232 0.020 9.622
3 age_smoking_interaction F1 0.238 0.006 2.682
4 ever_married_01 F1 0.239 0.001 0.320
5 smoking_status_never smoked F1 0.237 -0.002 -0.789
6 work_type_Government Sector F1 0.239 0.002 0.996
7 smoking_status_is_unknown_01 F1 0.236 -0.003 -1.131
8 work_type_Private Sector F1 0.237 0.001 0.234
9 health_risk_score F1 0.238 0.001 0.283
10 hypertension_heart_disease_interaction_01 F1 0.236 -0.002 -0.783
11 work_type_Self Employed F1 0.238 0.002 0.894
12 residence_is_urban_01 F1 0.239 0.001 0.405
13 work_type_Never worked F1 0.233 -0.006 -2.542
14 heart_disease_01 F1 0.240 0.007 3.187
15 smoking_status_smokes F1 0.236 -0.004 -1.556
16 smoking_status_formerly smoked F1 0.234 -0.002 -0.914
17 gender_is_male_01 F1 0.234 -0.000 -0.051
18 stroke_incidence F1 0.233 -0.001 -0.338
19 stroke_risk_40 F1 0.232 -0.002 -0.671
20 age_square F1 0.231 -0.001 -0.443
21 hypertension_01 F1 0.231 -0.000 -0.007
22 age F1 0.232 0.001 0.360
23 bmi_group_num F1 0.229 -0.003 -1.295
ml.sfs_plot_results(sfs_cat_res_backward, "CatBoost")
Fig. 3.16. Sequential Feature Selection using CatBoost classifier.
k = 6, avg. F1 = 0.243 [Parsimonious]
(Smallest number of predictors at best ± 1 SE score)
List of features
Table 3.19. Sequential Feature Selection using CatBoost classifier. For details, refer to the description of Table 3.4.
  feature metric score score_improvement score_percentage_change
23 avg_glucose_gr_above_165_01 F1 0.205 nan nan
22 stroke_risk_40 F1 0.229 0.024 11.881
21 age_smoking_interaction F1 0.236 0.008 3.311
20 smoking_status_is_unknown_01 F1 0.240 0.004 1.606
19 smoking_status_never smoked F1 0.239 -0.001 -0.416
18 hypertension_01 F1 0.243 0.004 1.753
17 smoking_status_smokes F1 0.242 -0.002 -0.777
16 smoking_status_formerly smoked F1 0.241 -0.000 -0.087
15 residence_is_urban_01 F1 0.239 -0.002 -0.840
14 work_type_Self Employed F1 0.239 0.000 0.052
13 work_type_Government Sector F1 0.237 -0.002 -0.867
12 heart_disease_01 F1 0.233 -0.004 -1.701
11 stroke_incidence F1 0.233 0.000 0.030
10 bmi_group_num F1 0.233 -0.000 -0.032
9 age_gender_interaction F1 0.235 0.001 0.629
8 work_type_Private Sector F1 0.233 -0.002 -0.665
7 ever_married_01 F1 0.234 0.001 0.255
6 health_risk_score F1 0.231 -0.003 -1.080
5 hypertension_heart_disease_interaction_01 F1 0.234 0.003 1.159
4 age F1 0.240 0.006 2.686
3 age_square F1 0.238 -0.003 -1.105
2 work_type_Never worked F1 0.232 -0.005 -2.253
1 gender_is_male_01 F1 0.229 -0.004 -1.577

3.4.9 Summary of SFS Results

This section summarizes the results of SFS analysis for each classifier (Table 3.20).


These classifiers were selected as the candidates for the final evaluation

  1. XGBoost (Figure 3.11, Table 3.14),
  2. Naive Bayes (Figure 3.4, Table 3.7), and
  3. Support Vector Machine (SVC; Figure 3.7, Table 3.10).
Code of the figure
data_str = """
Classifier             |  SFS type   | SFS duration | Number of selected predictors | F1 score | Selection method  | Selected for further analysis
Logistic Regression    |  Forward    |       7.1s   |  2                            | 0.269    |  Parsimonious¹    |
Logistic Regression    |  Backward   |       4.3s   |  2                            | 0.269    |  Parsimonious¹    |
Naive Bayes            |  Forward    |       2.5s   |  4                            | 0.300    |  Parsimonious¹    |
Naive Bayes            |  Backward   |       2.7s   |  3                            | 0.308    |  Parsimonious¹    | Yes
k Nearest Neighbors    |  Forward    |      20.2s   | 11                            | 0.143    |  Parsimonious¹    |
k Nearest Neighbors    |  Backward   |      15.8s   |  4                            | 0.168    |  Parsimonious¹    |
Support Vector Machine |  Forward    |  17m 44.4s   |  3                            | 0.280    |  Parsimonious¹    | Yes
Support Vector Machine |  Backward   |  19m 15.5s   |  3                            | 0.278    |  Parsimonious¹    |
Random Forest          |  Forward    |   4m 43.9s   |  5                            | 0.259    |  Parsimonious¹    |
Random Forest          |  Backward   |   5m  2.4s   |  4                            | 0.260    |  Parsimonious¹    |
XGBoost                |  Forward    |      56.3s   |  3                            | 0.314    |  Parsimonious¹    | Yes
XGBoost                |  Backward   |      56.2s   |  3                            | 0.308    |  Parsimonious¹    |
LightGBM               |  Forward    |   1m 58.4s   |  11                           | 0.255    |  Parsimonious¹    |
LightGBM               |  Backward   |   2m 28.4s   |  9                            | 0.260    |  Parsimonious¹    |
CatBoost               |  Forward    |   5m 35.1s   |  4                            | 0.239    |  Parsimonious¹    |
CatBoost               |  Backward   |   8m 56.0s   |  6                            | 0.243    |  Parsimonious¹    |

# Read the data into a pandas DataFrame
df = pd.read_csv(StringIO(data_str), sep="|")

# Strip leading/trailing spaces from column names
df.columns = df.columns.str.strip()

# Display the DataFrame
        cmap="RdBu", subset=["F1 score"], vmin=0.24, vmax=0.30
    ).format(precision=3, na_rep="")
Table 3.20. Summary of the forward and backward Sequential Feature Selection results: the number of selected predictors and the F1 score of the optimal model for each classifier type. F1 scores are highlighted based on the score: highest values are in dark blue, lowest values are in dark red.
  Classifier SFS type SFS duration Number of selected predictors F1 score Selection method Selected for further analysis
0 Logistic Regression Forward 7.1s 2 0.269 Parsimonious¹
1 Logistic Regression Backward 4.3s 2 0.269 Parsimonious¹
2 Naive Bayes Forward 2.5s 4 0.300 Parsimonious¹
3 Naive Bayes Backward 2.7s 3 0.308 Parsimonious¹ Yes
4 k Nearest Neighbors Forward 20.2s 11 0.143 Parsimonious¹
5 k Nearest Neighbors Backward 15.8s 4 0.168 Parsimonious¹
6 Support Vector Machine Forward 17m 44.4s 3 0.280 Parsimonious¹ Yes
7 Support Vector Machine Backward 19m 15.5s 3 0.278 Parsimonious¹
8 Random Forest Forward 4m 43.9s 5 0.259 Parsimonious¹
9 Random Forest Backward 5m 2.4s 4 0.260 Parsimonious¹
10 XGBoost Forward 56.3s 3 0.314 Parsimonious¹ Yes
11 XGBoost Backward 56.2s 3 0.308 Parsimonious¹
12 LightGBM Forward 1m 58.4s 11 0.255 Parsimonious¹
13 LightGBM Backward 2m 28.4s 9 0.260 Parsimonious¹
14 CatBoost Forward 5m 35.1s 4 0.239 Parsimonious¹
15 CatBoost Backward 8m 56.0s 6 0.243 Parsimonious¹

¹ Parsimonious method (the smallest number of features within 1 standard error from the best performance score) was used to select the number of features.

3.5 Hyperparameter Tuning

In this section:

  1. data processing pipelines will be created for the selected models;
  2. the models that have hyperparameters to tune will be fine-tuned.

3.5.1 Define Pipeline for Naive Bayes

Naive Bayes does not have any hyperparameters to tune. So only a pipeline is defined here.

# Selected features
nb_required_features = [

# Pipeline
nb_pipeline = Pipeline(
                ("selector", ColumnSelector(nb_required_features)),
                ("transformer", clone(group_dependent_preprocessor)),
        ("classifier", GaussianNB()),
), y_train)
# 0.2s
                                                                   <sklearn.compose._column_transformer.make_column_selector object at 0x0000029F3FDB0110>),
                                                                   <sklearn.compose._column_transformer.make_column_selector object at 0x0000029F3FEDE5D0>)],
                ('classifier', GaussianNB())])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with

3.5.2 Tune SVC model

In this subsection, a Support Vector Machine for Classification (SVC) will be tuned.

Note: Hyperparameter tuning using OptunaSearchCV completed most of the trials within 10 minutes. However, a few trials remained stuck for hours. I interrupted the last tuning attempt after approximately 422 minutes (Figure 3.17). I did several attempt like that some of them did not finish even after 600 minutes (10 hours). Luckily, Optuna prints out intermediate results, so from these attempts I selected the hyperparameters that performed best.

# Selected features
svc_required_features = [

# In order not to run hyperparameter tuning step that may stuck, I added a flag
# to skip it. If needed, you may rerun the search and stop it after a required
# amount of time.
to_tune_svc = False

if to_tune_svc:
    # Pipeline
    svc_pipeline = Pipeline(
                    ("selector", ColumnSelector(svc_required_features)),
                    ("transformer", clone(group_dependent_preprocessor)),
                SVC(random_state=1, probability=True, class_weight="balanced"),

    # Tune SVC model and cache the results

    @my.cache_results(dir_cache + "03_1_tune_svc.pickle")
    def tune_svc():
        # Hyperparameter space
        svc_params = {
            "classifier__kernel": CategoricalDistribution(["linear", "rbf", "poly"]),
            "classifier__C": FloatDistribution(1e-3, 100, log=True),
            "classifier__gamma": FloatDistribution(1e-3, 100, log=True),
            "classifier__degree": IntDistribution(1, 6),

        # Define the search
        search = OptunaSearchCV(

        # Perform hyperparameter tuning using cross-validation
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=FutureWarning)
  , y_train)

        return search

    tuning_results_svc = tune_svc()
# Duration: interrupted after 422m 8.7s

# Pipeline with the best hyperparameters
svc_pipeline = Pipeline(
                ("selector", ColumnSelector(svc_required_features)),
                ("transformer", clone(group_dependent_preprocessor)),
                    "kernel": "poly",
                    "C": 39.77955198439039,
                    "gamma": 0.002965591796461689,
                    "degree": 2,
), y_train)
                                                                   <sklearn.compose._column_transformer.make_column_selector object at 0x0000029F3B41EF10>)],
                 SVC(C=39.77955198439039, class_weight='balanced', degree=2,
                     gamma=0.002965591796461689, kernel='poly',
                     probability=True, random_state=1))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with
Fig. 3.17. One of the attempts where OptunaSearchCV stuck on several not finished attempts for hours.

The best SVC model:

[I 2023-09-21 09:58:14,387] Trial 183 finished with value: 0.2994436969318081 and parameters: {'classifier__kernel': 'poly', 'classifier__C': 39.77955198439039, 'classifier__gamma': 0.002965591796461689, 'classifier__degree': 2}. 

Best is trial 183 with value: 0.2994436969318081.

3.5.3 Tune XGBoost

Next, XGBoost model will be tuned.

# Selected features
xgb_required_features = [

# Pipeline
xgb_pipeline = Pipeline(
                ("selector", ColumnSelector(xgb_required_features)),
                ("transformer", clone(group_dependent_preprocessor)),
        ("classifier", best_models["XGBoost"]["classifier"]),

# Tune xgboost model and cache the results

@my.cache_results(dir_cache + "03_2_tune_xgb.pickle")
def tune_xgb():
    # Hyperparameters space
    xgb_params = {
        "classifier__n_estimators": IntDistribution(10, 1000, step=10),
        "classifier__max_depth": IntDistribution(1, 10),
        "classifier__scale_pos_weight": FloatDistribution(1, 60),
        "classifier__min_child_weight": IntDistribution(1, 20),
        "classifier__gamma": FloatDistribution(1e-8, 1, log=True),
        "classifier__reg_alpha": FloatDistribution(1e-8, 1, log=True),
        "classifier__reg_lambda": FloatDistribution(1e-8, 1, log=True),
        "classifier__subsample": FloatDistribution(0.05, 1),
        "classifier__colsample_bytree": FloatDistribution(0.05, 1, log=True),
        "classifier__learning_rate": FloatDistribution(1e-3, 1, log=True),

    # Define the search
    search = OptunaSearchCV(

    # Perform hyperparameter tuning using cross-validation
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=FutureWarning), y_train)

    return search

tuning_results_xgb = tune_xgb()
# Duration: 3m 50.9s
                                    'classifier__reg_alpha': FloatDistribution(high=1.0, log=True, low=1e-08, step=None),
                                    'classifier__reg_lambda': FloatDistribution(high=1.0, log=True, low=1e-08, step=None),
                                    'classifier__scale_pos_weight': FloatDistribution(high=60.0, log=False, low=1.0, step=None),
                                    'classifier__subsample': FloatDistribution(high=1.0, log=False, low=0.05, step=None)},
               random_state=1, scoring='f1', timeout=1000, verbose=1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with

The best XGBoost model:

[I 2023-09-24 09:33:03,196] Trial 194 finished with value: 0.3076357369723567 and parameters: {'classifier__n_estimators': 960, 'classifier__max_depth': 5, 'classifier__scale_pos_weight': 7.087738003648041, 'classifier__min_child_weight': 9, 'classifier__gamma': 0.35508303168648314, 'classifier__reg_alpha': 0.4127588744223943, 'classifier__reg_lambda': 7.520697059910392e-05, 'classifier__subsample': 0.7755901359590227, 'classifier__colsample_bytree': 0.5273351309040184, 'classifier__learning_rate': 0.0033382168701707937}.

Best is trial 194 with value: 0.3076357369723567.

3.5.4 Comparison of Tuned Models

This subsection summarizes the results of the tuned Naive Bayes, XGBoost, and SVC models. Table 3.21 displays scores for the training set without cross-validation, while Table 3.22 shows scores for the validation set. It appears that both Naive Bayes and XGBoost exhibit signs of overfitting, with higher F1 scores in the training set than in the validation set. Conversely, SVC demonstrates more consistent F1 scores, with values of 0.293 in the training set and 0.238 in the validation set, which are closer to each other compared to the other two classifiers.


SVC was selected as the best model for this dataset.

final_candidates = {
    "Naive Bayes": nb_pipeline,
    "SVC": svc_pipeline,
    "XGBoost": tuning_results_xgb.best_estimator_,
    "--- Train ---",
--- Train ---
No information rate:  0.951
Table 3.21. Hyperparameter tuning results for the training set. The best values in each column are highlighted. Note: The F1 scores here and above differ as scores in this table are calculated without cross-validation.
  n No_info_rate Accuracy BAcc BAcc_01 F1 F1_neg TPR TNR PPV NPV Kappa ROC_AUC
XGBoost 3576 0.951 0.891 0.744 0.487 0.341 0.941 0.580 0.907 0.242 0.977 0.293 0.873
Naive Bayes 3576 0.951 0.886 0.711 0.422 0.306 0.938 0.517 0.905 0.217 0.973 0.255 0.839
SVC 3576 0.951 0.868 0.723 0.447 0.293 0.927 0.563 0.884 0.198 0.975 0.239 0.798
    "--- Validation ---",
--- Validation ---
No information rate:  0.952
Table 3.22. Hyperparameter tuning results for the validation set The best values in each column are highlighted.
  n No_info_rate Accuracy BAcc BAcc_01 F1 F1_neg TPR TNR PPV NPV Kappa ROC_AUC
SVC 766 0.952 0.858 0.669 0.337 0.238 0.922 0.459 0.878 0.160 0.970 0.179 0.783
Naive Bayes 766 0.952 0.873 0.600 0.200 0.185 0.931 0.297 0.903 0.134 0.962 0.127 0.817
XGBoost 766 0.952 0.862 0.555 0.111 0.131 0.925 0.216 0.894 0.094 0.957 0.068 0.784

4 Final Model

4.1 Preparation

4.1.1 Prepare Data for Final Model Evaluation

  • First, merge training and validation data to form a dataset that will be used for creating the model.
  • Second, separate predictors and target variable.

Final trining set = training set + validation set

target = "stroke_01"

# Train + validation set for training model
data_train_final = pd.concat([data_train, data_validation])
X_train_final = data_train_final.drop(target, axis=1)
y_train_final = data_train_final[target].to_numpy()

# Test set for testing model
X_test = data_test.drop(target, axis=1)
y_test = data_test[target].to_numpy()

4.1.2 Create Final Pipeline

Now, final pre-processing and prediction pipeline will be created af fitted:

  • First, define the final pipeline.
  • Second, fit the pipeline to the training (training + validation) data.
# Final pipeline
pipeline_features_in = ["age", "smoking_status", "health_risk_score"]
# Note:
# FeatureEngineer from variables "age", "health_risk_score", "smoking_status"
# it creates "stroke_risk_40", "health_risk_score", "age_smoking_interaction".
pre_processor = Pipeline(
        ("selector", ColumnSelector(pipeline_features_in)),
        ("feature_engineer", FeatureEngineer()),
        ("imputer", SimpleImputer(strategy="median")),
        ("scaler", StandardScaler()),

classifier = SVC(

final_pipeline = Pipeline([
    ("preprocessor", pre_processor),
    ("classifier", classifier),
]), y_train_final)

                                  ColumnSelector(keep=['age', 'smoking_status',
                                 ('feature_engineer', FeatureEngineer()),
                                 ('imputer', SimpleImputer(strategy='median')),
                                 ('scaler', StandardScaler())])),
                 SVC(C=39.77955198439039, class_weight='balanced', degree=2,
                     gamma=0.002965591796461689, kernel='poly',
                     probability=True, random_state=1))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with

4.2 Performance on Test Data

The performance of the final model on the test data is as follows: F1 (of positive group) is 32.1%, balanced accuracy is 73.7% and ROC AUC is 0.801. Performance on the training set is slightly lower than on test set (that is unusual), but trends between the sets are similar (see Table 4.1 and compare Figure 4.1 with Figure 4.2) which indicates that model captures trends an not noise.

        "SVC (train)": ml.get_classification_scores(
            final_pipeline, X_train_final, y_train_final
        "SVC (test)": ml.get_classification_scores(final_pipeline, X_test, y_test),
Table 4.1. Final model performance on final training and test sets.
  n No_info_rate Accuracy BAcc BAcc_01 F1 F1_neg TPR TNR PPV NPV Kappa ROC_AUC
SVC (train) 4342 0.951 0.867 0.708 0.415 0.280 0.927 0.531 0.884 0.190 0.974 0.224 0.798
SVC (test) 767 0.950 0.879 0.737 0.473 0.321 0.933 0.579 0.894 0.222 0.976 0.269 0.801
y_pred_final = final_pipeline.predict(X_train_final)

plot_confusion_matrices(y_train_final, y_pred_final);
Fig. 4.1. Confusion matrix for the final model (final training set): absolute counts (left), true-label normalized proportions (center), and predicted-label normalized proportions (right).
y_pred = final_pipeline.predict(X_test)
plot_confusion_matrices(y_test, y_pred);
Fig. 4.2. Confusion matrix for the final model (test set): absolute counts (left), true-label normalized proportions (center), and predicted-label normalized proportions (right).

4.3 Feature Importance

In this section, the importance of each predictor is evaluated using SHAP values. In general, stroke_risk_40 variable is 2 times more influential than each of the remaining variables (Figure 4.3). Yet, in some cases health_risk_score have more influence than the other predictors (see right-most red dots in Figure 4.4). The highest values of each of the 3 predictors lead to the positive outcome (i.e., stroke) while moderate and low values either do not have much influence or lead to the negative outcome (i.e., no stroke; see Figure 4.4).

Code to calculate SHAP values
@my.cache_results(dir_cache + "04_1_shap_values.pickle")
def get_shap(pipeline, X_train, y_train, X_test):
    """Calculate SHAP values"""
    preprocessor = clone(pipeline["preprocessor"])
    model = clone(pipeline["classifier"])

    d_train = preprocessor.fit_transform(X_train)
    d_test = preprocessor.transform(X_test), y_train)
    explainer = shap.Explainer(model.predict_proba, d_test)
    shap_values = explainer(d_test)
    return shap_values

shap_values = get_shap(final_pipeline, X_train_final, y_train_final, X_test)
Code of the figure[:, :, 1], max_display=30);
Fig. 4.3. General feature importance for the final model: variables that were provided to the classifier.
Code of the figure
shap.plots.beeswarm(shap_values[:, :, 1], max_display=30);
No data for colormapping provided via 'c'. Parameters 'vmin', 'vmax' will be ignored
Fig. 4.4. Influence of feature values on the prediction in the final model: variables that were provided to the classifier.

4.4 Classification Model for Production

A pre-processing pipeline and the model trained on the entire dataset (including training, validation, and test data) will be saved to a file. This file will be utilized in the production environment to predict a patient’s likelihood of experiencing a stroke. In the production environment, the input variables will be:

  • age (in years);
  • health_risk_score (integer from 0 to 4);
  • smoking_status (one of the following: never smoked, smokes, formerly smoked, unknown).

The variables required by the model will be created within the pipeline.

The model used in production will be trained on the entire dataset (including training, validation, and test data).

target = "stroke_01"
production_pipeline = clone(final_pipeline), axis=1), data_all[target])

file = "model_to_deploy/final_pipeline.pkl"
with open(file, "wb") as f:
    joblib.dump(production_pipeline, f)

4.5 Deployment

To deploy the model, a Flask application was created. The code needed to deploy the application is present in Github repository GegznaV/deploy-stroke-prediction. For predictions, you send a request to the application providing the following information about the patient as a JSON object:

  • age (in years);
  • health_risk_score (integer from 0 to 5);
  • smoking_status (one of the following: never smoked, formerly smoked, smokes, Unknown).

You may access the application via online server or use a local development server.


The predictions of this model must not be used as medical advice. For medical advice, please, consult your physician.

4.5.1 Predictions via Online Server

Model was deployed on and accessible at (you should use this URL only with the available routes). You may test if server is up via route /test (link) and make predictions via route /api/predict.

The examples to test the service will use curl command line tool. To to run the examples curl must be installed.

curl -k


OK - server is up and running!

You may also use other ways to communicate with the server. Figure 4.5 demonstrates the request sent via Thunder Client extension for VSCode.

Fig. 4.5. Testing if server is running via request in VSCode extension Thunder Client.

To make predictions use, e.g.:

curl --ssl-no-revoke \
     -H 'Content-Type: application/json' \
     -d '{"age":[30], "health_risk_score":[1], "smoking_status":["never smoked"]}'

Response (manually re-formatted for better readability):

  "smoking_status":["never smoked"]

To request predictions about several people at once use, e.g.:

curl --ssl-no-revoke \
     -H 'Content-Type: application/json' \
     -d '{"age":[30, 65, 84], "health_risk_score":[1, 0, 3], "smoking_status":["never smoked", "smokes", "never smoked"]}'

Response (again manually re-formatted):

  "smoking_status":["never smoked","smokes","never smoked"]

The same request sent via client Thunder Client in VSCode is shown in Figure 4.6.

Fig. 4.6. Making prediction via API request in VSCode extension Thunder Client.

4.5.2 Predictions Locally via Development Server

To deploy app locally and ant test its responses, download the contents of GitHub repository (see above) to your working directory and run the following commands in the terminal:


To test if the server is running, use:

curl -X GET

To make predictions, use, e.g.:

curl -X POST \
     -H 'Content-Type: application/json' \
     -d '{"age":[30], "health_risk_score":[1], "smoking_status":["never smoked"]}'
curl -X POST \
     -H 'Content-Type: application/json' \
     -d '{"age":[30, 65, 84], "health_risk_score":[1, 0, 3], "smoking_status":["never smoked", "smokes", "never smoked"]}'

5 Final Remarks

Notes and ideas for improvement:

  1. Before feature engineering, a baseline model could be created to see if the feature engineering actually improves the model.
  2. To avoid additional noise and make models more comparable using cross-validation, it would be better to use cross-validation object. Despite the fact that I used random seeds, I got suspicious about this after I noticed that Optuna’s OptunaSearchCV results are not reproducible even with the same random seed.
  3. Investigate more thoroughly why some CatBoost optimization trials failed. I already found that cross-entropy loss function is not compatible with imbalanced data (options that take this into account).
  4. A better space of SVC hyperparameters could be defined as some combinations of hyperparameters result in extremely long calculation times yet (in many cases) do not improve the performance. More investigation on this could be done.
  5. SVC predictions via .predict() and .predict_proba() methods are not always in alignment to each other. This is a known issue (e.g., link 1 and link 2) and I found out about it only after the model was deployed. Due to lack of time, I did not look deeply into the issue, but it must be investigated further to decide if the predictions are really reliable.
  6. More extensive feature engineering could be performed (e.g., interaction between age and more different smoking statuses; log-transformation and other mathematical transformations could be tried for predictors).
  7. Feature importance could have been calculated not only for the features that were used in the model but also for those that could provide better understanding for doctors and patients. For example, this could include features such as patient’s age, health risk score, smoking status, or even separate components of the health risk score.
  8. The final production pipeline is quite vulnerable to errors (e.g., to the wrong user input). This should be improved by adding more checks, disallowing forbidden values (user input) and creating informative error messages. Yet, the main purpose of this project was prototyping and not the robustness of production line.
  9. Some plots and other results were not commented in the text. This should be improved. Yet, I decided to keep those plots/results for data monitoring purposes.
  10. The code could be better organized. Functions that are not used in this project could be removed from the external files.