← ~/blogs/ ▶ Live Demo

Building an Attention Residual UNet for CT Liver Segmentation

2025-01-05

Building an Attention Residual UNet for CT Liver Segmentation

Medical image segmentation has evolved dramatically from manual tracing to fully automated deep learning pipelines. In this post, I walk through my implementation of an Attention Residual UNet for segmenting liver, hepatic vessels, and tumors from CT scans—covering the architecture decisions, data processing, training setup, and lessons learned.

Why This Matters

The liver reveals a lot about a patient's health through medical imaging. CT and MRI scans allow doctors to assess various conditions, but manual segmentation is:

Automating this process can significantly improve diagnostic speed and consistency—which is exactly what this project aims to do.

The Evolution of Segmentation Methods

Before diving into my implementation, it helps to understand what came before:

Traditional Approaches

Thresholding segments regions based on pixel intensity—works when target structures have consistent values, but fails when liver tissue overlaps with surrounding organs on the intensity spectrum.

Region growing starts from seed points and expands based on similarity criteria—requires manual initialization and struggles with inhomogeneous tissues.

The Deep Learning Revolution

CNNs changed everything by automatically learning spatial hierarchies. But vanilla CNNs weren't designed for dense prediction tasks like segmentation.

UNet (2015) introduced the encoder-decoder architecture with skip connections—suddenly, we could produce pixel-level predictions while preserving spatial details. This became the foundation for medical image segmentation.

ResNet solved the vanishing gradient problem in deep networks through residual connections—enabling much deeper architectures that could learn more complex features.

My Architecture: Attention Residual UNet

I combined all three innovations into a single architecture:

The UNet Backbone

The classic encoder-decoder structure:

Encoder Path: Multiple convolutional blocks, each with two conv layers + ReLU + batch norm, followed by max pooling that halves spatial dimensions. This captures increasingly abstract features.

Decoder Path: Mirrors the encoder with upsampling (transposed convolutions) to reconstruct spatial resolution. Each level combines upsampled features with corresponding encoder features via skip connections.

Skip Connections: The key insight—directly connect encoder blocks to decoder blocks at each level:

D_i = f(E_i) + g(D_{i+1})

This preserves high-resolution details crucial for precise boundaries.

Adding Residual Connections

I replaced standard conv blocks with residual blocks throughout:

y = x + F(x, {W_i})

The input x is added directly to the block's output, allowing gradients to flow through the shortcut. This stabilizes training in deeper networks and helps preserve fine anatomical details through the network's depth.

Attention Mechanisms

Here's where it gets interesting. I added spatial attention gates just before each skip connection in the decoder:

F_att = α ⊙ F

The attention map α is computed using scaled dot-product attention:

Attention(Q, K, V) = softmax(QK^T / √d_k) × V

This forces the network to focus on relevant regions (liver boundaries, vessels) while suppressing irrelevant background. The result: sharper segmentation boundaries and better handling of ambiguous regions.

Data Processing

I used the Medical Segmentation Decathlon dataset—a standardized benchmark for medical segmentation tasks.

Liver Dataset

Hepatic Vessel Dataset

The slice threshold was a practical choice—very large volumes significantly increase training time without proportional improvement in segmentation quality.

Training Setup

Configuration

Parameter Value
Batch Size 32
Initial Learning Rate 1×10⁻⁴
Optimizer Adam
LR Scheduler Reduce by 0.5 if no improvement for 3 epochs
Early Stopping Stop if validation loss doesn't improve for 5 epochs

Loss Function: Dice-Cross Entropy

I used a combined loss that balances two objectives:

Dice Loss: Directly optimizes the overlap metric we care about—the Dice coefficient:

Dice = (2 × |P ∩ T|) / (|P| + |T|)

Cross-Entropy: Provides stable gradients and handles class imbalance better during early training.

Model Capacity

Features double at each encoder level, so the bottleneck has 128/256 features respectively.

Results

After 25 epochs (liver) and 18 epochs (vessels), the models achieved:

Task Structure Dice Score
Liver Segmentation Liver 0.88
Liver Segmentation Tumor 0.73
Vessel Segmentation Hepatic Vessels 0.73
Vessel Segmentation Tumor 0.66

What Worked

The liver segmentation at 0.88 Dice is solid—the model captures boundaries accurately for this larger, well-defined structure. Training and validation loss converged smoothly, indicating good generalization.

What Didn't

Tumor segmentation struggled (0.66-0.73 Dice). The main culprit: class imbalance. Tumors occupy a tiny fraction of the total image volume compared to background. The model learns to predict "not tumor" everywhere because that's mostly correct.

Lessons Learned

Class Imbalance is Real

Even with Dice loss (which handles imbalance better than pure cross-entropy), small structures get overwhelmed. Future work should explore:

The Cascaded Approach

Training separate models for liver (stage 1) and tumors within liver ROI (stage 2) could improve results by:
1. Reducing the search space for tumor detection
2. Allowing stage-specific hyperparameters
3. More efficient use of compute

Attention Helps, But Isn't Magic

Attention gates improved boundary sharpness and helped with ambiguous regions, but couldn't overcome fundamental class imbalance. They're a tool, not a solution.

Try It Yourself

I've deployed an interactive demo at /demos/ct-segmentation where you can:
- Upload NIfTI files
- Visualize 3D segmentation results
- See the model's predictions in real-time

The full code is available on GitHub.

What's Next

Potential improvements I'm exploring:
- nnU-Net: Self-configuring framework that automatically adapts to each dataset
- Transformer-based architectures: UNETR, Swin-UNETR for better long-range dependencies
- Semi-supervised learning: Leverage unlabeled CT scans to improve generalization

Medical imaging is a fascinating application of deep learning—high stakes, limited data, and real clinical impact. The gap between research benchmarks and production-ready systems is significant, but projects like this help bridge it one step at a time.