ModernBERT in Radiology Part 3: Fine-tuning a Classifier

In Part 3 of the ModernBERT in Radiology series, we will fine-tune a ModernBERT Classifier to predict the UMLS CUIs given a radiology report. It will combine our fine-tuning from Part 2 to produce a better classifier than the simple scikit-learn Logistic Regression from Part 1.

You can follow along with the associated Colab Notebook for Part 3🔥!

The ModernBERT in Radiology Series

Objective

Whereas Part 1 used the hidden state of the ModernBERT body to train a simple classifier, we are going to put a proper Hugging Face neural network classifier head on the ModernBERT body and fine-tune using the unsloth/Radiology_mini to perform multi-label classification from radiology text to UMLS CUI (concept ID). Finally, we will publish to Hugging Face as johnpaulett/ModernRadBERT-cui-classifier.

We will follow parts of the excellent Natural Language Processing with Transforms book with code available on GitHub. Part 3 follows Chapter 2 of this book. However, our problem deviates from this chapter since we are performing multi-label classification (i.e. each text can have one or more CUI labels).

WARNING: Since the cui concepts were generated via MedCAT, we will be learning MedCAT’s predictions.

Code

See Colab for the full Notebook: https://colab.research.google.com/drive/11hpCvNb4g65Igcmz1ePyuqdNmWwYwj4g?usp=sharing

Setup

In Part 3, I used a Colab nVidia L4 GPU. We use Hugging Face 🤗 transformers AutoModelForSequenceClassification to load the pre-trained ModernBERT for full fine-tuning.

pip install datasets evaluate wandb triton
# flash attention only works on ampere devices (e.g. not T4)
pip install flash-attn
# Until next transformers release (4.48.0)
pip install git+https://github.com/huggingface/transformers.git
model_id = (
    "answerdotai/ModernBERT-base"
    # answerdotai/ModernBERT-large
)
dataset_name = (
    # "eltorio/ROCOv2-radiology"
    "unsloth/Radiology_mini" # 0.33% of ROCOv2-radiology, for a quicker demo
)
push_to_hub = True
output_dir = "ModernRadBERT-cui-classifier"

Load & Transform the Dataset

See Part 1 for details on the dataset.

Load the dataset in and re-split.

from datasets import load_dataset, DatasetDict
original_dataset = load_dataset(dataset_name)
print(f"Training Size: {original_dataset['train'].size_in_bytes / (1024 * 1024 * 1024):.2f} GB")

validation_size = int(0.15 * (len(original_dataset['train']) + len(original_dataset['test'])))

dataset = DatasetDict({
    'train': original_dataset['train'].shuffle(seed=42).select(range(validation_size, len(original_dataset['train']))),
    'validation': original_dataset['train'].shuffle(seed=42).select(range(validation_size), )  ,
    # Keep the test -- we'll hold this back for comparison between models
    'test': original_dataset['test']
})

dataset = dataset.remove_columns(['image'])

# Now 'new_dataset' contains the training and validation sets
# You can use new_dataset['train'] and new_dataset['validation']
print(f"training set size: {len(dataset['train'])}")
print(f"validation set size: {len(dataset['validation'])}")
print(f"test set size: {len(dataset['test'])}")

Convert the caption into tokens using ModernBERT’s tokenizer:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_id)

def tokenize_function(examples):
    return tokenizer(
        examples["caption"],
        padding="max_length",
        truncation=True,
        # ModernBERT allows an increase to 8124 from 512 in BERT!
        # Our max len() of the captions in the train set is 934, so roughly 934/4 ~= 233,
        #  and further testing of the longest attention_mask shows this is actually 206.
        # Increasing too high will consume significant memory while we extract
        #  the hidden states for all the inputs.
        max_length=256,
      )

dataset = dataset.map(
    tokenize_function, batched=True
)

Since cui is a multi-label, we will use scikit-learn’s MultiLabelBinarizer:

from sklearn.preprocessing import MultiLabelBinarizer
import numpy as np

mlb = MultiLabelBinarizer()
train_labels = mlb.fit(dataset['train']['cui'])

def transform_labels(example):
    # Transform single example's CUIs to binary vector
    binary_labels = mlb.transform([example['cui']])[0]  # [0] to get the single example's labels

    # Convert to float32 for BCEWithLogitsLoss
    example['labels'] = binary_labels.astype(np.float32).tolist()
    example['num_labels'] = sum(binary_labels)

    return example

dataset = dataset.map(
    transform_labels,
    desc="Transforming labels to binary vectors",
    num_proc=4,
)

Train the Classifier

We use AutoModelForSequenceClassification with the number of distinct labels, making sure to set problem_type="multi_label_classification":

from transformers import AutoModelForSequenceClassification
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = AutoModelForSequenceClassification.from_pretrained(
    model_id,
    num_labels=len(mlb.classes_),
    problem_type="multi_label_classification"
).to(device)

Prepare the F1 score to compute and evaluate each epoch:

from sklearn.metrics import precision_recall_fscore_support, accuracy_score, hamming_loss

# We use sklearn's metrics instead of `evaluate`, due to evaluate's f1 only wanting a single value, not a one-hot array
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    # Apply sigmoid activation and threshold at 0.5
    predictions = 1 / (1 + np.exp(-predictions))  # sigmoid
    predictions = (predictions > 0.5).astype(int)

    # Calculate micro-averaged metrics
    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
        labels, predictions, average='micro', zero_division=0
    )

    # Calculate macro-averaged metrics
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        labels, predictions, average='macro', zero_division=0
    )

    # Calculate subset accuracy (exact match)
    subset_accuracy = accuracy_score(labels, predictions)

    # Calculate Hamming loss
    ham_loss = hamming_loss(labels, predictions)

    # Calculate per-label accuracy (element-wise)
    label_wise_accuracy = np.mean((predictions == labels).astype(float))

    results = {
        # Micro-averaged metrics
        "precision_micro": precision_micro,
        "recall_micro": recall_micro,
        "f1": f1_micro,

        # Macro-averaged metrics
        "precision_macro": precision_macro,
        "recall_macro": recall_macro,
        "f1_macro": f1_macro,

        # Other metrics
        "exact_match": subset_accuracy,
        "hamming_loss": ham_loss,
        "label_accuracy": label_wise_accuracy
    }

    return results

We then conduct 20 epochs of training:

from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=20,

    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    logging_strategy="steps",    # Log metrics every n steps
    logging_steps=100,           # Log every 100 steps
    eval_strategy="epoch",
    metric_for_best_model="f1",
    greater_is_better=True,      # higher f1 is better

    save_strategy="epoch",
    load_best_model_at_end=True,
    save_total_limit=3,          # Only keep the 3 best checkpoints

    push_to_hub=push_to_hub,

    report_to="none",            # Comment to enable wandb
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    compute_metrics=compute_metrics,
)

trainer.train()
Training LossEpochStepValidation LossPrecision MicroRecall MicroF1Precision MacroRecall MacroF1 MacroExact MatchHamming LossLabel Accuracy
0.13711.02050.12140.81690.66790.73500.41700.34810.36670.56810.04040.9596
0.09042.04100.10540.87040.68330.76560.53910.37440.41060.60290.03510.9649
0.04583.06150.10120.83160.75820.79320.58990.51570.52510.65800.03320.9668
0.02164.08200.11340.87380.70440.78000.71290.43380.50710.63770.03330.9667
0.015.010250.11940.83820.71590.77230.67070.48170.53360.62900.03540.9646
0.00476.012300.12240.87210.73320.79670.64750.46920.51870.66380.03140.9686
0.00247.014350.12280.85400.74090.79340.70160.50710.56480.67250.03240.9676
0.00128.016400.12890.87440.72170.79070.70530.48520.55310.66090.03200.9680
0.00099.018450.13230.87650.72170.79160.70630.48310.55120.66670.03190.9681
0.000710.020500.13370.87650.72170.79160.70590.48090.54930.66090.03190.9681
0.000611.022550.13570.87440.72170.79070.70440.48090.54880.66090.03200.9680
0.000612.024600.13730.87010.71980.78780.70270.48050.54760.66380.03250.9675
0.000513.026650.13950.86840.72170.78830.69770.48270.54770.66380.03250.9675
0.000514.028700.14100.87010.71980.78780.70290.48150.54880.65800.03250.9675
0.000515.030750.14260.86440.72170.78660.69570.48180.54660.65510.03290.9671
0.000416.032800.14320.86700.72550.79000.69760.48720.55080.65800.03240.9676
0.000417.034850.14420.86870.72360.78950.69810.48490.54920.65800.03240.9676
0.000418.036900.14480.86700.72550.79000.69850.48720.55100.65800.03240.9676
0.000419.038950.14510.86470.72360.78790.69630.48490.54850.65800.03270.9673
0.000420.041000.14540.86640.72170.78740.69730.48360.54800.65800.03270.9673

Use Model to Predict:

def predict(text, threshold=0.5):
    # Tokenize input
    inputs = tokenizer(
        text,
        padding=True,
        truncation=True,
        max_length=256,
        return_tensors="pt"
    )

    # Move inputs to device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Get predictions
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

        # Apply sigmoid to get probabilities
        probs = torch.sigmoid(logits).cpu().numpy()

    # Get binary predictions
    predictions = (probs > threshold).astype(int)

    # Convert binary predictions back to labels
    predicted_labels = mlb.inverse_transform(predictions)[0]

    # Create dictionary of label probabilities
    label_probs = {
        label: float(prob)  # Convert to Python float for JSON serialization
        for label, prob in zip(mlb.classes_, probs[0])
        if prob > threshold
    }

    # Sort by probability
    label_probs = dict(sorted(label_probs.items(), key=lambda x: x[1], reverse=True))

    # return {
    #     "labels": list(predicted_labels),
    #     "probabilities": label_probs
    # }

    print(f"{text}: {[cui_map.get(label) for label in list(predicted_labels)]}")

We can see the Anterior-Posterior, which is often related but not explicitly stated. Ideally using the full ROCOv2 dataset would perform better.

predict("CT of Chest with pneumothorax")
# CT of Chest with pneumothorax: ['X-Ray Computed Tomography (C0040405)']
predict("Abdomen x-ray with small bowel obstruction")
# Abdomen x-ray with small bowel obstruction: ['Abdomen (C0000726)', 'Plain x-ray (C1306645)', 'Anterior-Posterior (C1999039)']

Push to Hugging Face

Finally, we upload the fine-tuned weights to HuggingFace: https://huggingface.co/johnpaulett/ModernRadBERT-cui-classifier

trainer.push_to_hub()

Conclusion

Fine-tuning a radiology transformer for classification tasks on the report text is incredibly powerful.

Explore the Fine-Tuned Model

You can pull down this fine-tuned model (WARNING: it is trained on a small dataset as a demo, so do not use it for any real problems):

from transformers import pipeline

pipe = pipeline("fill-mask", model="johnpaulett/ModernRadBERT-cui-classifier")

Citations