ModernBERT in Radiology Part 2: Fine Tuning a Masked Language Model (MLM)

In Part 2 of the ModernBERT in Radiology series, we will fine-tune ModernBERT against the ROCOv2 dataset for Radiology for the Masked Language Model task so that given an input of ninth [MASK] fracture the model would predict ninth rib fracture.

You can follow along with the associated Colab Notebook for Part 2🔥!

The ModernBERT in Radiology Series

ModernBERT as a Masked Language Model

ModernBERT is a Masked Language Model (MLM), meaning that it was trained for the task of predicting what word makes sense in a masked sequence, e.g., an input of "as an old capital city, [MASK], is the center of government in the United Kingdom", ModernBERT would predict that the [MASK] “London” (0.8625) and “Edinburgh” (0.0595). This task may look similar to the decoder-only LLMs such as GPT or Llama. Unlike decoder-only models, ModernBERT as an encoder-only model has bidirectional attention, meaning that it will use tokens after the [MASK] to make its prediction.

Interestingly, the original [BERT] model was trained both for the Masked Language Model (MLM) task and Next Sentence Prediction (NSP), but subsequent models like RoBERTa showed that the NSP task might not be as helpful. ModernBERT is only trained for MLM but is easily transferrable to other NLP tasks like Classification (as we will examine in Part 3).

Objective

We aim to fine-tune an MLM model that uses a subset of the eltorio/ROCOv2-radiology captions of radiology texts to provide improved radiology-specific MLM predictions. Given an input of ninth [MASK] fracture, the model would predict ninth rib fracture.

Code

See Colab for the full Notebook: https://colab.research.google.com/drive/1pARD1JpZR9qQDdtqFSWYdz2vuEJ04scB?usp=sharing

Setup

In Part 2, I used a Colab nVidia L4 GPU. We use Hugging Face 🤗 transformers AutoModelForMaskedLM to load the pre-trained ModernBERT for full fine-tuning.

pip install datasets evaluate wandb triton
pip install umap-learn
# flash attention only works on ampere+ devices (i.e., not T4)
pip install flash-attn
# Until the next transformers release (4.48.0)
pip install git+https://github.com/huggingface/transformers.git
model_id = (
    "answerdotai/ModernBERT-base"
    # answerdotai/ModernBERT-large
)
dataset_name = (
    # "eltorio/ROCOv2-radiology"
    "unsloth/Radiology_mini" # 0.33% of ROCOv2-radiology, for a quicker demo
)

from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id).to(device)

Explore ModernBERT

We will make a predict() to get the Top-5 results.

def predict(text, top_k=5):
  inputs = tokenizer(text, return_tensors="pt").to(device)
  outputs = model(**inputs)

  # Predictions for the mask:
  masked_index = inputs["input_ids"][0].tolist().index(tokenizer.mask_token_id)

  # Top K predictions
  top_k_logits = torch.topk(outputs.logits[0, masked_index], top_k)
  top_k_token_ids = top_k_logits.indices
  top_k_tokens = tokenizer.batch_decode(top_k_token_ids)
  top_k_probabilities = torch.softmax(top_k_logits.values, dim=0)

  # Zip together
  predictions = list(zip(top_k_tokens, top_k_probabilities))

  print(text)
  for token, prob in predictions:
    print(f"  {token.strip()}: {prob:.4f}")
  print("")

The predictions do not look great for radiology examples using the out-of-the-box ModernBERT model before fine-tuning. I would expect “ninth rib fracture”, “small bowel obstruction”, and “pneumonia in the right lower lobe/lung” to rank more highly.

predict("ninth [MASK] fracture")
# ninth [MASK] fracture
#   degree: 0.6848
#   class: 0.0940
#   base: 0.0830
#   page: 0.0707
#   grade: 0.0674

predict("pneumonia in the right lower [MASK]")
# pneumonia in the right lower [MASK]
#   extremity: 0.3862
#   jaw: 0.2338
#   lobe: 0.1345
#   limb: 0.1244
#   lung: 0.1210

predict("small bowel [MASK]")
# small bowel [MASK]
#   syndrome: 0.4146
#   disease: 0.1788
#   : 0.1566
#   obstruction: 0.1396
#   cancer: 0.1105

predict("small [MASK] obstruction")
# small [MASK] obstruction
#   er: 0.6755
#   : 0.1045
#   intestine: 0.0778
#   claim: 0.0727
#   bowel: 0.0695

Load and Tokenize the Dataset

For speed of training, we’re going to use unsloth/Radiology_mini, which is a 0.33% sample from eltorio/ROCOv2-radiology. You would want to expand for a real model.

from datasets import load_dataset
dataset = load_dataset(dataset_name)

We will be using the caption field as our training dataset to simulate radiology reports.

dataset['train'][0]
# {'image': <PIL.PngImagePlugin.PngImageFile image mode=L size=657x442>,
#  'image_id': 'ROCOv2_2023_train_054311',
#  'caption': 'Panoramic radiography shows an osteolytic lesion in the right posterior maxilla with resorption of the floor of the maxillary sinus (arrows).',
#  'cui': ['C1306645', 'C0037303']}

We will use the built-in ModernBERT tokenizer to transform the captions into tokens. return_special_tokens_mask=True is essential here as the tokenizer will include special tokens like "[CLS]" (start of input sequence) and "[SEP]" (the boundary between sentences), which will get marked in a new one-hot special_tokens_mask in the tokenizer output, along with input_ids and attention_mask.

def tokenize_function(examples):
    return tokenizer(
        examples["caption"],
        padding="max_length",
        truncation=True,
        max_length=256,
        return_special_tokens_mask=True  # Important for MLM
      )

tokenized_datasets = dataset.map(
    tokenize_function, batched=True,
    # For MLM, we can remove the original columns
    remove_columns=dataset['train'].column_names
)

DataCollatorForLanguageModeling will add the "[MASK]" within the captions of the dataset so we can train for our MLM task.

from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,  # Enable masking
    mlm_probability=0.15  # 15% of tokens will be masked
)

Training

We will do a simple training (could go on for more epochs, use the full eltorio/ROCOv2-radiology dataset, or add an early stopping callback).

from transformers import TrainingArguments, Trainer, EarlyStoppingCallback

training_args = TrainingArguments(
    output_dir=output_dir,
    eval_strategy="epoch",
    # eval_strategy="steps",     # vs "steps". Note: `load_best_model_at_end` needs logging and eval to match
    # eval_steps=500,            # Evaluate every 500 steps
    logging_strategy="steps",    # Log metrics every n steps
    logging_steps=100,           # Log every 100 steps
    num_train_epochs=20,

    max_grad_norm=1.0,           # Gradient clipping prevents exploding gradients by scaling them down if they exceed this value
    warmup_ratio=0.1,            # Gradually increases learning rate

    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,

    prediction_loss_only=True,   # Only need to compute loss for MLM task

    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,     # Lower loss is better
    save_total_limit=3,          # Only keep the 3 best checkpoints

    push_to_hub=push_to_hub,

    report_to="none",            # Comment to enable wandb
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    data_collator=data_collator,
    # Could put in an early stopping, if you want
    # callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

To start the training (takes ~15 minutes on an L4 GPU).

trainer.train()

We see the loss in each epoch:

Epoch	Training Loss	Validation Loss
1	1.869300	1.599592
2	1.696800	1.797290
3	1.718700	1.723163
4	1.651800	1.734311
5	1.500300	1.772700
6	1.334600	1.735698
7	1.402900	1.716372
8	1.276200	1.712302
9	1.244100	1.697829
10	1.201600	1.737356
11	1.188700	1.707611
12	1.020500	1.673580
13	1.077100	1.720857
14	1.060700	1.675280
15	0.909000	1.617196
16	0.925500	1.741842
17	0.867600	1.691358
18	0.853300	1.731034
19	0.845000	1.789315
20	0.869000	1.693557

Finally, we upload the fine-tuned weights to HuggingFace: https://huggingface.co/johnpaulett/ModernRadBERT-mlm

trainer.push_to_hub()

Use the Fine-Tuned Model

The results with the fine-tuned model look substantially better than those with the default untuned model. More representative data would be helpful.

predict("ninth [MASK] fracture")
# ninth [MASK] fracture
#   rib: 0.3055
#   tibial: 0.2377
#   femoral: 0.2057
#   oblique: 0.1461
#   hip: 0.1050

predict("pneumonia in the right lower [MASK]")
# pneumonia in the right lower [MASK]
#   lung: 0.7542
#   lobe: 0.0845
#   jaw: 0.0598
#   extremity: 0.0596
#   lip: 0.0419

predict("small bowel [MASK]")
# small bowel [MASK]
#   obstruction: 0.4843
#   resection: 0.2274
#   .: 0.1142
#   colon: 0.0947
#   CT: 0.0793

predict("small [MASK] obstruction")
# small [MASK] obstruction
#   bowel: 0.7788
#   intestine: 0.1237
#   airway: 0.0574
#   vessel: 0.0247
#   intestinal: 0.0154

Conclusion

The MLM task is mighty. While not generative like an LLM, an encoder-only model like the one we have built is much faster. Example use cases in radiology reporting:

Explore the Fine-Tuned Model

You can pull down this fine-tuned model (WARNING: it is trained on a small dataset as a demo, so do not use it for any real problems):

from transformers import pipeline

pipe = pipeline("fill-mask", model="johnpaulett/ModernRadBERT-mlm")

Next Steps

Continue reading the ModernBERT in Radiology series:

Subscribe to get notified when the following Parts of the ModernBERT in Radiology series are published.

Citations