Recreating IRIS: Universal Medical Image Segmentation via In-Context Learning

Project Overview: This blog post details our implementation of the IRIS (In-context Reference Image guided Segmentation) framework for universal medical image segmentation, as documented in our GitHub repository mister-weeden/abts25.

ðŸĨ What is Medical Image Segmentation?

Medical image segmentation is a critical computer vision task that involves partitioning medical images into meaningful regions or structures. Unlike natural image segmentation, medical segmentation faces unique challenges:

  • High variability: Different imaging modalities (CT, MRI, PET) and anatomical structures
  • Class imbalance: Target organs often occupy small portions of the image
  • Complex shapes: Irregular anatomical boundaries and structures
  • Limited data: Annotated medical data is expensive and requires expert knowledge
Input Image 3D Medical Volume 3D UNet Encoder Multi-scale Features Task Encoding Module Reference-guided Context Reference Image Reference Mask Query-based Decoder Cross-attention Output Mask Segmentation IRIS Architecture Overview

🧠 The IRIS Framework

IRIS (In-context Reference Image guided Segmentation) represents a paradigm shift from traditional task-specific models to a universal framework that can adapt to new segmentation tasks without fine-tuning. The key innovation is the decoupling of task definition from inference through in-context learning.

Core Components

🏗ïļ 3D UNet Encoder

Multi-scale feature extraction with residual blocks and instance normalization. Processes 3D medical volumes through 6 stages with channel progression: [32, 32, 64, 128, 256, 512].

ðŸŽŊ Task Encoding Module

Distills task-specific information from reference image-mask pairs into compact embeddings using foreground feature encoding and contextual feature encoding.

🔄 Query-Based Decoder

Leverages task embeddings to guide segmentation through cross-attention mechanisms, enabling flexible adaptation to different anatomical structures.

ðŸ’ŧ Technical Implementation

Technology Stack

Framework: PyTorch 1.x
Architecture: 3D Convolutional Neural Networks
Attention: Multi-head Cross-attention
Optimization: Adam with Lamb optimizer
Data Processing: 3D Medical Imaging
Augmentation: Spatial and Intensity transforms
Evaluation: DICE coefficient, Surface DICE
Storage: NIfTI format for medical volumes

class IRISModel(nn.Module):
    """
    IRIS model for universal medical image segmentation.
    Combines 3D UNet encoder, task encoding, and query-based decoder.
    """
    def __init__(self, in_channels=1, base_channels=32, embed_dim=512, 
                 num_tokens=10, num_classes=1, num_heads=8):
        super().__init__()

        # Image encoder (shared for reference and query)
        self.encoder = Encoder3D(in_channels, base_channels, num_blocks_per_stage=2)

        # Task encoding from reference examples
        self.task_encoder = TaskEncodingModule(
            encoder_channels[-1], embed_dim, num_tokens
        )

        # Query-based decoder with cross-attention
        self.decoder = QueryBasedDecoderFixed(
            encoder_channels, embed_dim, num_classes, num_heads
        )

    def forward(self, query_image, reference_image=None, reference_mask=None):
        # Extract features from query image
        query_features = self.encoder(query_image)

        # Generate task embedding from reference
        if reference_image is not None:
            ref_features = self.encoder(reference_image)
            task_embedding = self.task_encoder(ref_features[-1], reference_mask)

        # Decode with task guidance
        segmentation = self.decoder(query_features, task_embedding)
        return segmentation
        

Key Architectural Innovations

🔧 Task Encoding Module

The task encoding module consists of two parallel streams:

  • Foreground Feature Encoding: High-resolution processing to preserve fine anatomical details
  • Contextual Feature Encoding: Learnable query tokens with cross-attention for global context

Understanding DICE in Medical Segmentation

DICE Coefficient Formula

DICE Coefficient

*** QuickLaTeX cannot compile formula:
  DICE = 2 × |P ∩ G| / (|P| + |G|) 

*** Error message:
Unicode character ∩ (U+2229)
leading text: $  DICE = 2 × |P ∩

P = Prediction, G = Ground Truth, âˆĐ = Intersection DICE Loss = 1 – DICE Coefficient

Why DICE for Medical Images?

  • Class Imbalance Robust: Focuses on overlap rather than accuracy, crucial when organs occupy small portions of images
  • Overlap-Centric: Emphasizes intersection between prediction and ground truth
  • Range [0,1]: 0 = no overlap, 1 = perfect overlap
  • Differentiable: Can be used as a loss function for backpropagation

Visual Representation of DICE Calculation

Ground Truth |G| = Area of Ground Truth Prediction |P| = Area of Prediction Intersection |P âˆĐ G| = Overlap DICE = 2 × Overlap / (|P| + |G|)

📚 Training Datasets

Our implementation was trained on a comprehensive collection of 12 medical imaging datasets, covering diverse anatomical structures and imaging modalities:

ðŸŦ AMOS22

Modalities: CT, MRI
Cases: 500 CT + 100 MRI
Structures: 15 abdominal organs

ðŸĨ BCV

Modality: CT
Cases: 30 abdominal scans
Structures: 13 abdominal organs

ðŸĶī LiTS

Modality: CT
Cases: 131 training cases
Target: Liver and tumor segmentation

ðŸŦ˜ KiTS19

Modality: CT
Cases: 210 cases
Target: Kidney and tumor segmentation

💗 M&Ms

Modality: Cardiac MRI
Cases: 320 cases
Structures: LV, RV, myocardium

🧠 Brain Aging

Modality: T1 MRI
Cases: 213 scans
Target: Brain tissue segmentation

🏗ïļ Implementation Architecture

Directory Structure

abts25/
├── src/
│   ├── models/
│   │   ├── iris_model.py          # Main IRIS model implementation
│   │   ├── encoder_3d.py          # 3D UNet encoder
│   │   ├── decoder_3d.py          # Query-based decoder  
│   │   ├── decoder_3d_fixed.py    # Fixed decoder (channel mismatch resolved)
│   │   └── task_encoding.py       # Task encoding module
│   ├── losses/
│   │   └── dice_loss.py           # DICE loss implementations
│   ├── training/
│   │   └── [training scripts]
│   ├── evaluation/
│   │   └── [evaluation utilities]
│   └── data/
│       └── [data loading utilities]
├── configs/                       # Configuration files
├── checkpoints/                   # Model checkpoints
├── docs/                         # Documentation and paper materials
└── scripts/                      # Training and evaluation scripts

Key Features Implemented

✅ Core Capabilities

  • End-to-end Training: Fixed channel mismatch issues for complete gradient flow
  • Multiple Inference Modes: One-shot, ensemble, and memory bank strategies
  • 3D Volume Processing: Native support for medical 3D imaging
  • Multi-modal Support: Works across CT, MRI, PET imaging modalities
  • Flexible Task Encoding: Adapts to different anatomical structures

📈 Model Performance & Evaluation

Evaluation Metrics

Primary Metrics

  • DICE Coefficient: Spatial overlap measure (0-1, higher better)
  • Surface DICE: Boundary-based evaluation for surface accuracy
  • Hausdorff Distance: Maximum surface distance error
  • IoU (Jaccard Index): Intersection over Union measure
Note: The model uses rank-then-aggregate methodology for final ranking, first averaging DICE and Surface DICE scores across all cases, then using rank-based aggregation for final evaluation.

Training Configuration

# Model Configuration
INPUT_CHANNELS = 1 # Grayscale medical images
BASE_CHANNELS = 32 # Encoder base channels
EMBED_DIM = 512 # Task embedding dimension
NUM_TOKENS = 10 # Query tokens for task encoding
NUM_CLASSES = 1 # Binary segmentation
NUM_HEADS = 8 # Multi-head attention heads

# Training Setup
BATCH_SIZE = 4 # Limited by GPU memory for 3D volumes
LEARNING_RATE = 1e-4 # Adam optimizer with Lamb variant
MAX_ITERATIONS = 1000 # Episodic training iterations
LOSS_FUNCTION = DiceLoss + CrossEntropyLoss # Combined loss

🚀 Getting Started

Installation & Usage

# Clone the repository
git clone https://github.com/mister-weeden/abts25
cd abts25

# Install dependencies
pip install -e .

# Download datasets
abts25_download_data

# Quick training test
python simple_train.py

# Full training on AMOS22 dataset
python train_amos22.py –data_dir src/data/amos –batch_size 4 –max_iterations 1000

# Evaluation
python evaluate_amos22.py –model_path checkpoints/model.pth

# Compute metrics
abts25_compute_metrics PREDICTIONS_FOLDER -num_processes 4

Model Inference

from src.models.iris_model import IRISModel, IRISInferenceFixed

# Load trained model
model = IRISModel(in_channels=1, embed_dim=512, num_classes=1)
model.load_state_dict(torch.load(‘checkpoints/model.pth’))

# Initialize inference engine
inference = IRISInferenceFixed(model, device=’cuda’)

# One-shot inference with reference
result = inference.one_shot_inference(
query_image=query_volume, # (1, 1, D, H, W)
reference_image=ref_volume, # (1, 1, D, H, W)
reference_mask=ref_mask, # (1, 1, D, H, W)
apply_sigmoid=True,
threshold=0.5
)

# Get segmentation prediction
prediction = result[‘prediction’] # Binary mask
probabilities = result[‘probabilities’] # Confidence scores

🔍 Technical Challenges & Solutions

❌ Challenge: Channel Mismatch

Problem: Original decoder had misaligned channels preventing end-to-end training

Solution: Implemented QueryBasedDecoderFixed with proper channel alignment and attention mechanisms

⚡ Challenge: Memory Efficiency

Problem: 3D medical volumes require significant GPU memory

Solution: Efficient 3D convolutions, gradient checkpointing, and optimized data loading

ðŸŽŊ Challenge: Multi-scale Features

Problem: Medical structures vary greatly in size and shape

Solution: Multi-scale encoder with skip connections and attention-based feature fusion

🔄 Challenge: Task Generalization

Problem: Adapting to unseen anatomical structures

Solution: In-context learning with reference-guided task embeddings

📊 Results & Impact

🏆 Key Achievements

  • Universal Framework: Single model works across 12+ medical imaging datasets
  • No Fine-tuning Required: Adapts to new tasks through reference examples
  • State-of-the-art Performance: Competitive with task-specific models
  • Computational Efficiency: Task embeddings can be precomputed and reused
  • Clinical Applicability: Supports real-world medical imaging workflows

ðŸ”Ū Future Work & Extensions

  • Multi-modal Fusion: Combining CT, MRI, and PET modalities
  • Temporal Segmentation: 4D medical image sequences
  • Interactive Refinement: User-guided segmentation improvements
  • Edge Deployment: Optimized models for clinical deployment
  • Uncertainty Quantification: Confidence measures for clinical decision-making

🎓 Conclusion

The IRIS framework represents a significant advancement in medical image segmentation, moving from task-specific models to a universal framework capable of adapting to new anatomical structures through in-context learning. Our implementation demonstrates that with careful architectural design and proper training strategies, a single model can achieve competitive performance across diverse medical imaging tasks.

Key Takeaways:

  • In-context learning enables flexible adaptation without retraining
  • Task encoding decouples context definition from inference
  • DICE loss is crucial for handling class imbalance in medical images
  • 3D architectures are essential for volumetric medical data
  • Cross-attention mechanisms effectively integrate task-specific guidance

Repository: mister-weeden/abts25

Paper: “Show and Segment: Universal Medical Image Segmentation via In-Context Learning”

For technical questions or collaboration opportunities, please open an issue on GitHub.