Recreating IRIS: Universal Medical Image Segmentation via In-Context Learning
ðĨ What is Medical Image Segmentation?
Medical image segmentation is a critical computer vision task that involves partitioning medical images into meaningful regions or structures. Unlike natural image segmentation, medical segmentation faces unique challenges:
- High variability: Different imaging modalities (CT, MRI, PET) and anatomical structures
- Class imbalance: Target organs often occupy small portions of the image
- Complex shapes: Irregular anatomical boundaries and structures
- Limited data: Annotated medical data is expensive and requires expert knowledge
ð§ The IRIS Framework
IRIS (In-context Reference Image guided Segmentation) represents a paradigm shift from traditional task-specific models to a universal framework that can adapt to new segmentation tasks without fine-tuning. The key innovation is the decoupling of task definition from inference through in-context learning.
Core Components
ðïļ 3D UNet Encoder
Multi-scale feature extraction with residual blocks and instance normalization. Processes 3D medical volumes through 6 stages with channel progression: [32, 32, 64, 128, 256, 512].
ðŊ Task Encoding Module
Distills task-specific information from reference image-mask pairs into compact embeddings using foreground feature encoding and contextual feature encoding.
ð Query-Based Decoder
Leverages task embeddings to guide segmentation through cross-attention mechanisms, enabling flexible adaptation to different anatomical structures.
ðŧ Technical Implementation
Technology Stack
Architecture: 3D Convolutional Neural Networks
Attention: Multi-head Cross-attention
Optimization: Adam with Lamb optimizer
Augmentation: Spatial and Intensity transforms
Evaluation: DICE coefficient, Surface DICE
Storage: NIfTI format for medical volumes
class IRISModel(nn.Module): """ IRIS model for universal medical image segmentation. Combines 3D UNet encoder, task encoding, and query-based decoder. """ def __init__(self, in_channels=1, base_channels=32, embed_dim=512, num_tokens=10, num_classes=1, num_heads=8): super().__init__() # Image encoder (shared for reference and query) self.encoder = Encoder3D(in_channels, base_channels, num_blocks_per_stage=2) # Task encoding from reference examples self.task_encoder = TaskEncodingModule( encoder_channels[-1], embed_dim, num_tokens ) # Query-based decoder with cross-attention self.decoder = QueryBasedDecoderFixed( encoder_channels, embed_dim, num_classes, num_heads ) def forward(self, query_image, reference_image=None, reference_mask=None): # Extract features from query image query_features = self.encoder(query_image) # Generate task embedding from reference if reference_image is not None: ref_features = self.encoder(reference_image) task_embedding = self.task_encoder(ref_features[-1], reference_mask) # Decode with task guidance segmentation = self.decoder(query_features, task_embedding) return segmentation
Key Architectural Innovations
ð§ Task Encoding Module
The task encoding module consists of two parallel streams:
- Foreground Feature Encoding: High-resolution processing to preserve fine anatomical details
- Contextual Feature Encoding: Learnable query tokens with cross-attention for global context
Understanding DICE in Medical Segmentation
DICE Coefficient Formula
Why DICE for Medical Images?
- Class Imbalance Robust: Focuses on overlap rather than accuracy, crucial when organs occupy small portions of images
- Overlap-Centric: Emphasizes intersection between prediction and ground truth
- Range [0,1]: 0 = no overlap, 1 = perfect overlap
- Differentiable: Can be used as a loss function for backpropagation
Visual Representation of DICE Calculation
ð Training Datasets
Our implementation was trained on a comprehensive collection of 12 medical imaging datasets, covering diverse anatomical structures and imaging modalities:
ðŦ AMOS22
Modalities: CT, MRI
Cases: 500 CT + 100 MRI
Structures: 15 abdominal organs
ðĨ BCV
Modality: CT
Cases: 30 abdominal scans
Structures: 13 abdominal organs
ðĶī LiTS
Modality: CT
Cases: 131 training cases
Target: Liver and tumor segmentation
ðŦ KiTS19
Modality: CT
Cases: 210 cases
Target: Kidney and tumor segmentation
ð M&Ms
Modality: Cardiac MRI
Cases: 320 cases
Structures: LV, RV, myocardium
ð§ Brain Aging
Modality: T1 MRI
Cases: 213 scans
Target: Brain tissue segmentation
ðïļ Implementation Architecture
Directory Structure
abts25/ âââ src/ â âââ models/ â â âââ iris_model.py # Main IRIS model implementation â â âââ encoder_3d.py # 3D UNet encoder â â âââ decoder_3d.py # Query-based decoder â â âââ decoder_3d_fixed.py # Fixed decoder (channel mismatch resolved) â â âââ task_encoding.py # Task encoding module â âââ losses/ â â âââ dice_loss.py # DICE loss implementations â âââ training/ â â âââ [training scripts] â âââ evaluation/ â â âââ [evaluation utilities] â âââ data/ â âââ [data loading utilities] âââ configs/ # Configuration files âââ checkpoints/ # Model checkpoints âââ docs/ # Documentation and paper materials âââ scripts/ # Training and evaluation scripts
Key Features Implemented
â Core Capabilities
- End-to-end Training: Fixed channel mismatch issues for complete gradient flow
- Multiple Inference Modes: One-shot, ensemble, and memory bank strategies
- 3D Volume Processing: Native support for medical 3D imaging
- Multi-modal Support: Works across CT, MRI, PET imaging modalities
- Flexible Task Encoding: Adapts to different anatomical structures
ð Model Performance & Evaluation
Evaluation Metrics
Primary Metrics
- DICE Coefficient: Spatial overlap measure (0-1, higher better)
- Surface DICE: Boundary-based evaluation for surface accuracy
- Hausdorff Distance: Maximum surface distance error
- IoU (Jaccard Index): Intersection over Union measure
Training Configuration
INPUT_CHANNELS = 1 # Grayscale medical images
BASE_CHANNELS = 32 # Encoder base channels
EMBED_DIM = 512 # Task embedding dimension
NUM_TOKENS = 10 # Query tokens for task encoding
NUM_CLASSES = 1 # Binary segmentation
NUM_HEADS = 8 # Multi-head attention heads
# Training Setup
BATCH_SIZE = 4 # Limited by GPU memory for 3D volumes
LEARNING_RATE = 1e-4 # Adam optimizer with Lamb variant
MAX_ITERATIONS = 1000 # Episodic training iterations
LOSS_FUNCTION = DiceLoss + CrossEntropyLoss # Combined loss
ð Getting Started
Installation & Usage
git clone https://github.com/mister-weeden/abts25
cd abts25
# Install dependencies
pip install -e .
# Download datasets
abts25_download_data
# Quick training test
python simple_train.py
# Full training on AMOS22 dataset
python train_amos22.py –data_dir src/data/amos –batch_size 4 –max_iterations 1000
# Evaluation
python evaluate_amos22.py –model_path checkpoints/model.pth
# Compute metrics
abts25_compute_metrics PREDICTIONS_FOLDER -num_processes 4
Model Inference
# Load trained model
model = IRISModel(in_channels=1, embed_dim=512, num_classes=1)
model.load_state_dict(torch.load(‘checkpoints/model.pth’))
# Initialize inference engine
inference = IRISInferenceFixed(model, device=’cuda’)
# One-shot inference with reference
result = inference.one_shot_inference(
query_image=query_volume, # (1, 1, D, H, W)
reference_image=ref_volume, # (1, 1, D, H, W)
reference_mask=ref_mask, # (1, 1, D, H, W)
apply_sigmoid=True,
threshold=0.5
)
# Get segmentation prediction
prediction = result[‘prediction’] # Binary mask
probabilities = result[‘probabilities’] # Confidence scores
ð Technical Challenges & Solutions
â Challenge: Channel Mismatch
Problem: Original decoder had misaligned channels preventing end-to-end training
Solution: Implemented QueryBasedDecoderFixed with proper channel alignment and attention mechanisms
⥠Challenge: Memory Efficiency
Problem: 3D medical volumes require significant GPU memory
Solution: Efficient 3D convolutions, gradient checkpointing, and optimized data loading
ðŊ Challenge: Multi-scale Features
Problem: Medical structures vary greatly in size and shape
Solution: Multi-scale encoder with skip connections and attention-based feature fusion
ð Challenge: Task Generalization
Problem: Adapting to unseen anatomical structures
Solution: In-context learning with reference-guided task embeddings
ð Results & Impact
ð Key Achievements
- Universal Framework: Single model works across 12+ medical imaging datasets
- No Fine-tuning Required: Adapts to new tasks through reference examples
- State-of-the-art Performance: Competitive with task-specific models
- Computational Efficiency: Task embeddings can be precomputed and reused
- Clinical Applicability: Supports real-world medical imaging workflows
ðŪ Future Work & Extensions
- Multi-modal Fusion: Combining CT, MRI, and PET modalities
- Temporal Segmentation: 4D medical image sequences
- Interactive Refinement: User-guided segmentation improvements
- Edge Deployment: Optimized models for clinical deployment
- Uncertainty Quantification: Confidence measures for clinical decision-making
ð Conclusion
The IRIS framework represents a significant advancement in medical image segmentation, moving from task-specific models to a universal framework capable of adapting to new anatomical structures through in-context learning. Our implementation demonstrates that with careful architectural design and proper training strategies, a single model can achieve competitive performance across diverse medical imaging tasks.
- In-context learning enables flexible adaptation without retraining
- Task encoding decouples context definition from inference
- DICE loss is crucial for handling class imbalance in medical images
- 3D architectures are essential for volumetric medical data
- Cross-attention mechanisms effectively integrate task-specific guidance
Repository: mister-weeden/abts25
Paper: “Show and Segment: Universal Medical Image Segmentation via In-Context Learning”
For technical questions or collaboration opportunities, please open an issue on GitHub.