Optimizing Vision Transformers for Medical Image Segmentation

For medical image semantic segmentation (MISS), Vision Transformers have emerged as strong alternatives to convolutional neural networks thanks to their inherent ability to capture long-range correlations. However, existing research uses off-the-shelf vision Transformer blocks based on linear projections and feature processing which lack spatial and local context to refine organ boundaries. Furthermore, Transformers do not generalize well on small medical imaging datasets and rely on large-scale pre-training due to limited inductive biases. To address these problems, we demonstrate the design of a compact and accurate Transformer network for MISS, CS-Unet, which introduces convolutions in a multi-stage design for hierarchically enhancing spatial and local modeling ability of Transformers. This is mainly achieved by our well-designed Convolutional Swin Transformer (CST) block which merges convolutions with Multi-Head Self-Attention and Feed-Forward Networks for providing inherent localized spatial context and inductive biases. Experiments demonstrate CS-Unet without pre-training outperforms other counterparts by large margins on multi-organ and cardiac datasets with fewer parameters and achieves state-of-the-art performance. Our code is available at Github.


INTRODUCTION
Medical image semantic segmentation (MISS), which classifies image pixels with semantic organ labels (e.g.Kidney and Liver) for various imaging modalities, is considered as one of the most fundamental problems in medical imaging.However, compared to natural scene images, MISS requires overcoming more challenges to create robust models.For instance, common benchmark datasets in MISS suffer from large deformation of organs under different image acquisition processes.In addition, shortage of costly pixel-level annotations is another problem leading to a performance gap.To achieve efficient and effective segmentation, models are not only required to have a better understanding of their local semantic features to capture more subtle organ structures, but also of global feature dependencies to capture the relationships among multiple organs.
UNet [1] and its variants [2][3] [4][5] [6] with Convolutional Neural Networks (CNNs) as the backbone have found huge success in MISS as they are good at modelling local attributes inside their receptive field.However, the inherent locality of convolution operations restricts their ability to model long-range semantic dependencies within the image, and as a result the challenging Thanks to EPSRC (EP/W01212X/1),Royal Society (RGS/R2/212199), UKRI 104690 (iCAIRD) and China Scholarship Councils for funding. 1 https://github.com/kathyliu579/CS-Unetboundaries of the whole organ may not be effectively segmented.Attention mechanisms alleviate this issue, but these tend to be 'single head' mechanisms that only calculate pixel-level similarities, and not 'multi head' with the ability to capture patch-level patterns.
For alleviating the inherent flaws of CNNs, there's a recent shift in the choice of architectures from CNNs to Vision Transformers (ViTs) due to their ability to model long range semantic attributes among input tokens (embeddings of image patches) via a linearly projected Multi-Head Self-Attention (MHSA) operation and a Feed-Forward Network (FFN).Most early works [7][8] [9] treat CNNs as a backbone and exploit the Transformer's desirable characteristics in their encoder.They tend to have high complexity as they stack bulky Transformer blocks on top of convolutional feature extractors (large pretrained CNNs, e.g.ResNet).Recent research [10][11] [12][13] [14][15] [16][17] has moved towards using Transformers as the main stem for building the entire segmentation architecture.Swin-UNet [12] is regarded as the first pure Transformer model.It keeps the familiar U-shape and adds hierarchical feature extraction using shifted windows proposed by the Swin Transformer [18].This drastically reduces the quadratic complexity of traditional self-attention while achieving better performance.
However, most of these Transformers for MISS use off-theshelf Transformer blocks from Computer Vision community and only model and extract linear semantic relations via MHSA and FFN, leading to the challenge of precisely delineating organ boundaries due to the lack of spatial and local information as shown in Figure 1.(d), although showing small influence on detection and classification tasks.Besides, these methods require a large dataset to compensate the lack of inductive biases such as translation equivariance [19], which may be defected or even lost when fine-tuning on downstream tasks, showing less robustness on small datasets.
Keeping the current state of the literature in mind, our paper highlights issues that today's Transformers in MISS face, followed by our contribution that helps alleviate those drawbacks.Most current Transformers are bulky and rely on pre-training weights from classical vision tasks to be adapted for MISS.To the best of our knowledge, no existing study explores the effects of adding spatial locality inside Transformer blocks via convolutions for medical imaging.To this end, we first propose an empirical analysis to show the need for spatial locality in pure Transformer based MISS.Our insights show the effects of introducing convolutions to Transformer blocks and multi-stage design of networks on segmentation performance.We call the final model resulting from our experiments, Convolutional Swin-Unet (CS-Unet), which is based on purely convolutional Transformer blocks created to make Transformers model local information better, segment organ boundaries more accurately, while maintaining a low computational complexity.Experiments on CT and MRI datasets show CS-Unet (24M parameters) trained from scratch outperforms pre-trained Swin-Unet (27M) on ImageNet by around 3% dice score, achieving state-of-the-art performance.

METHOD
Most Transformer based methods in MISS, i.e., encoder-decoder models with a standard U-shape, use a standard Transfomer block containing linear projections and linear FFNs, which are essentially MLPs, to process the data.Hence, to create effective image representations using such a regime requires huge amounts of data for training, as they lack local spatial information.
The first pure-Transformer based MISS model is the Swin-Unet [12] which adopts Swin Transformer blocks [18] to add locality information to Transformers.The data representation created here is still inherently linear as this block contains linear projections and feature processing.Next, we add convolutional projections to this Swin Transformer block structure.The projections follow the methodology proposed in [20] where tokens are first shaped into a 2D token map, then processed by a depth-wise separable convolution with kernel size s implemented by: Depth-wise Conv → BatchNorm → Point-wise Conv.Finally, the tokens are flattened into 1D token input x q/k/v i for Q/K/V matrices.It can be formulated as: Following this, when a 3x3 convolution is used for FFNs instead of MLPs to introduce more spatial context, we see the full effects of adding complete spatial locality to Transformers through the boundaries of the left and right kidneys and spleen becoming greatly refined.The over-segmentation problem of the pancreas however gets worse (as shown in Figure 1.(b)).This is due to the limited receptive field not modeling the whole boundary of big organs effectively.

Convolutional Swin Transformer (CST) Layer
We propose a CST layer to fully explore spatial modeling ability of convolutions in MHSA and FFN.First, we propose a novel (shifted) window based convolutional multi-head self attention ((S)W-CMSA) to extract hierarchical semantic features while reducing computational costs, by combining a shifted windows mechanism and convolutional projection.Then, we replace the MLP with our novel depthwise separable feed-forward (DSF) module.From Figure 1.(a), we see the Transformer model based on CST handles challenging organ boundaries more efficiently.The CST layer is formulated as: where ẑl and z l denote the outputs of (S)W-CMSA module and DSF of the l-th block, respectively.
As shown in Figure 2, once tokens enter (S)W-CMSA, they are reshaped into a 2D token map, and partitioned into windows.For each window, we use three depth-wise convolutions with kernel size s of 3×3, padding of 1 and stride of 1 to create our Q, K and V vectors via: F latten(DepthConv(W indow(Reshape(xi)), s).
CST is different from [20] as we create a projection based on windows rather than the whole image, leading to more refined local features as now the kernels learnt on each window are different.In order to better adapt to medical images with smaller data volumes, point-wise convolutions are removed to avoid over-fitting.Furthermore, we replace Batch Normalization with Layer Normalization (LN), providing a performance boost.The token vectors are fed to MHSA as: Here d represents the dimension of the query and key.The values in B are the bias.
Then, we replace the linear layer and feed the attention output to a 3×3 depth-wise convolution for fine-tuning for more spatial information.We follow this by reversing the windows to 2D token maps, resulting in more robust estimations compared to Swin Transformer [18] removing our dependence on positional encoding.
Depthwise separable feed-forward (DSF) module After computing (S)W-CMSA, the feature maps are fed into a FFN.Existing Transformers implement this module as an MLP: LN,d → Linear,4×d → GELU → Linear,d → RC.The d denotes the number of channels of a reshaped feature map and RC denotes the residual connection.We propose a DSF module as a choice of FFN which provides adding spatial context.We use three depth-wise convolutions instead of two linear layers for utilizing the features between channels, C. In addition, we found that adding LN after convolution gives better segmentation results.The DSF is implemented as: 7x7 Depth-wise Conv,d → LN,d → Point-wise Conv,4×d → GELU → Point-wise Conv,d → RC.

Overall Structure Design
CS-Unet keeps a symmetrical UNet shape.The input of our model is a 2D image slice with the resolution of H × W × 3 sampled from a 3D volume of images.H, W and 3 denote the height, width and number of channels of each input.The input images on entering the encoder are passed through the convolutional token embedding to create a sequence embedding on overlapping patches of the image, following which CST and patch merging layers are applied.Extracted features are then processed by the model's bottleneck that consists of two CST blocks.A symmetrical decoder then creates the final segmentation marks.In addition, skip convolution (SC) modules are added between corresponding feature pyramids of the encoder and decoder to compensate for the missing information caused by down-sampling.The overall architecture of the proposed CS-Unet is presented in Figure 3.

Encoder
The input image is first passed through the convolutional token embedding layer to create a sequence embedding with the resolution of Convolutional Token Embedding layer Existing models use a linear layer to split images into non-overlapping patches and reduce the size of the image drastically, e.g. by 75%, while increasing the channel dimension C.However, as the images' highest resolution is H × W at the encoder, using a linear layer to compress these features not only loses high-quality spatial and local information, but also increases model size.Our embedding layer, is implemented as four convolutions with overlapping patches to compress features in stages, helping to introduce more spatial dependency between, and inside the patches, while greatly reducing the parameters (by 6M.See Ablation 3, Method 1).Specifically, this layer is implemented as follows:

Decoder
Our decoder is symmetric to the encoder.Feature representation is created by enlarging the feature volume through a convolutional upsampling module and then passing it through a SC module to compensate for the information lost due to down-sampling.A CST layer then provides spatial context to the upsampled features.After repeating the above process three times, the features are fed into the patch expansion layer which up-samples by 4×, followed by a linear projection to fine tune the final segmentation prediction.Specifically, convolutional up-sampling module employs strided deconvolution to 2× up-sample feature maps and halves the channel dimension as: Skip Convolutions (SC) module The outputs of high-resolution feature maps created from up-sampling are concatenated with shallow feature representations from the encoder, and then merged by a SC module.It further enriches both spatial and fine-grained information, while compensating for the missing information caused by down-sampling.It is implemented as :3x3 s=1 Conv,d/2 → GELU → 3x3 s=1 Conv,d/2 → GELU.

EXPERIMENTS
We use two publicly available datasets to benchmark our method.

Implementation details
We train our models on a single Nvidia RTX3090 GPU with 24GB memory.We use flipping and rotation augmentations on the training data.The input image size is 224×224.Pre-trained weights are used for other methods if provided, while our model is trained from scratch for 300 epochs from a randomly initialized set of weights.A batch size of 24 and a combination of cross entropy and dice loss are used.Our model is optimized by AdamW [23] with a weight decay of 5E-4 for both datasets.The learning rates for Synapse and ACDC are 1e-3 and 5e-3, respectively.We start with a 10-epoch linear warmup.Layer Scale [24] of initial value 1e-6 is applied.

Experimental Results
As shown in Table 1 and Table 2, our model consistently surpasses a variety of convolution-based and Transformer-based methods.CS-Unet outperforms Swin-Unet by 3.08% and 3.3% DSC on Synapse and ACDC, respectively.In addition, our method gets the highest DSC for five and two organs of Synapse and ACDC respectively, especially providing large boosts for challenging organs like gallbladder, pancreas and RV.Overall, compared to pretrained Swin-Unet (27 M), nnFormer(158 M) and TransUnet (96 M), CS-Unet achieves the best performance without pretraining while being lightweight (24 M) via introducing more local perception and inductive bias.
Figure 4 visualizes segmentation results.In case 1, our method has overwhelming advantage on segmenting the pancreas, stomach and liver.CS-Unet is also more discriminative on the complex shape of RV than other Transformer-based models in case 2 due to its better ability of spatial context modelling.

Ablation study
We explore the influence of proposed modules on the performance on Synapse as shown in Table 3.The Swin-Unet trained from scratch is treated as the baseline (method 0) which cannot adapt to small datasets.Adding convolutional token embedding (method 1) and convolutional projections (method 2), we observe large improvements of 8% and 9% on DSC which is competitive with pre-trained

Fig. 1 .
Fig. 1.Visualization of segmentation results of different methods trained from scratch on Synapse dataset.
Figure 1 shows segmentation visualizations for the Synapse dataset.Swin-Unet trained from a random weight initialization (Figure 1.(d)) does not perform well.It fails to detect the spleen and misclassifies the left kidney as the right.

Figure 1 .
(c)   shows outputs of the resultant Unet trained with this block.It visually demonstrates how spatial locality is essential for low level pixel labelling tasks.It can be seen that although the convolutional projection alleviates a lot of the problems posed by the linearity of Swin-Unet, there are still severe over-segmentations on pancreas and liver and extremely rough boundaries of right kidney.

H 4 W 4 ×
C (C = 96 in experiments).This embedding is fed to three main CST layers and a patch merging module which downsamples the image and doubles the number of channels.For example, at the first patch merging module, an input with size H 4 × W 4 ×C is divided into four parts and concatenated along the C dimension to create a feature map of size H 8 × W 8 × 4C.Then a linear layer is applied to this map to reduce the C dimension by a factor of 2.

Fig. 3 .
Fig. 3. (a) Overall architecture of CS-Unet, (b) one CST layer, (c) convolutional token embedding, (d) DSF and (e) skip convolutions.d is the current number of channels, c is an arbitrary dimension.
18 cases (2212 axial slices) are extracted for training, while other 12 cases are used for testing.We report the model performance evaluated with the average Dice score Coefficient (DSC) and average Hausdorff Distance (HD) on eight abdominal organs.Automatic Cardiac Diagnosis Challenge (ACDC): ACDC [22] contains MRI images from 100 patients, with right ventricle (RV), left ventricle (LV) and myocardium (MYO) to be segmented.Using data splits proposed in [16], the dataset is split into 70 (1930 axial slices), 10 and 20 for training, validation and testing, respectively.Evaluation metrics used are average DSC (%) and HD (mm).

Table 1 .
, Comparison with different models on Synapse.Gallbladder, left Kidney, right Kidney, Pancreas and Stomach are abbreviated as Gallb, Kid L, Kid R, Pancr and Stom.

Table 2 .
Experimental results of ACDC.