Med-Flamingo: a Multimodal Medical Few-shot Learner - Med-Flamingo

cover
19 Jun 2024

Authors:

(1) Michael Moor, Department of Computer Science, Stanford University, Stanford, USA and these authors contributed equally to this work;

(2) Qian Huang, Department of Computer Science, Stanford University, Stanford, USA and these authors contributed equally to this work;

(3) Shirley Wu, Department of Computer Science, Stanford University, Stanford, USA;

(4) Michihiro Yasunaga, Department of Computer Science, Stanford University, Stanford, USA;

(5) Cyril Zakka, Department of Cardiothoracic Surgery, Stanford Medicine, Stanford, USA;

(6) Yash Dalmia, Department of Computer Science, Stanford University, Stanford, USA;

(7) Eduardo Pontes Reis, Hospital Israelita Albert Einstein, Sao Paulo, Brazil;

(8) Pranav Rajpurkar, Department of Biomedical Informatics, Harvard Medical School, Boston, USA;

(9) Jure Leskovec, Department of Computer Science, Stanford University, Stanford, USA.

Abstract and 1 Introduction

2 Related Works

3 Med-Flamingo

4 Evaluation

5 Results

6 Discussion, Acknowledgments, and References

A Appendix

3 MED-FLAMINGO

To train a Flamingo model adapted to the medical domain, we leverage the pre-trained OpenFlamingo9B model checkpoint Awadalla et al. (2023), which is a general-domain VLM that was built on top of the frozen language model LLaMA-7B Touvron et al. (2023) and frozen vision encoder CLIP ViT/L-14 Radford et al.. We perform continued pre-training in the medical domain which results in the model we refer to as Med-Flamingo.

3.1 DATA

We pre-train Med-Flamingo by jointly training on interleaved image-text data and paired image-text data. As for the interleaved dataset, we created a interleaved dataset from a set of medical textbooks, which we subsequently refer to as MTB. As for the paired datasets, we used PMC-OA Lin et al. (2023).

MTB We construct a new multimodal dataset from a set of 4 721 textbooks from different medical specialties (see Figure 3). During preprocessing, each book is first converted from PDF to HTML with all tags removed, except the image tags are converted to tokens. We then carry out data cleaning via deduplication and content filtering. Finally, each book with cleaned text and images is then chopped into segments for pretraining so that each segment contains at least one image and up to 10 images and a maximum length. In total, MTB consists of approximately 0.8M images and 584M tokens. We use 95% of the data for training and 5% of the data for evaluation during the pre-training. We construct a new multimodal dataset from a set of 4 721 textbooks from different medica specialties (see Figure 3). During preprocessing, each book is first converted from PDF to HTML with all tags removed, except the image tags are converted to tokens. We then carry out data cleaning via deduplication and content filtering. Finally, each book with cleaned text and images is then chopped into segments for pretraining so that each segment contains at least one image and up to 10 images and a maximum length. In total, MTB consists of approximately 0.8M images and 584M tokens. We use 95% of the data for training and 5% of the data for evaluation during the pre-training.

PMC-OA We adopt the PMC-OA dataset Lin et al. (2023) which is a biomedical dataset with 1.6M image-caption pairs collected from PubMedCentral’s OpenAccess subset. We use 1.3M image-caption pairs for training and 0.16M pairs for evaluation following the public split[2].

3.2 OBJECTIVES

We follow the original We follow the original Flamingo model approach Alayrac et al., which considers the following language modelling problem:

3.3 TRAINING

We performed multi-gpu training on a single node with 8x 80GB NVIDIA A100 GPUs. We trained the model using DeepSpeed ZeRO Stage 2: Optimizer states and gradients are sharded across devices. To further reduce memory load, we employed the 8-bit AdamW optimizer as well as the memory-efficient attention implementation of PyTorch 2.0. Med-Flamingo was initialized at the checkpoint of the Open-Flamingo model and then pre-trained for 2700 steps (or 6.75 days in wall time, including the validation steps), using 50 gradient accumulation steps and a per-device batch size of 1, resulting in a total batch size of 400. The model has 1.3B trainable parameters (gated cross attention layers and perceiver layers) and roughly 7B frozen parameters (decoder layers and vision encoder), which results in a total of 8.3B parameters. Note that this is the same number parameters as in the OpenFlamingo-9B model (version 1).

This paper is available on arxiv under CC BY-NC-SA 4.0 DEED license.


[2] https://huggingface.co/datasets/axiong/pmc_oa_beta