ModernBERT in Radiology Part 3: Fine-tuning a Classifier
In Part 3 of the ModernBERT in Radiology series, we will fine-tune a ModernBERT Classifier to predict the UMLS CUIs given a radiology report. It will combine our fine-tuning from Part 2 to produce a better classifier than the simple scikit-learn Logistic Regression from Part 1.
You can follow along with the associated Colab Notebook for Part 3🔥!
The ModernBERT in Radiology Series
Part 1: Simple Classifier using Hidden States.
Build a multi-label classification using a simple scikit-learn Logistic Regression model on top of the pre-trained ModernBERT body.
Part 2: Fine-tuning a Masked Language Model (MLM).
ModernBERT is pretrained as a Masked Language Model, but we perform a full fine-tuning using radiology text using
AutoModelForMaskedLM
.Part 3: Fine-tuning for Classification. 👈 This Post
Combining Part 1 and Part 2, we will build a ModernBERT classifier performing a fine-tuning of the entire model using
AutoModelForSequenceClassification
Objective
Whereas Part 1 used the hidden state of the ModernBERT body to train a simple classifier, we are going to put a proper Hugging Face neural network classifier head on the ModernBERT body and fine-tune using the unsloth/Radiology_mini
to perform multi-label classification from radiology text to UMLS CUI (concept ID). Finally, we will publish to Hugging Face as johnpaulett/ModernRadBERT-cui-classifier
.
We will follow parts of the excellent Natural Language Processing with Transforms book with code available on GitHub. Part 3 follows Chapter 2 of this book. However, our problem deviates from this chapter since we are performing multi-label classification (i.e. each text can have one or more CUI labels).
WARNING: Since the cui
concepts were generated via MedCAT, we will be learning MedCAT’s predictions.
Code
See Colab for the full Notebook: https://colab.research.google.com/drive/11hpCvNb4g65Igcmz1ePyuqdNmWwYwj4g?usp=sharing
Setup
In Part 3, I used a Colab nVidia L4 GPU. We use Hugging Face 🤗 transformers AutoModelForSequenceClassification
to load the pre-trained ModernBERT for full fine-tuning.
pip install datasets evaluate wandb triton
# flash attention only works on ampere devices (e.g. not T4)
pip install flash-attn
# Until next transformers release (4.48.0)
pip install git+https://github.com/huggingface/transformers.git
model_id = (
"answerdotai/ModernBERT-base"
# answerdotai/ModernBERT-large
)
dataset_name = (
# "eltorio/ROCOv2-radiology"
"unsloth/Radiology_mini" # 0.33% of ROCOv2-radiology, for a quicker demo
)
push_to_hub = True
output_dir = "ModernRadBERT-cui-classifier"
Load & Transform the Dataset
See Part 1 for details on the dataset.
Load the dataset in and re-split.
from datasets import load_dataset, DatasetDict
original_dataset = load_dataset(dataset_name)
print(f"Training Size: {original_dataset['train'].size_in_bytes / (1024 * 1024 * 1024):.2f} GB")
validation_size = int(0.15 * (len(original_dataset['train']) + len(original_dataset['test'])))
dataset = DatasetDict({
'train': original_dataset['train'].shuffle(seed=42).select(range(validation_size, len(original_dataset['train']))),
'validation': original_dataset['train'].shuffle(seed=42).select(range(validation_size), ) ,
# Keep the test -- we'll hold this back for comparison between models
'test': original_dataset['test']
})
dataset = dataset.remove_columns(['image'])
# Now 'new_dataset' contains the training and validation sets
# You can use new_dataset['train'] and new_dataset['validation']
print(f"training set size: {len(dataset['train'])}")
print(f"validation set size: {len(dataset['validation'])}")
print(f"test set size: {len(dataset['test'])}")
Convert the caption
into tokens using ModernBERT’s tokenizer:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
def tokenize_function(examples):
return tokenizer(
examples["caption"],
padding="max_length",
truncation=True,
# ModernBERT allows an increase to 8124 from 512 in BERT!
# Our max len() of the captions in the train set is 934, so roughly 934/4 ~= 233,
# and further testing of the longest attention_mask shows this is actually 206.
# Increasing too high will consume significant memory while we extract
# the hidden states for all the inputs.
max_length=256,
)
dataset = dataset.map(
tokenize_function, batched=True
)
Since cui
is a multi-label, we will use scikit-learn’s MultiLabelBinarizer
:
from sklearn.preprocessing import MultiLabelBinarizer
import numpy as np
mlb = MultiLabelBinarizer()
train_labels = mlb.fit(dataset['train']['cui'])
def transform_labels(example):
# Transform single example's CUIs to binary vector
binary_labels = mlb.transform([example['cui']])[0] # [0] to get the single example's labels
# Convert to float32 for BCEWithLogitsLoss
example['labels'] = binary_labels.astype(np.float32).tolist()
example['num_labels'] = sum(binary_labels)
return example
dataset = dataset.map(
transform_labels,
desc="Transforming labels to binary vectors",
num_proc=4,
)
Train the Classifier
We use AutoModelForSequenceClassification
with the number of distinct labels, making sure to set problem_type="multi_label_classification"
:
from transformers import AutoModelForSequenceClassification
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = AutoModelForSequenceClassification.from_pretrained(
model_id,
num_labels=len(mlb.classes_),
problem_type="multi_label_classification"
).to(device)
Prepare the F1 score to compute and evaluate each epoch:
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, hamming_loss
# We use sklearn's metrics instead of `evaluate`, due to evaluate's f1 only wanting a single value, not a one-hot array
def compute_metrics(eval_pred):
predictions, labels = eval_pred
# Apply sigmoid activation and threshold at 0.5
predictions = 1 / (1 + np.exp(-predictions)) # sigmoid
predictions = (predictions > 0.5).astype(int)
# Calculate micro-averaged metrics
precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
labels, predictions, average='micro', zero_division=0
)
# Calculate macro-averaged metrics
precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
labels, predictions, average='macro', zero_division=0
)
# Calculate subset accuracy (exact match)
subset_accuracy = accuracy_score(labels, predictions)
# Calculate Hamming loss
ham_loss = hamming_loss(labels, predictions)
# Calculate per-label accuracy (element-wise)
label_wise_accuracy = np.mean((predictions == labels).astype(float))
results = {
# Micro-averaged metrics
"precision_micro": precision_micro,
"recall_micro": recall_micro,
"f1": f1_micro,
# Macro-averaged metrics
"precision_macro": precision_macro,
"recall_macro": recall_macro,
"f1_macro": f1_macro,
# Other metrics
"exact_match": subset_accuracy,
"hamming_loss": ham_loss,
"label_accuracy": label_wise_accuracy
}
return results
We then conduct 20 epochs of training:
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=20,
learning_rate=2e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
weight_decay=0.01,
logging_strategy="steps", # Log metrics every n steps
logging_steps=100, # Log every 100 steps
eval_strategy="epoch",
metric_for_best_model="f1",
greater_is_better=True, # higher f1 is better
save_strategy="epoch",
load_best_model_at_end=True,
save_total_limit=3, # Only keep the 3 best checkpoints
push_to_hub=push_to_hub,
report_to="none", # Comment to enable wandb
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
compute_metrics=compute_metrics,
)
trainer.train()
Training Loss | Epoch | Step | Validation Loss | Precision Micro | Recall Micro | F1 | Precision Macro | Recall Macro | F1 Macro | Exact Match | Hamming Loss | Label Accuracy |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0.1371 | 1.0 | 205 | 0.1214 | 0.8169 | 0.6679 | 0.7350 | 0.4170 | 0.3481 | 0.3667 | 0.5681 | 0.0404 | 0.9596 |
0.0904 | 2.0 | 410 | 0.1054 | 0.8704 | 0.6833 | 0.7656 | 0.5391 | 0.3744 | 0.4106 | 0.6029 | 0.0351 | 0.9649 |
0.0458 | 3.0 | 615 | 0.1012 | 0.8316 | 0.7582 | 0.7932 | 0.5899 | 0.5157 | 0.5251 | 0.6580 | 0.0332 | 0.9668 |
0.0216 | 4.0 | 820 | 0.1134 | 0.8738 | 0.7044 | 0.7800 | 0.7129 | 0.4338 | 0.5071 | 0.6377 | 0.0333 | 0.9667 |
0.01 | 5.0 | 1025 | 0.1194 | 0.8382 | 0.7159 | 0.7723 | 0.6707 | 0.4817 | 0.5336 | 0.6290 | 0.0354 | 0.9646 |
0.0047 | 6.0 | 1230 | 0.1224 | 0.8721 | 0.7332 | 0.7967 | 0.6475 | 0.4692 | 0.5187 | 0.6638 | 0.0314 | 0.9686 |
0.0024 | 7.0 | 1435 | 0.1228 | 0.8540 | 0.7409 | 0.7934 | 0.7016 | 0.5071 | 0.5648 | 0.6725 | 0.0324 | 0.9676 |
0.0012 | 8.0 | 1640 | 0.1289 | 0.8744 | 0.7217 | 0.7907 | 0.7053 | 0.4852 | 0.5531 | 0.6609 | 0.0320 | 0.9680 |
0.0009 | 9.0 | 1845 | 0.1323 | 0.8765 | 0.7217 | 0.7916 | 0.7063 | 0.4831 | 0.5512 | 0.6667 | 0.0319 | 0.9681 |
0.0007 | 10.0 | 2050 | 0.1337 | 0.8765 | 0.7217 | 0.7916 | 0.7059 | 0.4809 | 0.5493 | 0.6609 | 0.0319 | 0.9681 |
0.0006 | 11.0 | 2255 | 0.1357 | 0.8744 | 0.7217 | 0.7907 | 0.7044 | 0.4809 | 0.5488 | 0.6609 | 0.0320 | 0.9680 |
0.0006 | 12.0 | 2460 | 0.1373 | 0.8701 | 0.7198 | 0.7878 | 0.7027 | 0.4805 | 0.5476 | 0.6638 | 0.0325 | 0.9675 |
0.0005 | 13.0 | 2665 | 0.1395 | 0.8684 | 0.7217 | 0.7883 | 0.6977 | 0.4827 | 0.5477 | 0.6638 | 0.0325 | 0.9675 |
0.0005 | 14.0 | 2870 | 0.1410 | 0.8701 | 0.7198 | 0.7878 | 0.7029 | 0.4815 | 0.5488 | 0.6580 | 0.0325 | 0.9675 |
0.0005 | 15.0 | 3075 | 0.1426 | 0.8644 | 0.7217 | 0.7866 | 0.6957 | 0.4818 | 0.5466 | 0.6551 | 0.0329 | 0.9671 |
0.0004 | 16.0 | 3280 | 0.1432 | 0.8670 | 0.7255 | 0.7900 | 0.6976 | 0.4872 | 0.5508 | 0.6580 | 0.0324 | 0.9676 |
0.0004 | 17.0 | 3485 | 0.1442 | 0.8687 | 0.7236 | 0.7895 | 0.6981 | 0.4849 | 0.5492 | 0.6580 | 0.0324 | 0.9676 |
0.0004 | 18.0 | 3690 | 0.1448 | 0.8670 | 0.7255 | 0.7900 | 0.6985 | 0.4872 | 0.5510 | 0.6580 | 0.0324 | 0.9676 |
0.0004 | 19.0 | 3895 | 0.1451 | 0.8647 | 0.7236 | 0.7879 | 0.6963 | 0.4849 | 0.5485 | 0.6580 | 0.0327 | 0.9673 |
0.0004 | 20.0 | 4100 | 0.1454 | 0.8664 | 0.7217 | 0.7874 | 0.6973 | 0.4836 | 0.5480 | 0.6580 | 0.0327 | 0.9673 |
Use Model to Predict:
def predict(text, threshold=0.5):
# Tokenize input
inputs = tokenizer(
text,
padding=True,
truncation=True,
max_length=256,
return_tensors="pt"
)
# Move inputs to device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Get predictions
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Apply sigmoid to get probabilities
probs = torch.sigmoid(logits).cpu().numpy()
# Get binary predictions
predictions = (probs > threshold).astype(int)
# Convert binary predictions back to labels
predicted_labels = mlb.inverse_transform(predictions)[0]
# Create dictionary of label probabilities
label_probs = {
label: float(prob) # Convert to Python float for JSON serialization
for label, prob in zip(mlb.classes_, probs[0])
if prob > threshold
}
# Sort by probability
label_probs = dict(sorted(label_probs.items(), key=lambda x: x[1], reverse=True))
# return {
# "labels": list(predicted_labels),
# "probabilities": label_probs
# }
print(f"{text}: {[cui_map.get(label) for label in list(predicted_labels)]}")
We can see the Anterior-Posterior, which is often related but not explicitly stated. Ideally using the full ROCOv2 dataset would perform better.
predict("CT of Chest with pneumothorax")
# CT of Chest with pneumothorax: ['X-Ray Computed Tomography (C0040405)']
predict("Abdomen x-ray with small bowel obstruction")
# Abdomen x-ray with small bowel obstruction: ['Abdomen (C0000726)', 'Plain x-ray (C1306645)', 'Anterior-Posterior (C1999039)']
Push to Hugging Face
Finally, we upload the fine-tuned weights to HuggingFace: https://huggingface.co/johnpaulett/ModernRadBERT-cui-classifier
trainer.push_to_hub()
Conclusion
Fine-tuning a radiology transformer for classification tasks on the report text is incredibly powerful.
- Billing tasks such as extracting CPT and ICD codes
- Classifying specific findings, normal vs abnormal, scoring
- Quality assessments such as follow-up recommendations, critical result communication
Explore the Fine-Tuned Model
You can pull down this fine-tuned model (WARNING: it is trained on a small dataset as a demo, so do not use it for any real problems):
from transformers import pipeline
pipe = pipeline("fill-mask", model="johnpaulett/ModernRadBERT-cui-classifier")
Citations
- Warner, B., Chaffin, A., Clavié, B., Weller, O., Hallström, O., Taghadouini, S., Gallagher, A., Biswas, R., Ladhak, F., Aarsen, T., Cooper, N., Adams, G., Howard, J., & Poli, I. (2024). Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference. arXiv preprint arXiv:2412.13663.
- Ronan, L. M. (2024). ROCOv2-radiology [Dataset]. Hugging Face. https://doi.org/10.57967/hf/3489
- Tunstall, L. (2022). Natural Language Processing with Transformers: Building Language Applications with Hugging Face. O’Reilly Media.