ModernBERT in Radiology Part 2: Fine Tuning a Masked Language Model (MLM)
In Part 2 of the ModernBERT in Radiology series, we will fine-tune ModernBERT against the ROCOv2 dataset for Radiology for the Masked Language Model task so that given an input of ninth [MASK] fracture
the model would predict ninth rib fracture
.
You can follow along with the associated Colab Notebook for Part 2🔥!
The ModernBERT in Radiology Series
Part 1: Simple Classifier using Hidden States.
Build a multi-label classification using a simple scikit-learn Logistic Regression model on top of the pre-trained ModernBERT body.
Part 2: Fine-tuning a Masked Language Model (MLM). 👈 This Post
ModernBERT is pretrained as a Masked Language Model, but we perform a full fine-tuning using radiology text using
AutoModelForMaskedLM
.Part 3: Fine-tuning for Classification.
Combining Part 1 and Part 2, we will build a ModernBERT classifier performing a fine-tuning of the entire model using
AutoModelForSequenceClassification
ModernBERT as a Masked Language Model
ModernBERT is a Masked Language Model (MLM), meaning that it was trained for the task of predicting what word makes sense in a masked sequence, e.g., an input of "as an old capital city, [MASK], is the center of government in the United Kingdom"
, ModernBERT would predict that the [MASK]
“London” (0.8625) and “Edinburgh” (0.0595). This task may look similar to the decoder-only LLMs such as GPT or Llama. Unlike decoder-only models, ModernBERT as an encoder-only model has bidirectional attention, meaning that it will use tokens after the [MASK]
to make its prediction.
Interestingly, the original [BERT] model was trained both for the Masked Language Model (MLM) task and Next Sentence Prediction (NSP), but subsequent models like RoBERTa showed that the NSP task might not be as helpful. ModernBERT is only trained for MLM but is easily transferrable to other NLP tasks like Classification (as we will examine in Part 3).
Objective
We aim to fine-tune an MLM model that uses a subset of the eltorio/ROCOv2-radiology
captions of radiology texts to provide improved radiology-specific MLM predictions. Given an input of ninth [MASK] fracture,
the model would predict ninth rib fracture
.
Code
See Colab for the full Notebook: https://colab.research.google.com/drive/1pARD1JpZR9qQDdtqFSWYdz2vuEJ04scB?usp=sharing
Setup
In Part 2, I used a Colab nVidia L4 GPU. We use Hugging Face 🤗 transformers AutoModelForMaskedLM
to load the pre-trained ModernBERT for full fine-tuning.
pip install datasets evaluate wandb triton
pip install umap-learn
# flash attention only works on ampere+ devices (i.e., not T4)
pip install flash-attn
# Until the next transformers release (4.48.0)
pip install git+https://github.com/huggingface/transformers.git
model_id = (
"answerdotai/ModernBERT-base"
# answerdotai/ModernBERT-large
)
dataset_name = (
# "eltorio/ROCOv2-radiology"
"unsloth/Radiology_mini" # 0.33% of ROCOv2-radiology, for a quicker demo
)
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id).to(device)
Explore ModernBERT
We will make a predict()
to get the Top-5 results.
def predict(text, top_k=5):
inputs = tokenizer(text, return_tensors="pt").to(device)
outputs = model(**inputs)
# Predictions for the mask:
masked_index = inputs["input_ids"][0].tolist().index(tokenizer.mask_token_id)
# Top K predictions
top_k_logits = torch.topk(outputs.logits[0, masked_index], top_k)
top_k_token_ids = top_k_logits.indices
top_k_tokens = tokenizer.batch_decode(top_k_token_ids)
top_k_probabilities = torch.softmax(top_k_logits.values, dim=0)
# Zip together
predictions = list(zip(top_k_tokens, top_k_probabilities))
print(text)
for token, prob in predictions:
print(f" {token.strip()}: {prob:.4f}")
print("")
The predictions do not look great for radiology examples using the out-of-the-box ModernBERT model before fine-tuning. I would expect “ninth rib fracture”, “small bowel obstruction”, and “pneumonia in the right lower lobe/lung” to rank more highly.
predict("ninth [MASK] fracture")
# ninth [MASK] fracture
# degree: 0.6848
# class: 0.0940
# base: 0.0830
# page: 0.0707
# grade: 0.0674
predict("pneumonia in the right lower [MASK]")
# pneumonia in the right lower [MASK]
# extremity: 0.3862
# jaw: 0.2338
# lobe: 0.1345
# limb: 0.1244
# lung: 0.1210
predict("small bowel [MASK]")
# small bowel [MASK]
# syndrome: 0.4146
# disease: 0.1788
# : 0.1566
# obstruction: 0.1396
# cancer: 0.1105
predict("small [MASK] obstruction")
# small [MASK] obstruction
# er: 0.6755
# : 0.1045
# intestine: 0.0778
# claim: 0.0727
# bowel: 0.0695
Load and Tokenize the Dataset
For speed of training, we’re going to use unsloth/Radiology_mini
, which is a 0.33% sample from eltorio/ROCOv2-radiology
. You would want to expand for a real model.
from datasets import load_dataset
dataset = load_dataset(dataset_name)
We will be using the caption
field as our training dataset to simulate radiology reports.
dataset['train'][0]
# {'image': <PIL.PngImagePlugin.PngImageFile image mode=L size=657x442>,
# 'image_id': 'ROCOv2_2023_train_054311',
# 'caption': 'Panoramic radiography shows an osteolytic lesion in the right posterior maxilla with resorption of the floor of the maxillary sinus (arrows).',
# 'cui': ['C1306645', 'C0037303']}
We will use the built-in ModernBERT tokenizer to transform the captions
into tokens. return_special_tokens_mask=True
is essential here as the tokenizer will include special tokens like "[CLS]"
(start of input sequence) and "[SEP]"
(the boundary between sentences), which will get marked in a new one-hot special_tokens_mask
in the tokenizer output, along with input_ids
and attention_mask
.
def tokenize_function(examples):
return tokenizer(
examples["caption"],
padding="max_length",
truncation=True,
max_length=256,
return_special_tokens_mask=True # Important for MLM
)
tokenized_datasets = dataset.map(
tokenize_function, batched=True,
# For MLM, we can remove the original columns
remove_columns=dataset['train'].column_names
)
DataCollatorForLanguageModeling
will add the "[MASK]"
within the captions of the dataset so we can train for our MLM task.
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=True, # Enable masking
mlm_probability=0.15 # 15% of tokens will be masked
)
Training
We will do a simple training (could go on for more epochs, use the full eltorio/ROCOv2-radiology
dataset, or add an early stopping callback).
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback
training_args = TrainingArguments(
output_dir=output_dir,
eval_strategy="epoch",
# eval_strategy="steps", # vs "steps". Note: `load_best_model_at_end` needs logging and eval to match
# eval_steps=500, # Evaluate every 500 steps
logging_strategy="steps", # Log metrics every n steps
logging_steps=100, # Log every 100 steps
num_train_epochs=20,
max_grad_norm=1.0, # Gradient clipping prevents exploding gradients by scaling them down if they exceed this value
warmup_ratio=0.1, # Gradually increases learning rate
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
prediction_loss_only=True, # Only need to compute loss for MLM task
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False, # Lower loss is better
save_total_limit=3, # Only keep the 3 best checkpoints
push_to_hub=push_to_hub,
report_to="none", # Comment to enable wandb
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["test"],
data_collator=data_collator,
# Could put in an early stopping, if you want
# callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)
To start the training (takes ~15 minutes on an L4 GPU).
trainer.train()
We see the loss in each epoch:
Epoch Training Loss Validation Loss
1 1.869300 1.599592
2 1.696800 1.797290
3 1.718700 1.723163
4 1.651800 1.734311
5 1.500300 1.772700
6 1.334600 1.735698
7 1.402900 1.716372
8 1.276200 1.712302
9 1.244100 1.697829
10 1.201600 1.737356
11 1.188700 1.707611
12 1.020500 1.673580
13 1.077100 1.720857
14 1.060700 1.675280
15 0.909000 1.617196
16 0.925500 1.741842
17 0.867600 1.691358
18 0.853300 1.731034
19 0.845000 1.789315
20 0.869000 1.693557
Finally, we upload the fine-tuned weights to HuggingFace: https://huggingface.co/johnpaulett/ModernRadBERT-mlm
trainer.push_to_hub()
Use the Fine-Tuned Model
The results with the fine-tuned model look substantially better than those with the default untuned model. More representative data would be helpful.
predict("ninth [MASK] fracture")
# ninth [MASK] fracture
# rib: 0.3055
# tibial: 0.2377
# femoral: 0.2057
# oblique: 0.1461
# hip: 0.1050
predict("pneumonia in the right lower [MASK]")
# pneumonia in the right lower [MASK]
# lung: 0.7542
# lobe: 0.0845
# jaw: 0.0598
# extremity: 0.0596
# lip: 0.0419
predict("small bowel [MASK]")
# small bowel [MASK]
# obstruction: 0.4843
# resection: 0.2274
# .: 0.1142
# colon: 0.0947
# CT: 0.0793
predict("small [MASK] obstruction")
# small [MASK] obstruction
# bowel: 0.7788
# intestine: 0.1237
# airway: 0.0574
# vessel: 0.0247
# intestinal: 0.0154
Conclusion
The MLM task is mighty. While not generative like an LLM, an encoder-only model like the one we have built is much faster. Example use cases in radiology reporting:
- Check for likely speech recognition errors
- Suggest standardized terminology in reports between radiologists.
- Ensure logical laterality and measurements
Explore the Fine-Tuned Model
You can pull down this fine-tuned model (WARNING: it is trained on a small dataset as a demo, so do not use it for any real problems):
from transformers import pipeline
pipe = pipeline("fill-mask", model="johnpaulett/ModernRadBERT-mlm")
Next Steps
Continue reading the ModernBERT in Radiology series:
- Part 3: Fine-tuning for Classification.
Subscribe to get notified when the following Parts of the ModernBERT in Radiology series are published.
Citations
- Warner, B., Chaffin, A., Clavié, B., Weller, O., Hallström, O., Taghadouini, S., Gallagher, A., Biswas, R., Ladhak, F., Aarsen, T., Cooper, N., Adams, G., Howard, J., & Poli, I. (2024). Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference. arXiv preprint arXiv:2412.13663.
- Ronan, L. M. (2024). ROCOv2-radiology [Dataset]. Hugging Face. https://doi.org/10.57967/hf/3489
- Tunstall, L. (2022). Natural Language Processing with Transformers: Building Language Applications with Hugging Face. O’Reilly Media.