Abstract:The success of large-scale pre-trained models has established fine-tuning as a standard method for achieving significant improvements in downstream tasks. However, fine-tuning the entire parameter set of a pre-trained model is costly. Parameter-efficient transfer learning (PETL) has recently emerged as a cost-effective alternative for adapting pre-trained models to downstream tasks. Despite its advantages, the increasing model size and input resolution present challenges for PETL, as the training memory consumption is not reduced as effectively as the parameter usage. In this paper, we introduce Fine-grained Prompt Tuning plus (FPT+), a PETL method designed for high-resolution medical image classification, which significantly reduces memory consumption compared to other PETL methods. FPT+ performs transfer learning by training a lightweight side network and accessing pre-trained knowledge from a large pre-trained model (LPM) through fine-grained prompts and fusion modules. Specifically, we freeze the LPM and construct a learnable lightweight side network. The frozen LPM processes high-resolution images to extract fine-grained features, while the side network employs the corresponding down-sampled low-resolution images to minimize the memory usage. To enable the side network to leverage pre-trained knowledge, we propose fine-grained prompts and fusion modules, which collaborate to summarize information through the LPM's intermediate activations. We evaluate FPT+ on eight medical image datasets of varying sizes, modalities, and complexities. Experimental results demonstrate that FPT+ outperforms other PETL methods, using only 1.03% of the learnable parameters and 3.18% of the memory required for fine-tuning an entire ViT-B model. Our code is available at https://github.com/YijinHuang/FPT.
Abstract:Parameter-efficient fine-tuning (PEFT) is proposed as a cost-effective way to transfer pre-trained models to downstream tasks, avoiding the high cost of updating entire large-scale pre-trained models (LPMs). In this work, we present Fine-grained Prompt Tuning (FPT), a novel PEFT method for medical image classification. FPT significantly reduces memory consumption compared to other PEFT methods, especially in high-resolution contexts. To achieve this, we first freeze the weights of the LPM and construct a learnable lightweight side network. The frozen LPM takes high-resolution images as input to extract fine-grained features, while the side network is fed low-resolution images to reduce memory usage. To allow the side network to access pre-trained knowledge, we introduce fine-grained prompts that summarize information from the LPM through a fusion module. Important tokens selection and preloading techniques are employed to further reduce training cost and memory requirements. We evaluate FPT on four medical datasets with varying sizes, modalities, and complexities. Experimental results demonstrate that FPT achieves comparable performance to fine-tuning the entire LPM while using only 1.8% of the learnable parameters and 13% of the memory costs of an encoder ViT-B model with a 512 x 512 input resolution.
Abstract:Recent applications of deep convolutional neural networks in medical imaging raise concerns about their interpretability. While most explainable deep learning applications use post hoc methods (such as GradCAM) to generate feature attribution maps, there is a new type of case-based reasoning models, namely ProtoPNet and its variants, which identify prototypes during training and compare input image patches with those prototypes. We propose the first medical prototype network (MProtoNet) to extend ProtoPNet to brain tumor classification with 3D multi-parametric magnetic resonance imaging (mpMRI) data. To address different requirements between 2D natural images and 3D mpMRIs especially in terms of localizing attention regions, a new attention module with soft masking and online-CAM loss is introduced. Soft masking helps sharpen attention maps, while online-CAM loss directly utilizes image-level labels when training the attention module. MProtoNet achieves statistically significant improvements in interpretability metrics of both correctness and localization coherence (with a best activation precision of $0.713\pm0.058$) without human-annotated labels during training, when compared with GradCAM and several ProtoPNet variants. The source code is available at https://github.com/aywi/mprotonet.
Abstract:This paper presents a simple and effective generalization method for magnetic resonance imaging (MRI) segmentation when data is collected from multiple MRI scanning sites and as a consequence is affected by (site-)domain shifts. We propose to integrate a traditional encoder-decoder network with a regularization network. This added network includes an auxiliary loss term which is responsible for the reduction of the domain shift problem and for the resulting improved generalization. The proposed method was evaluated on multiple sclerosis lesion segmentation from MRI data. We tested the proposed model on an in-house clinical dataset including 117 patients from 56 different scanning sites. In the experiments, our method showed better generalization performance than other baseline networks.