Building an AI for Leukemia Detection: An End-to-End Guide

Introduction

Leukemia is a severe blood cancer that requires early detection for better patient outcomes. Fortunately, advancements in machine learning and AI have made it possible to develop automated leukemia detection systems from cell images. In this blog post, I will share how to build a leukemia AI detection system from scratch, covering everything from data acquisition to model deployment.

Data Acquisition and Processing

Data Source

This project utilizes the ISBI 2019 Leukemia Cell Image Dataset provided by Anubha Gupta et al. The dataset contains microscopic images of acute lymphoblastic leukemia (ALL) cells and normal hematopoietic (HEM) cells. It can be accessed through The Cancer Imaging Archive.

Data Processing

Data processing forms the foundation of this project. We use Python for data manipulation, with key steps including:

  1. Downloading and Unzipping the Dataset: Using the Kaggle API, we download the dataset and unzip it within Google Colab. Here’s the essential code:
# Install Kaggle library
!pip install kaggle

# Upload Kaggle credentials
from google.colab import files
files.upload()

# Configure Kaggle API
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# Download the dataset
!kaggle datasets download -d andrewmvd/leukemia-classification

# Unzip the dataset
!unzip leukemia-classification.zip
  1. Image Vectorization: Convert image data into a numerical format suitable for machine learning models. We employ an efficient NumPy vectorization approach instead of the slower Pandas iteration method. Here’s the core code for image vectorization:
import os
import numpy as np
from tqdm import tqdm
from skimage import io

# Set data path
path_training = '/content/C-NMC_Leukemia/training_data'
folds = [os.path.join(path_training, d) for d in os.listdir(path_training) if os.path.isdir(os.path.join(path_training, d))]

# Process ALL class images
data = []
img_count = 0
max_images = 1000

for fold in folds:
    path_all = os.path.join(fold, 'all')
    if os.path.exists(path_all) and img_count < max_images:
        img_files = os.listdir(path_all)
        for img_file in tqdm(img_files, desc=f"Processing ALL: {path_all}"):
            if img_count >= max_images:
                break
            img_path = os.path.join(path_all, img_file)
            img = io.imread(img_path)
            row = img.flatten()
            data.append(row)
            img_count += 1

X = np.vstack(data)
df = pd.DataFrame(X)
df['Class'] = 'ALL'

Similarly, we process HEM class images and combine both datasets for subsequent model training.

Model Training and Optimization

Model Selection

We choose the Random Forest algorithm for this classification task. Random Forest is an ensemble learning method that offers good generalization and interpretability, making it suitable for high-dimensional image data.

Model Training

We split the processed dataset into training and validation sets and train the Random Forest model using the training set. Here’s the key code for model training:

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

# Separate features and labels
X = df_combined_train.drop('Class', axis=1)
y = df_combined_train['Class']

# Split into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# Initialize and train the Random Forest classifier
rf_model = RandomForestClassifier(
    n_estimators=200,
    class_weight='balanced',
    min_samples_leaf=1,
    random_state=42,
    n_jobs=-1
)
rf_model.fit(X_train, y_train)

Model Optimization

To enhance the model’s sensitivity and specificity, we optimize the prediction threshold. By adjusting the threshold, we can achieve a better balance between sensitivity and specificity. Here’s the code for threshold optimization:

# Obtain prediction probabilities
y_pred_proba = rf_model.predict_proba(X_test)

# Determine class index
class_labels = rf_model.classes_
all_index = np.where(class_labels == 'ALL')[0][0]

# Set a custom threshold
threshold = 0.35
y_pred = (y_pred_proba[:, all_index] >= threshold).astype(int)
y_pred = np.where(y_pred == 1, 'ALL', 'HEM')

# Calculate performance metrics
accuracy = accuracy_score(y_test, y_pred)
print(f"Test Accuracy: {accuracy:.4f}")

# Calculate sensitivity and specificity
tn, fp, fn, tp = confusion_matrix(y_test == 'ALL', y_pred == 'ALL').ravel()
sensitivity = tp / (tp + fn)
specificity = tn / (tn + fp)
print(f"Sensitivity (True Positive Rate): {sensitivity:.4f}")
print(f"Specificity (True Negative Rate): {specificity:.4f}")

By analyzing sensitivity and specificity across different thresholds, we can select an optimal threshold that meets the requirements of practical application scenarios.

Application Development and Deployment

Application Development

Based on the trained model, we develop a simple web application that allows users to upload cell images for leukemia detection. The application is built using the Gradio framework, featuring a user-friendly interface and interactive functionality. Here’s the key code for the application:

import gradio as gr
import pickle
import numpy as np
import pandas as pd

# Load the model
with open('random_forest_leukemia_full.pkl', 'rb') as f:
    model_data = pickle.load(f)

rf_model = model_data['model']
threshold = model_data['threshold']
all_index = model_data['all_index']
class_labels = rf_model.classes_

# Image preprocessing function
def preprocess_image(image):
    flattened = image.flatten()
    df = pd.DataFrame([flattened])
    return df

# Prediction function
def predict_leukemia(image):
    if image is None:
        return "Please upload an image."
    processed_image = preprocess_image(image)
    proba = rf_model.predict_proba(processed_image)
    all_probability = proba[0, all_index]
    prediction = "ALL" if all_probability >= threshold else "HEM"
    result = f"Prediction: {prediction}\n"
    result += f"Confidence: {all_probability:.2%}\n"
    result += f"Threshold used: {threshold:.2f}\n\n"
    if prediction == "ALL":
        result += "This sample shows characteristics of Acute Lymphoblastic Leukemia (ALL)."
    else:
        result += "This sample appears to show normal hematopoietic cells (HEM)."
    return result

# Create Gradio interface
with gr.Blocks(title="Leukemia Cell Classification") as demo:
    gr.Markdown("# Leukemia Cell Classification")
    with gr.Tab("Make Prediction"):
        gr.Markdown("Upload a microscopy image to determine if it shows ALL or HEM.")
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(type="numpy", label="Upload Cell Image")
                predict_button = gr.Button("Predict", variant="primary")
            with gr.Column():
                prediction_result = gr.Textbox(label="Prediction Result", lines=8)
        predict_button.click(
            fn=predict_leukemia,
            inputs=input_image,
            outputs=prediction_result
        )
    # Other tab content...

demo.launch()

Application Deployment

The developed application is deployed on the Hugging Face platform, allowing users to access and use it via a provided link. Here are screenshots of the deployed application:

Application Page 1
Application Page 2

Frequently Asked Questions (FAQ)

Q1: What is the accuracy of this project?

A1: The project achieves an accuracy of 0.85 on the test set, with a sensitivity of 0.89 and specificity of 0.74. These metrics indicate that the model performs well in detecting leukemia cells, but there is still room for improvement.

Q2: How can I obtain the dataset used in this project?

A2: The leukemia cell image dataset used in this project can be downloaded from Kaggle.

Q3: Can I run this project on my own device?

A3: Yes, you can run this project locally. Ensure that Python and the required libraries are installed, and then configure and run the project according to the provided code.

Q4: How can I adjust the model’s sensitivity and specificity?

A4: You can modify the prediction threshold to adjust the model’s sensitivity and specificity. For example, lowering the threshold increases sensitivity but may decrease specificity, while raising the threshold increases specificity but may sacrifice some sensitivity.

Q5: Is this project suitable for actual medical diagnosis?

A5: It is important to note that this project is currently intended for research and educational purposes only and should not be used for actual medical diagnosis. The results are for reference only and cannot replace the diagnosis and treatment recommendations of professional doctors.

Conclusion and Outlook

This project successfully constructs a leukemia cell image classification system based on machine learning. From data acquisition and processing to model training and optimization, and finally to application development and deployment, the entire process demonstrates the potential of machine learning technology in the field of medical diagnosis.

However, the project still has some limitations, such as a limited dataset size and room for further model improvement. In the future, we can enhance and expand the project in the following ways:

  1. Data Augmentation: Expand the dataset using data augmentation techniques to improve the model’s generalization ability.
  2. Model Optimization: Experiment with more advanced deep learning models like Convolutional Neural Networks (CNNs) to further enhance classification performance.
  3. Multimodal Data Fusion: Combine other types of medical data (such as clinical indicators and genetic information) to build a more comprehensive diagnostic model.
  4. Clinical Validation: Collaborate with medical institutions to conduct large-scale clinical validation of the model and assess its effectiveness and reliability in real-world medical scenarios.

In conclusion, with the continuous development of AI technology, we can anticipate that future medical diagnosis will become more intelligent and precise, making a greater contribution to human health.