HTML conversions sometimes display errors due to content that did not convert correctly from the source. This paper uses the following packages that are not yet supported by the HTML conversion tool. Feedback on these issues are not necessary; they are known and are being worked on.

  • failed: nicematrix
  • failed: stackrel
  • failed: mwe

Authors: achieve the best HTML results from your LaTeX submissions by following these best practices.

License: arXiv.org perpetual non-exclusive license
arXiv:2309.03493v4 [eess.IV] 05 Mar 2024

SAM3D: Segment Anything Model in Volumetric Medical Images

Abstract

Image segmentation remains a pivotal component in medical image analysis, aiding in the extraction of critical information for precise diagnostic practices. With the advent of deep learning, automated image segmentation methods have risen to prominence, showcasing exceptional proficiency in processing medical imagery. Motivated by the Segment Anything Model (SAM)—a foundational model renowned for its remarkable precision and robust generalization capabilities in segmenting 2D natural images—we introduce SAM3D, an innovative adaptation tailored for 3D volumetric medical image analysis. Unlike current SAM-based methods that segment volumetric data by converting the volume into separate 2D slices for individual analysis, our SAM3D model processes the entire 3D volume image in a unified approach. Extensive experiments are conducted on multiple medical image datasets to demonstrate that our network attains competitive results compared with other state-of-the-art methods in 3D medical segmentation tasks while being significantly efficient in terms of parameters. Code and checkpoints are available at https://github.com/UARK-AICV/SAM3D.

**footnotetext: Equal contribution

Index Terms—  3D Medical Segmentation, Foundation Model, Transfer Learning, Segment Anything Model

1 Introduction

Volumetric segmentation is crucial in medical image analysis, finding applications in pathology diagnosis, surgical planning, and computer-aided diagnosis. Volumetric medical images like CT, MRI, OCT, and DBT offer a 3D view of anatomical structures. Segmentation identifies regions of interest for better interpretation.

Deep learning, particularly UNet [1] and variants [2, 3, 4], made strides in 3D medical segmentation but faced limitations. Transformer-based models like Vision Transformer (ViT) [5] and Swin-UNet [4] showed promise in capturing long-range relationships. Combining CNNs and Transformers in models like TransUNet [3], UNETR [6], and HiFormer [7], yielded promising results. However, these models prioritize precision, leading to increased complexity and training time. Leveraging pretrained models offers an alternative. SAM, a transformer-based model pretrained on large-scale datasets, has shown generalizability in segmentation tasks. SAM-based models for medical images piqued interest.

This work introduces SAM3D, an architecture for volumetric medical segmentation, combining the SAM encoder and a lightweight 3D CNN decoder. Unlike traditional slice-by-slice processing, SAM3D extracts features across the entire volume, improving segmentation while maintaining simplicity and computational efficiency. Contributions include applying the SAM encoder to process 3D volumes, designing SAM3D for effective 3D medical segmentation, and validating its performance on various datasets, such as ACDC [8], Synapse [9], MSD BraTS [10], and MSD Lung [10]. SAM3D demonstrates competitive results, marking a novel approach to 3D volumetric imaging.

Refer to caption
Fig. 1: Overall architecture of the proposed SAM3D. Given a volumetric image IH×W×D𝐼superscript𝐻𝑊𝐷I\in\mathbb{R}^{H\times W\times D}italic_I ∈ blackboard_R start_POSTSUPERSCRIPT italic_H × italic_W × italic_D end_POSTSUPERSCRIPT, SAM3D initially applies SAM to process each of the D𝐷Ditalic_D slices individually, producing slice embeddings denoted as FH16×W16×D×256𝐹superscript𝐻16𝑊16𝐷256F\in\mathbb{R}^{\frac{H}{16}\times\frac{W}{16}\times D\times 256}italic_F ∈ blackboard_R start_POSTSUPERSCRIPT divide start_ARG italic_H end_ARG start_ARG 16 end_ARG × divide start_ARG italic_W end_ARG start_ARG 16 end_ARG × italic_D × 256 end_POSTSUPERSCRIPT. These embeddings are then decoded by a lightweight 3D decoder, ultimately yielding the segmentation prediction.

2 Related Works

Segmentation Methods using CNNs and Transformer: Various methods leverage a combination of CNNs and Transformer architectures for segmentation tasks. TransUNet [3] integrates CNNs and Transformer within a U-shaped architecture to capture both local and global information. Swin-Unet [4] replaces U-Net’s convolutional blocks with Swin Transformer blocks. nnUNet [2] introduces a self-adapting framework for 2D and 3D medical segmentation. MISSFormer [11] enhances hierarchical feature representation using Enhanced Transformer Blocks. TransDeepLab [12] combines Swin Transformer blocks with ASPP module and cross-contextual attention. HiFormer [7] introduces the Double-Level Fusion (DLF) module. UNETR [6] encodes 3D input patches with Transformers and combines feature extraction with a CNNs-based decoder. Swin UNETR [13] enhances UNETR with Swin Transformer blocks. nnFormer [14] interleaves CNNs and Transformer blocks with feature pyramids. UNETR++ [15] introduces the efficient paired-attention (EPA) module.

Segment Anything Model (SAM) in medical: SAM [16] is a foundational model for natural image segmentation that can be guided by prompts. It comprises an image encoder, prompt encoder, and lightweight mask decoder, trained on promptable segmentation tasks. SAMed [17] adapts SAM to medical images using a series of finetuning strategies. MedSAM [18] re-trains SAM on a union of medical image datasets.

In contrast to existing approaches that involve fine-tuning SAM and handling 3D images as sets of 2D slices, and unlike the conventional CNNs/Transformer-based methods that typically require large model designs, our proposed SAM3D effectively and efficiently harnesses SAM’s capabilities for 3D medical segmentation. It does so without the need for large model architectures or depending on slice-by-slice predictions. This approach enhances the model’s ability to perceive anatomical structures and capture global information.

3 Method

In this section, we introduce our model, SAM3D, and explain the rationale behind its simple design. Our goal is to leverage SAM without the need for extensive parameter retraining or complex task-specific modules.

Overall Architecture. SAM was trained on an extensive dataset comprising 1 million images and 1.1 billion masks, and it features a robust image encoder tailored for natural images. However, applying SAM directly to 3D medical images poses challenges due to inherent domain differences. We posit that the SAM image encoder retains valuable low-level features, e.g. edges and boundaries, which have relevance across various image domains.

In contrast to SAMed [17] and MedSAM [18], where all three components of SAM are fine-tuned, our approach involves freezing SAM’s image encoder and training a new lightweight 3D decoder. SAM3D leverages SAM by initially processing images slice by slice and then incorporating a lightweight 3D decoder to capture depth-wise relationships between slices. The overall architecture of SAM3D is depicted in Figure 1 and can be summarized as follows: a volumetric input IH×W×D𝐼superscript𝐻𝑊𝐷I\in\mathbb{R}^{H\times W\times D}italic_I ∈ blackboard_R start_POSTSUPERSCRIPT italic_H × italic_W × italic_D end_POSTSUPERSCRIPT is divided into D𝐷Ditalic_D 2D slices, each of dimension H×W𝐻𝑊H\times Witalic_H × italic_W. We duplicate each channel three times to generate the slices that have dimension of H×W×3𝐻𝑊3H\times W\times 3italic_H × italic_W × 3. The pretrained SAM encoder processes these slices, generating 3D slice embeddings denoted as F𝐹Fitalic_F. The depth-wise relationships among these slice embeddings are effectively captured by our proposed 3D decoder. Additionally, we remove the prompt encoder from SAM to ensure that feature extraction remains uninhibited across different modalities.

Encoder. SAM’s image encoder extracts robust low-level information. Thus, it is plausible to tackle the notorious weak boundary in the medical image domain by using features extracted by SAM’s image encoder. Formally, let IH×W×D𝐼superscript𝐻𝑊𝐷I\in\mathbb{R}^{H\times W\times D}italic_I ∈ blackboard_R start_POSTSUPERSCRIPT italic_H × italic_W × italic_D end_POSTSUPERSCRIPT be the input, and Enc𝐸𝑛𝑐Encitalic_E italic_n italic_c represent the slice encoder. We split I𝐼Iitalic_I into D𝐷Ditalic_D slices Iisubscript𝐼𝑖I_{i}italic_I start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT along the depth dimension, each slice is in 3×H×W3𝐻𝑊3\times H\times W3 × italic_H × italic_W, and feed them into Enc𝐸𝑛𝑐Encitalic_E italic_n italic_c. The output slice embeddings are stacked and transposed to obtain the final 3D slice embeddings F=[fi]i=1D𝐹superscriptsubscriptdelimited-[]subscript𝑓𝑖𝑖1𝐷F=[f_{i}]_{i=1}^{D}italic_F = [ italic_f start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT.

fi=Enc(Ii),where fiH16×W16×256formulae-sequencesubscript𝑓𝑖𝐸𝑛𝑐subscript𝐼𝑖where subscript𝑓𝑖superscript𝐻16𝑊16256f_{i}=Enc(I_{i}),\text{where }f_{i}\in\mathbb{R}^{\frac{H}{16}\times\frac{W}{1% 6}\times 256}italic_f start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_E italic_n italic_c ( italic_I start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , where italic_f start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT divide start_ARG italic_H end_ARG start_ARG 16 end_ARG × divide start_ARG italic_W end_ARG start_ARG 16 end_ARG × 256 end_POSTSUPERSCRIPT (1)

We stack these slice embeddings and transpose the result to obtain the final 3D slice embedding, F=[fi]i=1D,FH16×W16×D×256formulae-sequence𝐹superscriptsubscriptdelimited-[]subscript𝑓𝑖𝑖1𝐷𝐹superscript𝐻16𝑊16𝐷256F=[f_{i}]_{i=1}^{D},F\in\mathbb{R}^{\frac{H}{16}\times\frac{W}{16}\times D% \times 256}italic_F = [ italic_f start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT , italic_F ∈ blackboard_R start_POSTSUPERSCRIPT divide start_ARG italic_H end_ARG start_ARG 16 end_ARG × divide start_ARG italic_W end_ARG start_ARG 16 end_ARG × italic_D × 256 end_POSTSUPERSCRIPT.

Decoder. Because our decoder must handle 3D volumetric data, we cannot utilize SAM’s mask decoder, which is specifically designed for 2D natural images. Instead, we propose the development of an appropriate 3D decoder. However, creating a 3D network with the Vision Transformer [5] and its variants can be resource-intensive, requiring significant computational power and increasing inference time, especially when dealing with a large value of D. Therefore, we suggest the design of a lightweight 3D decoder comprising four 3D convolutional blocks with skip connections [19] and a segmentation head, as elaborated in Figure 2.

Refer to caption
Fig. 2: Architecture of the proposed lightweight 3D decoder.
Table 1: Quantitative results on Synapse dataset.
SAM Networks Methods Average DSC on individual abdominal organs
Params \downarrow HD \downarrow DSC \uparrow RKid LKid Spl Gal Sto Pan Aor Liv
2D TransUNet [3] 96.07M 31.69 77.49 77.02 81.87 85.08 63.16 75.62 55.86 87.23 94.08
Swin-Unet [4] 27.17M 21.55 79.13 79.61 83.28 90.66 66.53 76.60 56.58 85.47 94.29
TransDeepLab [12] 21.14M 21.25 80.16 79.88 84.08 89.00 69.16 78.40 61.19 86.04 93.53
HiFormer-S [7] 23.25M 18.85 80.29 64.84 82.39 91.03 73.29 78.07 60.84 85.63 94.22
HiFormer-B [7] 25.51M 14.70 80.39 79.77 85.23 90.99 65.69 81.08 59.52 86.21 94.61
HiFormer-L [7] 29.52M 19.14 80.69 78.37 84.23 90.44 68.61 82.03 60.77 87.03 94.07
3D MISSFormer [11] - 18.20 81.96 82.00 85.21 91.92 68.65 80.81 65.67 86.99 94.41
nnFormer [14] 150.50M 10.63 86.57 86.25 86.57 90.51 70.17 86.83 83.35 92.04 96.84
UNETR [6] 92.49M 18.59 78.35 84.52 85.60 85.00 56.30 70.46 60.47 89.80 94.57
UNETR++ [15] 42.95M 7.53 87.22 87.18 87.54 95.77 71.25 86.01 81.10 92.52 96.42
2D SAMed [17] 18.81M 20.64 81.88 79.95 80.45 88.72 69.11 82.06 72.17 87.77 94.80
SAMed_s [17] 6.32M 31.72 77.78 78.92 79.63 85.81 57.11 77.49 65.66 83.62 93.98
3D SAM3D (Ours) 1.88M 17.87 79.56 85.64 86.31 84.29 49.81 76.11 69.32 89.57 95.42

Objective Function. We train our SAM3D network with a combination loss of both the dice loss and cross-entropy loss. The formulation is as follows:

(Y,Y^)=n=1Nk=1K(2×Yk,nY^k,nYk,n2+Y^k,n2+Yk,nlogY^k,n)𝑌^𝑌superscriptsubscript𝑛1𝑁superscriptsubscript𝑘1𝐾2subscript𝑌𝑘𝑛subscript^𝑌𝑘𝑛superscriptsubscript𝑌𝑘𝑛2superscriptsubscript^𝑌𝑘𝑛2subscript𝑌𝑘𝑛𝑙𝑜𝑔subscript^𝑌𝑘𝑛\small\begin{split}\mathcal{L}(Y,\hat{Y})=-\sum_{n=1}^{N}\sum_{k=1}^{K}(\frac{% 2\times Y_{k,n}\hat{Y}_{k,n}}{Y_{k,n}^{2}+\hat{Y}_{k,n}^{2}}+Y_{k,n}log\hat{Y}% _{k,n})\end{split}start_ROW start_CELL caligraphic_L ( italic_Y , over^ start_ARG italic_Y end_ARG ) = - ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ( divide start_ARG 2 × italic_Y start_POSTSUBSCRIPT italic_k , italic_n end_POSTSUBSCRIPT over^ start_ARG italic_Y end_ARG start_POSTSUBSCRIPT italic_k , italic_n end_POSTSUBSCRIPT end_ARG start_ARG italic_Y start_POSTSUBSCRIPT italic_k , italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + over^ start_ARG italic_Y end_ARG start_POSTSUBSCRIPT italic_k , italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG + italic_Y start_POSTSUBSCRIPT italic_k , italic_n end_POSTSUBSCRIPT italic_l italic_o italic_g over^ start_ARG italic_Y end_ARG start_POSTSUBSCRIPT italic_k , italic_n end_POSTSUBSCRIPT ) end_CELL end_ROW (2)

here, Y𝑌Yitalic_Y is the predicted segmenting result from SAM3D, and Y^^𝑌\hat{Y}over^ start_ARG italic_Y end_ARG is the ground truth. N𝑁Nitalic_N represents the number of classes, K𝐾Kitalic_K denotes the number of voxels, and Yk,nsubscript𝑌𝑘𝑛Y_{k,n}italic_Y start_POSTSUBSCRIPT italic_k , italic_n end_POSTSUBSCRIPT and Y^k,nsubscript^𝑌𝑘𝑛\hat{Y}_{k,n}over^ start_ARG italic_Y end_ARG start_POSTSUBSCRIPT italic_k , italic_n end_POSTSUBSCRIPT refer to the predictions and the ground truths at voxel j𝑗jitalic_j for class i𝑖iitalic_i, respectively.

Additionally, we employ the deep supervision technique for multiple decoding stages. Specifically, the output features of each decoding stage pass through a segmentation block, consisting of one 3 x 3 x 3 and one 1 x 1 x 1 convolution layer, to generate predictions for one typical stage. To calculate the loss value for one typical stage, we down-sample the ground truth to match the prediction resolution. Consequently, the final loss can be defined as follows:

total=l=1Lαl×lsubscript𝑡𝑜𝑡𝑎𝑙superscriptsubscript𝑙1𝐿subscript𝛼𝑙subscript𝑙\mathcal{L}_{total}=\sum_{l=1}^{L}\alpha_{l}\times\mathcal{L}_{l}caligraphic_L start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT × caligraphic_L start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT (3)

here, L is set to 3, representing the number of decoder layers. αlsubscript𝛼𝑙\alpha_{l}italic_α start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT signifies the hyperparameter controlling the contribution of different resolutions to the final loss function. In practice, we set α2=α12subscript𝛼2subscript𝛼12\alpha_{2}=\frac{\alpha_{1}}{2}italic_α start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = divide start_ARG italic_α start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG and α3=α14subscript𝛼3subscript𝛼14\alpha_{3}=\frac{\alpha_{1}}{4}italic_α start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT = divide start_ARG italic_α start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 4 end_ARG with all α𝛼\alphaitalic_α hyperparameters normalized to 1.

4 Experiments

A. Datasets and Evaluation Metrics.

Datasets: We conduct the experiments on four datasets: Multi-organ CT Segmentation (Synapse) [9], Automated Cardiac Diagnosis (ACDC) [8], Brain Tumor Segmentation (BraTS) [10], and Lung Tumor Segmentation (Lung) [10]. BraTS and Lung come from the Medical Segmentation Decathlon challenge (MSD) [10]. For a fair comparison, we follow the data splitting of previous works, e.g. nnFormer [14] and UNETR++ [15].

Table 2: Quantitative results on ACDC dataset.
Methods Average DSC on individual regions
Paramsnormal-↓\downarrow DSC \uparrow    RV     LV MYO
TransUNet [3] 96.07M 89.71 88.86 84.54 95.73
Swin-Unet [4] 27.17M 90.00 88.55 85.62 95.83
UNETR [6] 92.49M 86.61 85.29 86.52 94.02
MISSFormer [11] - 87.90 86.36 85.75 91.59
nnFormer [14] 150.5M 92.06 90.94 89.58 95.65
UNETR++ [15] 66.80M 92.83 91.89 90.61 96.00
SAM3D (Ours) 1.88M 90.41 89.44 87.12 94.67
Table 3: Quantitative results on Lung dataset.
Methods Params normal-↓\downarrow Average DSC normal-↑\uparrow
nnUNet [2] - 74.31
Swin UNETR [13] 62.83M 75.55
nnFormer [14] 150.5M 77.95
UNETR [6] 92.49M 73.29
UNETR++ [15] 121.17M 80.68
SAM3D (Ours) 1.88M 71.42
Table 4: Quantitative results on BraTS dataset.
Methods Average WT ET TC
Paramsnormal-↓\downarrow HD \downarrow DSC \uparrow HD \downarrow DSC \uparrow HD \downarrow DSC \uparrow HD \downarrow DSC \uparrow
TransUNet [3] 96.07M 12.98 64.4 14.03 70.6 10.42 54.2 14.50 68.4
UNETR [6] 92.49M 8.82 71.1 8.27 78.9 9.35 58.5 8.85 76.1
nnFormer [14] 150.5M 4.05 86.4 3.80 91.3 3.87 81.8 4.49 86.0
UNETR++ [15] 42.65M 5.85 77.7 4.79 91.2 4.22 78.5 6.78 78.4
SAM3D (Ours) 4.63M 8.72 72.9 6.03 88.0 10.05 69.6 9.79 76.6
Table 5: Ablation study of the skip connection in our lightweight 3D decoder on ACDC and Synapse datasets.
Einstellungen Average DSC on individual regions
DSC normal-↑\uparrow     RV     LV MYO
w/o skip connection 89.73 88.46 94.41 86.32
w skip connection 90.41 89.44 94.67 87.12
(a) ACDC dataset.
Einstellungen Average DSC on individual abdominal organs
HDnormal-↓\downarrow DSCnormal-↑\uparrow RKid LKid Spl Gal Sto Pan Aor Liv
w/o skip connection 25.87 79.33 84.68 85.20 85.26 50.55 75.07 68.83 90.10 94.98
w skip connection 17.87 79.56 85.64 86.31 84.29 49.81 76.11 69.32 89.57 95.42
(b) Synapse dataset.

Metrics: We evaluate the network’s accuracy using the Dice Similarity Coefficient (DSC) and the 95% Hausdorff Distance (HD95), while the network’s complexity is measured by the number of trainable parameters (#params). The HD95 is calculated based on the 95th percentile of the distances between the boundaries of predictions and ground truths.

B. Implementation Details.

Our model is implemented based on Python 3.8.10 with PyTorch library and trained on a single NVIDIA RTX 2080 Ti GPU with 11GB memory. We use ViT-B version as our backbone for the SAM’s image encoder due to the limited resources. Instead of exhaustively finding an overfitting training procedure, we trained our model with the general training strategy of nnFormer [14] and UNETR++ [15], the stochastic gradient descent (SGD) with a momentum of 0.99 and a weight decay of 3e-5. The learning rate scheduler is defined as lr=init_lr×(1epochmax_epoch)power𝑙𝑟𝑖𝑛𝑖𝑡_𝑙𝑟superscript1𝑒𝑝𝑜𝑐𝑚𝑎𝑥_𝑒𝑝𝑜𝑐𝑝𝑜𝑤𝑒𝑟lr=init\_lr\times(1{}-{}\frac{epoch}{max\_epoch})^{power}italic_l italic_r = italic_i italic_n italic_i italic_t _ italic_l italic_r × ( 1 - divide start_ARG italic_e italic_p italic_o italic_c italic_h end_ARG start_ARG italic_m italic_a italic_x _ italic_e italic_p italic_o italic_c italic_h end_ARG ) start_POSTSUPERSCRIPT italic_p italic_o italic_w italic_e italic_r end_POSTSUPERSCRIPT, where init_lr𝑖𝑛𝑖𝑡_𝑙𝑟init\_lritalic_i italic_n italic_i italic_t _ italic_l italic_r = 1e-2, power𝑝𝑜𝑤𝑒𝑟poweritalic_p italic_o italic_w italic_e italic_r = 0.9, and max_epoch𝑚𝑎𝑥_𝑒𝑝𝑜𝑐max\_epochitalic_m italic_a italic_x _ italic_e italic_p italic_o italic_c italic_h = 1000. One epoch consists of 250 iterations. For ACDC, Synapse, BraTS, and Lung datasets, SAM3D is trained with the 3D volume sizes of 160 x 160 x 14, 176 x 176 x 64, 64 x 64 x 64 and 192 x 192 x 34, respectively. We also utilize the same data augmentation techniques including rotation, scaling, brightness adjustment, gamma augmentation, and mirroring. The batch size is set to 4 for ACDC and 2 for Synapse, BraTS, and Lung.

C. Performance Comparisons.

We compared our SAM3D with recent SOTA methods on both CNNs-based networks, e.g. nnFormer [14] and Transformer-based networks, e.g. TransUNet [3], Swin-Unet [4], TransDeepLab [12], HiFormer [7], MISSFormer [11], UNETR [6] and SAM-based models SAMed and SAMed_s [17]. The performance comparisons are reported in Tables 1, 2, 3, and 4 including both accuracy (i.e. HD95 and DSC metrics) and network complexity (#params).

Synapse comprises eight abdominal organs in a large dataset and the performance comparison is shown in Table 1. Among the models evaluated, UNETR++ (a Transformer-based model) achieved the best results with 42.9M parameters, while nnFormer ranked second with 150.5M parameters. Notably, SAMed_s distinguishes itself by achieving impressive results with a modest 6.32M parameters and a DSC of 77.78%. SAMed_s shares a similar architecture with our SAM3D, fine-tuned from SAM, but differs in processing methods. SAMed_s employs a straightforward slice-by-slice approach, while SAM3D considers depth-wise information. Despite this difference, both models are efficient in parameter usage. SAMed_s requires 6.32M parameters, whereas SAM3D excels with just 1.88M parameters. Furthermore, SAM3D achieves a DSC score exceeding 1.78%, demonstrating superior performance compared to SAM-based methods with lightweight models.

While SAMed is exclusive to the Synapse dataset, our SAM3D can be evaluated on a variety of other datasets, including Cardiac, Brain Tumor, and Lung. In Table 2 and 3, it is evident that SAM3D competes favorably with SOTA CNNs/Transformer-based networks on the Cardiac ACDC and Lung datasets. For instance, SAM3D surpasses TransUnet’s performance on the ACDC dataset with a 0.41% increase in DSC while utilizing less than 50×\times× the number of parameters. Table 4 further illustrates SAM3D’s competitiveness with other leading models on the Brain Tumor Brats dataset, despite its significantly lower parameter count. For example, SAM3D achieves a 1.8% DSC improvement compared to UNETR, while requiring less than 20×\times× the number of params. It is worth noting that the MRI scans in Brats contain four modalities, which explains SAM3D’s parameter count being four times that of other single-modality models.

Fig. 3 visually presents samples from the Synapse dataset. In this illustration, we compare our approach (in the third column) with the outcomes obtained from SAMed and SAMed_s, which represent SOTA in SAM-based methods for volumetric medical image segmentation. Despite the reduced trainable params, SAM3D exhibits superior segmentation performance compared to the other two methods.

Refer to caption
Fig. 3: Qualitative comparison between our SAM3D (3rdsuperscript3𝑟𝑑3^{rd}3 start_POSTSUPERSCRIPT italic_r italic_d end_POSTSUPERSCRIPT column) and other SAM-based volumetric segmentation models SAMed (4thsuperscript4𝑡4^{th}4 start_POSTSUPERSCRIPT italic_t italic_h end_POSTSUPERSCRIPT column) and SAMed_s (5thsuperscript5𝑡5^{th}5 start_POSTSUPERSCRIPT italic_t italic_h end_POSTSUPERSCRIPT column) on Synapse dataset. SAMed and SAMed_s require 18.81M and 6.32M params whereas our SAM3D needs only 1.88M.

D. Ablation Study.

To assess the impact of skip connections in our proposed lightweight 3D decoder, we conducted an ablation study on ACDC and Synapse datasets as depicted in Table 5. The results clearly indicate that these skip connections contribute positively to the model’s performance, resulting in improvement. We believe that these skip connections play a crucial role in preserving information related to edges and boundaries from lower-level features, enhancing the precision of the segmentation process.

5 Conclusion

In this study, we introduce SAM3D, an efficient and simple SAM-based model tailored for volumetric medical image segmentation. Our approach harnesses the capabilities of a SAM pre-trained encoder coupled with a lightweight 3D decoder. Through extensive experimentation, we have established that SAM3D competes effectively with current SOTA 3D neural networks and Transformer-based models while demanding significantly fewer parameters (50×\times× fewer). Furthermore, SAM3D outperforms other lightweight networks in the context of volumetric segmentation. As SAM has already made a substantial impact on natural image segmentation, our research extends its potential to the domain of medical image segmentation. We anticipate that this work will serve as an inspiration for future researchers, fostering advancements in the field of medical segmentation

Discussion. In our experiments, we employed the smallest SAM variant, which utilizes ViT-B backbone, primarily due to resource and time constraints. We hypothesize that ViT-L and ViT-H pre-trained models may yield even more remarkable results. Consequently, we encourage researchers to explore these options for our segmentation task.

Additionally, our simple decoder leaves room for developing a more complex architecture, which could potentially enhance the model’s performance. This presents a promising avenue for further research and development.

6 Acknowledgement

Nhat-Tan Bui and Ngan Le are supported by the National Science Foundation (NSF) under Award No OIA-1946391 RII Track-1, NSF 1920920 RII Track 2 FEC, NSF 2223793 EFRI BRAID, NSF 2119691 AI SUSTEIN, NSF 2236302. Minh-Triet Tran is sponsored by Vietnam National University Ho Chi Minh City (VNU-HCM) under grant number DS2020-42-01. Dinh-Hieu Hoang is funded by Vingroup Joint Stock Company and supported by the Domestic Master/ PhD Scholarship Programme of Vingroup Innovation Foundation (VINIF), Vingroup Big Data Institute (VINBIGDATA), code VINIF.2022.ThS.JVN.04.

References

  • [1] Olaf Ronneberger, Philipp Fischer, and Thomas Brox, “U-Net: Convolutional Networks for Biomedical Image Segmentation,” in MICCAI, 2015.
  • [2] Fabian Isensee, Paul Jaeger, Simon Kohl, Jens Petersen, and Klaus Maier-Hein, “nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation,” Nature Methods, vol. 18, 2021.
  • [3] Chen, Jieneng and Lu, Yongyi and Yu, Qihang and Luo, Xiangde and Adeli, Ehsan and Wang, Yan and Lu, Le and Yuille, Alan L., and Zhou, Yuyin, “TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation,” arXiv preprint arXiv:2102.04306, 2021.
  • [4] Hu Cao, Yueyue Wang, Joy Chen, Dongsheng Jiang, Xiaopeng Zhang, Qi Tian, and Manning Wang, “Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation,” in ECCVW, 2022.
  • [5] Alexey Dosovitskiy, Lucas Beyer, et al., “An image is worth 16x16 words: Transformers for image recognition at scale,” arXiv preprint arXiv:2010.11929, 2020.
  • [6] A. Hatamizadeh, Y. Tang, V. Nath, D. Yang, A. Myronenko, B. Landman, H. R. Roth, and D. Xu, “UNETR: Transformers for 3D Medical Image Segmentation,” in WACV, 2022.
  • [7] Moein Heidari, Amirhossein Kazerouni, Milad Soltany, Reza Azad, Ehsan Khodapanah Aghdam, Julien Cohen-Adad, and Dorit Merhof, “Hiformer: Hierarchical multi-scale representations using transformers for medical image segmentation,” in WACV, 2023.
  • [8] Olivier Bernard, Alain Lalande, et al., “Deep learning techniques for automatic mri cardiac multi-structures segmentation and diagnosis: Is the problem solved?,” IEEE Transactions on Medical Imaging, 2018.
  • [9] Bennett Landman, Zhoubing Xu, J Igelsias, Martin Styner, T Langerak, and Arno Klein, “Multi-Atlas Labeling Beyond the Cranial Vault - Workshop and Challenge,” in MICCAI Multi-Atlas Labeling Beyond Cranial Vault—Workshop Challenge, 2015.
  • [10] Amber L. Simpson, Michela Antonelli, et al., “A large annotated medical image dataset for the development and evaluation of segmentation algorithms,” arXiv preprint arXiv:1902.09063, 2019.
  • [11] Xiaohong Huang, Zhifang Deng, Dandan Li, and Xueguang Yuan, “MISSFormer: An Effective Medical Image Segmentation Transformer,” arXiv preprint arXiv:2109.07162, 2021.
  • [12] Reza Azad, Moein Heidari, Moein Shariatnia, Ehsan Khodapanah Aghdam, Sanaz Karimijafarbigloo, Ehsan Adeli, and Dorit Merhof, “TransDeepLab: Convolution-Free Transformer-based DeepLab v3+ for Medical Image Segmentation,” in PRIME, 2022.
  • [13] Ali Hatamizadeh, Vishwesh Nath, Yucheng Tang, Dong Yang, Holger Roth, and Daguang Xu, “Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images,” arXiv preprint arXiv:2201.01266, 2022.
  • [14] Hong-Yu Zhou, Jiansen Guo, Zhang Yinghao, Lequan Yu, Liansheng Wang, and Yizhou Yu, “nnFormer: Interleaved Transformer for Volumetric Segmentation,” arXiv preprint arXiv:2109.03201, 2021.
  • [15] Abdelrahman Shaker, Muhammad Maaz, Hanoona Rasheed, Salman Khan, Ming-Hsuan Yang, and Fahad Shahbaz Khan, “UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation,” arXiv:2212.04497, 2022.
  • [16] Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alexander C. Berg, Wan-Yen Lo, Piotr Dollár, and Ross Girshick, “Segment Anything,” arXiv:2304.02643, 2023.
  • [17] Kaidong Zhang and Dong Liu, “Customized Segment Anything Model for Medical Image Segmentation,” arXiv preprint arXiv:2304.13785, 2023.
  • [18] Jun Ma, Yuting He, Feifei Li, Lin Han, Chenyu You, and Bo Wang, “Segment anything in medical images,” arXiv preprint arXiv:2304.12306, 2023.
  • [19] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, “Deep Residual Learning for Image Recognition,” in CVPR, 2016.