MEDICAL IMAGING ANALYSIS USING SELF-SUPERVISED LEARNING
20230410483 ยท 2023-12-21
Assignee
Inventors
Cpc classification
G06V10/26
PHYSICS
G06V10/7753
PHYSICS
G06V10/50
PHYSICS
G06N3/0895
PHYSICS
International classification
G06V10/774
PHYSICS
G06V10/26
PHYSICS
Abstract
A method includes obtaining a first training data set including unannotated multi-dimensional medical images and executing a self-supervised masked image modeling (MIM) training process to pre-train an image encoder on the first training data set. The method also includes obtaining a second training data set that includes annotated multi-dimensional medical images. Here, each annotated multi-dimensional medical image includes a plurality of image voxels each paired with a corresponding ground-truth label indicating a class the corresponding image voxel belongs to. The method also includes executing a supervised training process to train an image analysis model on the second training data set to teach the image analysis model to learn how to predict the corresponding ground-truth labels for the plurality of image voxels of each annotated multi-dimensional medical image. The image analysis model incorporates the pre-trained image encoder.
Claims
1. A computer-implemented method executed on data processing hardware causes the data processing hardware to perform operations comprising: obtaining a first training data set comprising a plurality of unannotated multi-dimensional medical images; executing a self-supervised masked image modeling (MIM) training process to pre-train an image encoder on the first training data set; obtaining a second training data set comprising a plurality of annotated multi-dimensional medical images, each annotated multi-dimensional medical image comprising a plurality of image voxels each paired with a corresponding ground-truth label indicating a class the corresponding image voxel belongs to; and executing a supervised training process to train an image analysis model on the second training data set to teach the image analysis model to learn how to predict the corresponding ground-truth labels for the plurality of image voxels of each annotated multi-dimensional medical image, the image analysis model incorporates the pre-trained image encoder.
2. The method of claim 1, wherein executing the self-supervised MIM training process to pre-train the image encoder comprises, for each corresponding unannotated multi-dimensional medical image in the first training data set: generating, using an image tokenizer configured to receive the corresponding unannotated multi-dimensional medical image as input, a sequence of discrete visual tokens characterizing the corresponding unannotated multi-dimensional medical image; dividing the corresponding unannotated multi-dimensional medical image into a plurality of image patches; randomly masking a portion of the image patches divided from the corresponding unannotated multi-dimensional medical image; for each masked image patch: generating, using the image encoder, an encoded hidden representation for the masked image patch; and based on the encoded hidden representation, generating, using a decoder, a corresponding predicted token; determining a training loss based on the predicted tokens generated for the masked image patches and corresponding visual tokens from the sequence of discrete visual tokens that are aligned with the masked image patches; and updating parameters of the image encoder based on the training loss.
3. The method of claim 2, wherein: the image encoder comprises a plurality of multi-head attention layers; and the decoder comprises a plurality of multi-head attention layers.
4. The method of claim 2, wherein randomly masking the portion of the image patches comprises randomly masking the portion of the image patches using one of a central region masking strategy, a block-wise masking strategy, or a uniformly random masking strategy using different masked patch sizes and masking ratios.
5. The method of claim 2, wherein a number of visual tokens in the sequence of discrete visual tokens is equal to a number of image patches in the plurality of image patches.
6. The method of claim 1, wherein executing the self-supervised MIM training process to pre-train the image encoder comprises, for each corresponding unannotated multi-dimensional medical image in the first training data set: dividing the corresponding unannotated multi-dimensional medical image into a plurality of image patches, each image patch represented by a corresponding set of raw voxel values; randomly masking a portion of the image patches divided from the corresponding unannotated multi-dimensional medical image; for each masked image patch: generating, using the image encoder, an encoded hidden representation for the masked image patch; and based on the encoded hidden representation, generating, using a prediction head, predicted voxel values for the masked image patch; determining a training loss based on the predicted voxel values generated for the masked image patches and the corresponding sets of the raw voxel values that represent the masked image patches; and updating parameters of the image encoder based on the training loss.
7. The method of claim 6, wherein: the image encoder comprises a plurality of multi-head attention layers; and the prediction head comprises a single linear layer prediction head and is configured to generate the predicted voxel values from the encoded hidden representation without using a decoder.
8. The method of claim 6, wherein randomly masking the portion of the image patches comprises randomly masking the portion of the image patches using one of a central region masking strategy, a block-wise masking strategy, or a uniformly random masking strategy using different masked patch sizes and masking ratios.
9. The method of claim 1, wherein the image analysis model comprises a tumor segmentation model.
10. The method of claim 1, wherein the image analysis model comprises a multi-organ segmentation model.
11. A system comprising: data processing hardware; and memory hardware in communication with the data processing hardware and storing instructions that when executed on the data processing hardware causes the data processing hardware to perform operations comprising: obtaining a first training data set comprising a plurality of unannotated multi-dimensional medical images; executing a self-supervised masked image modeling (MIM) training process to pre-train an image encoder on the first training data set; obtaining a second training data set comprising a plurality of annotated multi-dimensional medical images, each annotated multi-dimensional medical image comprising a plurality of image voxels each paired with a corresponding ground-truth label indicating a class the corresponding image voxel belongs to; and executing a supervised training process to train an image analysis model on the second training data set to teach the image analysis model to learn how to predict the corresponding ground-truth labels for the plurality of image voxels of each annotated multi-dimensional medical image, the image analysis model incorporates the pre-trained image encoder.
12. The system of claim 11, wherein executing the self-supervised MIM training process to pre-train the image encoder comprises, for each corresponding unannotated multi-dimensional medical image in the first training data set: generating, using an image tokenizer configured to receive the corresponding unannotated multi-dimensional medical image as input, a sequence of discrete visual tokens characterizing the corresponding unannotated multi-dimensional medical image; dividing the corresponding unannotated multi-dimensional medical image into a plurality of image patches; randomly masking a portion of the image patches divided from the corresponding unannotated multi-dimensional medical image; for each masked image patch: generating, using the image encoder, an encoded hidden representation for the masked image patch; and based on the encoded hidden representation, generating, using a decoder, a corresponding predicted token; determining a training loss based on the predicted tokens generated for the masked image patches and corresponding visual tokens from the sequence of discrete visual tokens that are aligned with the masked image patches; and updating parameters of the image encoder based on the training loss.
13. The system of claim 12, wherein: the image encoder comprises a plurality of multi-head attention layers; and the decoder comprises a plurality of multi-head attention layers.
14. The system of claim 12, wherein randomly masking the portion of the image patches comprises randomly masking the portion of the image patches using one of a central region masking strategy, a block-wise masking strategy, or a uniformly random masking strategy using different masked patch sizes and masking ratios.
15. The system of claim 12, wherein a number of visual tokens in the sequence of discrete visual tokens is equal to a number of image patches in the plurality of image patches.
16. The system of claim 11, wherein executing the self-supervised MIM training process to pre-train the image encoder comprises, for each corresponding unannotated multi-dimensional medical image in the first training data set: dividing the corresponding unannotated multi-dimensional medical image into a plurality of image patches, each image patch represented by a corresponding set of raw voxel values; randomly masking a portion of the image patches divided from the corresponding unannotated multi-dimensional medical image; for each masked image patch: generating, using the image encoder, an encoded hidden representation for the masked image patch; and based on the encoded hidden representation, generating, using a prediction head, predicted voxel values for the masked image patch; determining a training loss based on the predicted voxel values generated for the masked image patches and the corresponding sets of the raw voxel values that represent the masked image patches; and updating parameters of the image encoder based on the training loss.
17. The system of claim 16, wherein: the image encoder comprises a plurality of multi-head attention layers; and the prediction head comprises a single linear layer prediction head and is configured to generate the predicted voxel values from the encoded hidden representation without using a decoder.
18. The system of claim 16, wherein randomly masking the portion of the image patches comprises randomly masking the portion of the image patches using one of a central region masking strategy, a block-wise masking strategy, or a uniformly random masking strategy using different masked patch sizes and masking ratios.
19. The system of claim 11, wherein the image analysis model comprises a tumor segmentation model.
20. The system of claim 11, wherein the image analysis model comprises a multi-organ segmentation model.
Description
DESCRIPTION OF DRAWINGS
[0011]
[0012]
[0013] training processes for pre-training the image encoder of
[0014]
[0015]
[0016]
[0017]
[0018]
[0019]
[0020]
[0021]
[0022]
[0023]
[0024]
[0025]
[0026]
[0027] Like reference symbols in the various drawings indicate like elements.
DETAILED DESCRIPTION
[0028] Computer vision analysis has witnessed a paradigm shift from using Convolutional Neural Networks (CNNs) to using multi-head attention-based architectures. The present disclosure refers to Transformer-based architectures employing self-attention as one type of multi-head attention-based architecture by way of example, however, the present disclosure may employ other types of multi-head attention-based architectures for enhancing multidimensional input images. Generally, a Transformer-based architecture (i.e., a vision transformer) splits a multidimensional input image into patches and creates patch embeddings as inputs to a Transformer-based model for various vision tasks including image classification, object detection, and image segmentation.
[0029] Three-dimensional (3D) medical imaging technologies such as computed tomography (CT) or magnetic resonance imaging (MM) are widely used in diagnosing and treating a wide range of diseases. Generally, 3D medical volumetric images can help increase the speed and accuracy of diagnosing patient conditions. For instance, properly and swiftly discovering and measuring tumor lesions from MM or CT scans could be critical to disease prevention, early detection and treatment plan optimization, and also inspire the development of more successful clinical applications to ultimately improve patients' lives. A fundamental task performed for medical image analysis includes 3D image segmentation. Another fundamental task performed for medical image analysis includes image classification. Image classifications tasks classify input images into various categories. Generally, 3D image segmentation (also referred to as 3D semantic segmentation) aims to predict a corresponding class for each voxel of a volumetric input image to classify one or more particular objects and separating each of the particular objects from one another by overlying respective segmentation masks over the particular objects. 3D image segmentation has the potential to alleviate the burden for radiologists' daily workload by automating or assisting image interpretation workflow to ultimately improve clinical care and patient outcome. Example 3D image segmentation tasks may include multi-organ segmentation performed as a 13-class segmentation task with single-channel input and brain tumor segmentation performed as a three class segmentation class with four-channel input.
[0030] Training robust Transformer-based image analysis models require more annotated training data to surpass performance of conventional CNNs. However, the high expenses of obtaining expert annotations of 3D medical volumetric images in particular domains frequently stymies attempts to leverage advances in clinical outcomes using deep learning approaches for 3D medical image analysis. In short, annotations of 3D medical images at scale by radiologists are limited, expensive, and time-consuming to produce. Another limiting factor in 3D medical image processing is the sheer data volume associated with 3D medical images, which is driven by increased 3D image dimensionality and resolution, resulting in significant processing complexity. As a consequence, the ability to effectively integrate radiomics endpoint information with other bio-marker data for other downstream tasks in clinical study designs such as tumor burden assessment and overall survival prediction can be extremely difficult.
[0031] Transfer learning is the use of a trained model from one context in a different context. Transfer learning from natural images can be utilized in medical image analysis, regardless of disparities in image statistics, scale, and task-relevant characteristics. Transfer learning from, for example, ImageNet can accelerate convergence on medical images, which can be useful when the medical image training data is limited. Transfer learning using domain-specific data can also assist in resolving the domain disparity issue. For instance, improved performance can be achieved following pre-training on labeled data from the same domain. However, this strategy can be frequently impractical for a variety of medical scenarios requiring labeled data that is costly and time-consuming to gather. Self-supervised learning offers a viable alternative, allowing for the utilization of unlabeled/unannotated medical data.
[0032] Self-supervised learning is a training technique that focuses on learning representations from unlabeled data so that a low-capacity classifier can achieve high accuracy using various embeddings. Contrastive learning is another example of self-supervised learning strategies. Contrastive learning models image similarity and dissimilarity (or solely similarity) between two or more views, with data augmentation being crucial for contrastive and related approaches. Self-supervised learning can be used in the medical field such as in domain-specific pretext tasks or tailoring contrastive learning to medical data. A range of self-supervised learning strategies can be applied to 3D medical imaging. For example, a model pretrained on the ImageNet dataset can be applied to dermatology image classification. In another example, inpainting can be combined with contrastive learning for medical image segmentation.
[0033] Masked image modeling approaches, in general, mask out a portion of input images or encoded image tokens and encourage the model to recreate the masked area. Some extant MIM models employ an encoder-decoder design followed by a projection head. The encoder aids in the modeling of latent feature representations, while the decoder aids in the resampling of latent vectors to original images. The encoded or decoded embeddings can subsequently be aligned with the original signals at the masked area by a projection head. Notably, the decoder component can be a lightweight design so as to minimize training time. A lightweight decoder can not only reduce computing complexity but also can increase the encoder's ability to learn more generalizable representations that the decoder can easily grasp, translate, and convey. An encoder can be used for fine-tuning. Encoding techniques such as SimMIM can obviate the entire decoder with a single projection layer.
[0034] Using a vision transformer (ViT), for example, an image can be divided into regular non-overlapping patches (e.g., a 969696 3D volume can be divided into 216 patches of 161616 smaller volumes), which are often considered as the basic processing units of vision transformers. There are a number of random masking techniques, including but not limited to, a central region masking strategy, a complex block-wise masking strategy, and/or a uniformly random masking method at patch level using different masked patch sizes and masking ratios.
[0035] In some examples, the image encoder includes a vision transformer (ViT) architecture such as vanilla ViT (e.g., ViT3D, Swin-Transformer 3D, and/or an attention visual network (e.g., VAN3D) that can inherit an attention mechanism to derive hierarchical representations similar to, for example, Swin-Transformer 3D but instead using pure convolutions. Other types of multi-head attention layers may be employed by the image encoder such as, without limitation, Conformer layers, Performer layers, or lightweight convolutional layers.
[0036] Implementations herein are directed toward executing a self-supervised masked image modeling (MIM) training process to pre-train an image encoder on a plurality of unannotated (e.g., unlabeled) multi-dimensional medical images. As used herein, the multi-dimensional images are referred to as 3D medical images but may the disclosure is not so limited and may also include 4D medical images. The 3D medical images may include volumetric slices from CT or MM scans of interior (or exterior) body regions of patients. The image encoder includes a plurality of multi-head attention layers. For instance, the image encoder may include a Transformer-based architecture with self-attention that employs a stack of Transformer layers. As will become apparent, the image encoder is responsible for modeling latent feature representations of masked image patches, which can subsequently be utilized to forecast original image signals in regions associated with the masked image patches. The image encoder pre-trained on the unannotated 3D medical images via the self-supervised MIM training process is capable of adapting to a wide range of downstream vision tasks such as 3D image segmentation and image classification.
[0037] The pre-trained image encoder may be integrated into an image analysis model and fine-tuned using annotated multi-dimensional medical images to perform a particular downstream vision task. The annotated multi-dimensional medical images used to fine-tune the pre-trained image encoder, and ultimate train the image analysis model to perform the particular vision task, may each include a plurality of image voxels each paired with a corresponding ground-truth label indicating a class the corresponding image voxel belongs to. In this way, implementations of the present disclosure are further directed toward executing a supervised training process to train the image segmentation model on the plurality of annotated multi-dimensional medical images to teach the image segmentation model to learn how to predict the corresponding ground-truth labels for the plurality of image voxels for each annotated multi-dimensional medical image, whereby the image segmentation model includes the pre-trained image encoder initialized on the unannotated multi-dimensional images via the self-supervised MIM training process and fine-tuned on the annotated multi-dimensional images via the supervised training process. In some examples, the trained image analysis model includes an image segmentation model for performing 3D image segmentation tasks such as multi-organ segmentation or tumor segmentation performed on 3D image slices divided from MM or CT scans of interior body regions. Described in greater detail below, the trained image analysis model may receive, as input, multiple image patches divided from a multi-dimensional medical image (i.e., a volumetric slice from a MM or CT scan), generate an enhanced medical image based on features extracted from the multi-dimensional medical image, and perform image segmentation or image classification on the enhanced image. In the image segmentation scenario, the trained image analysis model may be trained to classify one or more particular objects (e.g., tumors or organs) in the enhanced image and separating each of the particular objects from one another by augmenting the enhanced image to include respective segmentation masks overlaying the particular objects. As used herein, augmenting an enhanced image to include a segmentation mask includes augmenting image voxels in the enhanced image that that represent each object class and/or define a boundary of the respective object class. The augmenting of image voxels may include changing a color of the image voxels, adjusting an intensity of the image voxels, or augmenting the image voxels in any suitable manner so that the each object classified is distinguishable and identifiable within the enhanced image.
[0038]
[0039] The first computing system 120a may include a distributed system (e.g., cloud computing environment). The second computing system 120b may include a computing device (e.g., desktop computer, workstation, laptop, tablet, etc.) that downloads the image analysis model 170 from the first computing system 120a. In some other implementations, the first computing system 120a receives the raw 3D medical images 110R from the second computing system 120b and executes the image analysis model 170 to perform the downstream vision task. In additional implementations, the second computing system 120b receives, from the first computing system 120a, the image encoder 150 pre-trained by the self-supervised training process 200 and executes the supervised training process 160 to fine-tune the pre-trained image encoder on the downstream vision task. In this scenario, the annotated MD images 204 may be processed locally on the second computing system 120b via the supervised training process 160, thereby preserving privacy and sensitive data.
[0040] The self-supervised training process 200 trains the image encoder 150 on a first training data set 201 that includes the plurality of unannotated multi-dimensional (MD) images 202. Specifically, and as described in greater detail below with reference to
[0041] Notably, self-supervised MIM training as disclosed herein in is especially advantageous for modeling 3D medical images by significantly speeding up training convergence and improves downstream performance. For instance, when compared to naive contrastive learning, training convergence can save up to a 1.40 training cost to reach a same or higher dice score when the pre-trained image encoder 150 is adapted and fine-tuned to perform a downstream vision task. Similarly, the downstream performance of the downstream vision task of image segmentation can achieve over 5-percent (5%) improvements without any hyper parameter tuning. Additionally, downstream applications incorporating the image encoder pre-trained via self-supervised MIM training are faster and more cost-effective then transfer learning to the particular downstream task for prognosis, treatment sensitivity prediction, tissue segmentation, image classification, and digital representations of patients. As will become apparent, training the image encoder 150 via the self-supervised MIM training process 200 enables prediction of raw voxel values using a high masking ratio and a relatively small patch size. For simply reconstructing raw input 3D medical images 110R into enhanced 3D medical images 110E, a lightweight decoder may be implemented to receive the encoded feature representations 225 output by the image encoder 150 and perform reconstruction of image signals at increased speeds and reduced computing and memory costs. Self-supervised MIM training is versatile across raw input 3D medical images 110R having diverse image resolutions and labeled data ratios during the supervised training process 160.
[0042] Generally, MIM learning includes a learning task that includes masking a subset of input signals (e.g., image patches 210) and forecasting the masked signals. Stated differently, MIM learning/training is a self-supervised learning technique that learns representations via masking-corrupted images. Masking can be presented as a noise type. Masked patch prediction for self-supervised learning can predict missing voxels by inpainting a large rectangular area of the source areas and grouping voxel values into different clusters to classify unknown voxel values. Additionally, masked patch prediction for self-supervised learning can be accomplished by predicting a mean color of images.
[0043] After the image encoder 150 is pre-trained via the self-supervised training process 200, the supervised training process 160 trains the image analysis model 170 on a second training data set 203 that includes the plurality of annotated MD medical images 204. The supervised training process 160 fine-tunes the pre-trained image encoder 150 integrated with the image analysis model 170 to teach the image analysis model 170 to perform downstream vision tasks such as image segmentation tasks or image classification tasks. Each annotated MD medical image 204 includes a plurality of image voxels 206 each paired with a corresponding ground-truth label 208 indicating a class the corresponding image voxel 206 belongs to. Notably, the unannotated 3D images 202 in the first training data set 201 used to pre-train the image encoder 150 may be associated with a different medical domain than the annotated 3D images 204 in annotated second training data set 203. For instance, the first data set 201 may include chest CT scans while the second data set 203 may include abdominal CT scans or multimodal MRI scans of brain tumors.
[0044] The image analysis model 170 may include a U-shaped encoder-decoder architecture that includes the image encoder 150 (employed as a ViT-based encoder, Swin Transformer, or VAN) to produce hierarchical encoded features 225 (
[0045] In one example, the second training data set 203 includes annotated 3D CT scans obtained from the Beyond the Cranial Vault (BTCV) Abdomen dataset that includes abdominal CT scans acquired from 30 participants/patients with 13 organ annotations by human interpreters under the supervision of clinical radiologists. Each 3D CT scan in the BTCV Abdomen dataset was performed in a portal venous phase with contrast enhancement and includes 80 to 225 slices with 512512 pixels and a slice thickness ranging from one to six millimeters (mm). During pre-processing, each annotated 3D image 204 may be resampled to 1.5-2.0 isotropic voxel spacing. In this example, the supervised training process 160 trains the image analysis model 170 as a multi-organ segmentation model to perform 13-class segmentation with 1-channel output. Thus, the ground-truth label 208 for each corresponding image voxels 206 in each annotated 3D image 204 may include one of 13 different classes depending on which organ the corresponding image voxel 206 belongs to.
[0046] In another example, the second training data set 203 includes annotated 3D MRI scan images obtained from the Brain Tumor Segmentation (BraTS) public data set that includes multi-modal and multi-site MRI scans with the ground-truth labels 208 for corresponding image voxels 206 representing regions of edema, non-enhancing core, and necrotic core. In this example, the supervised training process 160 trains the image analysis model 170 as a brain tumor segmentation model to perform 3-class segmentation with 4-channel input. The voxel spacing of the Mill images can be 1.01.01.0 mm3. The voxel intensities can be pre-processed with z-score normalization.
[0047] The self-supervised training process 200 may store the pre-trained image encoder 150 in data storage 180 overlain on the memory hardware 124 of the computing system 120. Likewise, the supervised training process 160 may store the trained image analysis model 170 in the data storage 180. The first computing system 120a and/or any number of second computing systems 120b may access/retrieve the pre-trained image encoder 150 and/or the trained image analysis model 170 for execution thereon.
[0048] During inference, the image analysis model 170 incorporating the pre-trained and fine-tuned image encoder 150 executes on the second computing system 120b (or the first computing system 120a) to process and perform an image analysis task on one or more raw input 3D medical images 110R. Notably, the image analysis task performed by the image analysis model 170 includes the downstream vision task (i.e., image segmentation or image classification) the image analysis model 170 was trained by the supervised training process 160 to perform. Each raw input 3D medical image 110R may correspond to a 3D image slice from a 3D CT scan or an 3D MM scan of an interior body of a patient. Optionally, the raw input 3D medical images 110R may correspond to 3D images of an exterior body region of the patient. Each raw input 3D medical image 110R may undergo initial image pre-processing 184 to divide the raw input 3D medical image 110R into a plurality of image patches 210, 210an. While nine (9) image patches is shown by way of example, the example is non-limiting and the pre-processing 184 may divide the image into any number of image patches 210. The image analysis model 170 may process the image patches 210 to generate an enhanced 3D medical image 110E and perform the downstream vision task on the enhanced 3D medical images 110E. When the image analysis model 170 performs the downstream vision task of 3D image segmentation, the model 170 predicts a corresponding class for each voxel of the volumetric enhanced 3D medical image 110E to classify one or more particular objects (e.g., tumors, tissue, organs) and separates each of the particular objects from one another by defining a respective segmentation mask to overly the voxels classifying each object. Example 3D image segmentation tasks may include multi-organ segmentation performed as a 13-class segmentation task with single-channel input and brain tumor segmentation performed as a three class segmentation class with four-channel input.
[0049] An image augmenter 360 may receive the enhanced 3D medical image 110E segmented to identify the voxels that represent each particular object class and generate a corresponding segmentation mask to apply to at least a portion of the voxels representing the particular object class. Accordingly, the image augmenter 360 may augment image voxels in the enhanced image that represent each object class and/or define a boundary of the respective object class. The augmenting of image voxels may include changing a color of the image voxels, adjusting an intensity of the image voxels, or augmenting the image voxels in any suitable manner so that the each object classified is distinguishable and identifiable within the enhanced image 110E. The segmentation mask may include a a graphical feature applied to the enhanced image to convey the location of each object class identified in the enhanced image 110E. The image augmenter 360 may output an enhanced augmented image 110A depicting the segmentation masks convey the segmentation results performed by the analysis model 170. A graphical user interface 360 executing on the computing system 120 may display the augmented image 110A on a screen in communication of the computing system 120. Additionally or alternatively, the enhanced image and/or the augmented image 110A may be provided as output to one or more additional downstream tasks.
[0050] Referring to
[0051]
[0052] In the example shown, the self-supervised MIM training process 200 adds positional embeddings 215 to the image patches 210. The image encoder 150 receives each masked image patch 210M, whereby each masked image patch may be replaced with a special masking embedding [M]. The special masking token [M] may be randomly initialized as a learnable vector optimized to reveal the corresponding masked image patch 210.
[0053] For each masked image patch [M], the image encoder 150 is configured to generate a corresponding encoded feature representation 225 (also referred to as an encoded hidden representation 225) and a decoder 250 decodes the corresponding encoded feature representation 225 to predict a corresponding predicted token 275 as output from the projection head 260. The objective of the MIM training process 200 is to teach the image encoder 150 and the decoder 250 to learn how to predict the visual tokens 240 obtained from the original 3D image 202. Specifically, the training process 200 teaches the encoder 150 to produce encoded feature representations 225 for the masked image patches 210M for use in generating predicted tokens 275 that match the visual tokens 240 obtained from the original 3D image 202. Here, the training process 200 may determine a training loss based on the predicted tokens 275 generated for the masked image patches 210M and the corresponding visual tokens from the sequence of discrete visual tokens 240 that are aligned (i.e., using the positional embeddings 215) with the masked image patches 210M. Thereafter, the training process 200 updates parameters of the image encoder 150 (and optionally the decoder 250) based on the training loss.
[0054] The decoder may include a plurality of multi-head attention layers (e.g., Transformer layers). In some examples, the masked image patches 210M are invisible to the encoder 150, whereby only the decoder 250 has knowledge of the various tokens. This approach may save computation and memory while not interfering with training.
[0055]
[0056] In the example shown, the self-supervised MIM training process 200 adds positional embeddings 215 to the image patches 210. The image encoder 150 receives each masked image patch 210M, whereby each masked image patch may be replaced with a special masking embedding [M]. The special masking token [M] may be randomly initialized as a learnable vector optimized to reveal the corresponding masked image patch 210.
[0057] For each masked image patch 210M, the image encoder 150 is configured to generate a corresponding encoded feature representation 225 and a prediction head 260 generates predicted voxel values 270 for the masked image patch 210M. Notably, the MIM training process 200 for pre-training the image encoder 150 having the SimMIM architecture omits a decoder and instead implements a prediction head 260 to predict raw voxel values 270 for each masked image patch 210M directly from the encoded feature representation 225 generated by the image encoder 225 for the corresponding masked image patch 210M. The training process 200 may determine a training loss based on the predicted voxel values 270 generated for the masked image patches and the corresponding sets of raw voxel values from the original unannotated MD medical image 202 that represent the masked image patches.
[0058] The training loss may be based on a distance in a voxel space between the recovered/estimated raw voxel values 270 and the original voxels from the corresponding sets of raw voxel values that represent the masked image patches. The training loss may include either an l.sub.1 or l.sub.2 loss function. Notably, the training loss may only be computed for the masked matches 210M to prevent the encoder 150 from engaging in self-reconstruction and potentially dominate the learning process and ultimately impeded knowledge learning. Thereafter, the training process 200 updates parameters of the image encoder 150 (and optionally the decoder 250) based on the training loss. The projection head can transform the predicted tokens 275 to the original voxel space when the pre-processing down samples the resolution of the medical image 202. Optionally, a two-layer convolutional transpose can up sample the compressed encoded feature representations 225 to the resolution of the original medical image 202.
[0059]
[0060]
[0061]
[0062]
[0063]
[0064] The self-supervised MIM training process 200 increases the training speed while reducing the cost to pre-train the image encoder 150 on the first training data set 201.
[0065] In some implementations, various masked patch sizes and masking ratios can be used for training the models using self-supervised MIM. Results of applying machine learning models to 3D medical images using several MIM techniques and then fine-tuning the pre-trained image encoder to perform downstream image segmentation are summarized in the tables of
[0066] A higher masking ratio is a non-trivial self-supervised learning job that can continually drive the model to build generalizable representations that can be transferred effectively to downstream tasks. For example, the best dice scores on multiorgan segmentation and brain tumor segmentation tasks are obtained when a masking ratio of approximately 0.75 is used across multiple patch sizes (e.g., 0.7183 for patch size 16 in
[0067] Generally, in supervised learning, more training data results in improved performance.
[0068]
[0069]
[0070] At operation 1406, the method 1400 includes obtaining a second training data set 203 that includes a plurality of annotated multi-dimensional medical images 204. Here, each annotated multi-dimensional medical image 204 includes a plurality of image voxels 206 each paired with a corresponding ground-truth label 208 indicating a class the corresponding image voxel belongs to. At operation 1408, the method 1400 includes executing a supervised training process 160 to train the image analysis model 170 on the second training data set 203 to teach the image analysis model 170 to learn how to predict the corresponding ground-truth labels 208 for the plurality of image voxels 206 of each annotated multi-dimensional medical image 204. Here, the image analysis model 170 incorporates the pre-trained image encoder 150. The supervised training process 160 fine tunes the pre-trained image encoder 150 initialized via the self-supervised MIM training process 200.
[0071] A software application (i.e., a software resource) may refer to computer software that causes a computing device to perform a task. In some examples, a software application may be referred to as an application, an app, or a program. Example applications include, but are not limited to, system diagnostic applications, system management applications, system maintenance applications, word processing applications, spreadsheet applications, messaging applications, media streaming applications, social networking applications, and gaming applications.
[0072] The non-transitory memory may be physical devices used to store programs (e.g., sequences of instructions) or data (e.g., program state information) on a temporary or permanent basis for use by a computing device. The non-transitory memory may be volatile and/or non-volatile addressable semiconductor memory. Examples of non-volatile memory include, but are not limited to, flash memory and read-only memory (ROM)/programmable read-only memory (PROM)/erasable programmable read-only memory (EPROM)/electronically erasable programmable read-only memory (EEPROM) (e.g., typically used for firmware, such as boot programs). Examples of volatile memory include, but are not limited to, random access memory (RAM), dynamic random access memory (DRAM), static random access memory (SRAM), phase change memory (PCM) as well as disks or tapes.
[0073]
[0074] The computing device 1500 includes a processor 1510, memory 1520, a storage device 1530, a high-speed interface/controller 1540 connecting to the memory 1520 and high-speed expansion ports 1550, and a low speed interface/controller 1560 connecting to a low speed bus 1570 and a storage device 1530. Each of the components 1510, 1520, 1530, 1540, 1550, and 1560, are interconnected using various busses, and may be mounted on a common motherboard or in other manners as appropriate. The processor 1510 can process instructions for execution within the computing device 1500, including instructions stored in the memory 1520 or on the storage device 1530 to display graphical information for a graphical user interface (GUI) on an external input/output device, such as display 1580 coupled to high speed interface 1540. In other implementations, multiple processors and/or multiple buses may be used, as appropriate, along with multiple memories and types of memory. Also, multiple computing devices 1500 may be connected, with each device providing portions of the necessary operations (e.g., as a server bank, a group of blade servers, or a multi-processor system).
[0075] The memory 1520 stores information non-transitorily within the computing device 1500. The memory 1520 may be a computer-readable medium, a volatile memory unit(s), or non-volatile memory unit(s). The non-transitory memory 1520 may be physical devices used to store programs (e.g., sequences of instructions) or data (e.g., program state information) on a temporary or permanent basis for use by the computing device 1500. Examples of non-volatile memory include, but are not limited to, flash memory and read-only memory (ROM)/programmable read-only memory (PROM)/erasable programmable read-only memory (EPROM)/electronically erasable programmable read-only memory (EEPROM) (e.g., typically used for firmware, such as boot programs). Examples of volatile memory include, but are not limited to, random access memory (RAM), dynamic random access memory (DRAM), static random access memory (SRAM), phase change memory (PCM) as well as disks or tapes.
[0076] The storage device 1530 is capable of providing mass storage for the computing device 1500. In some implementations, the storage device 1530 is a computer-readable medium. In various different implementations, the storage device 1530 may be a floppy disk device, a hard disk device, an optical disk device, or a tape device, a flash memory or other similar solid state memory device, or an array of devices, including devices in a storage area network or other configurations. In additional implementations, a computer program product is tangibly embodied in an information carrier. The computer program product contains instructions that, when executed, perform one or more methods, such as those described above. The information carrier is a computer- or machine-readable medium, such as the memory 1520, the storage device 1530, or memory on processor 1510.
[0077] The high speed controller 1540 manages bandwidth-intensive operations for the computing device 1500, while the low speed controller 1560 manages lower bandwidth-intensive operations. Such allocation of duties is exemplary only. In some implementations, the high-speed controller 1540 is coupled to the memory 1520, the display 1580 (e.g., through a graphics processor or accelerator), and to the high-speed expansion ports 1550, which may accept various expansion cards (not shown). In some implementations, the low-speed controller 1560 is coupled to the storage device 1530 and a low-speed expansion port 1590. The low-speed expansion port 1590, which may include various communication ports (e.g., USB, Bluetooth, Ethernet, wireless Ethernet), may be coupled to one or more input/output devices, such as a keyboard, a pointing device, a scanner, or a networking device such as a switch or router, e.g., through a network adapter.
[0078] The computing device 1500 may be implemented in a number of different forms, as shown in the figure. For example, it may be implemented as a standard server 1500a or multiple times in a group of such servers 1500a, as a laptop computer 1500b, or as part of a rack server system 1500c.
[0079] Various implementations of the systems and techniques described herein can be realized in digital electronic and/or optical circuitry, integrated circuitry, specially designed ASICs (application specific integrated circuits), computer hardware, firmware, software, and/or combinations thereof. These various implementations can include implementation in one or more computer programs that are executable and/or interpretable on a programmable system including at least one programmable processor, which may be special or general purpose, coupled to receive data and instructions from, and to transmit data and instructions to, a storage system, at least one input device, and at least one output device.
[0080] These computer programs (also known as programs, software, software applications or code) include machine instructions for a programmable processor, and can be implemented in a high-level procedural and/or object-oriented programming language, and/or in assembly/machine language. As used herein, the terms machine-readable medium and computer-readable medium refer to any computer program product, non-transitory computer readable medium, apparatus and/or device (e.g., magnetic discs, optical disks, memory, Programmable Logic Devices (PLDs)) used to provide machine instructions and/or data to a programmable processor, including a machine-readable medium that receives machine instructions as a machine-readable signal. The term machine-readable signal refers to any signal used to provide machine instructions and/or data to a programmable processor.
[0081] The processes and logic flows described in this specification can be performed by one or more programmable processors, also referred to as data processing hardware, executing one or more computer programs to perform functions by operating on input data and generating output. The processes and logic flows can also be performed by special purpose logic circuitry, e.g., an FPGA (field programmable gate array) or an ASIC (application specific integrated circuit). Processors suitable for the execution of a computer program include, by way of example, both general and special purpose microprocessors, and any one or more processors of any kind of digital computer. Generally, a processor will receive instructions and data from a read only memory or a random access memory or both. The essential elements of a computer are a processor for performing instructions and one or more memory devices for storing instructions and data. Generally, a computer will also include, or be operatively coupled to receive data from or transfer data to, or both, one or more mass storage devices for storing data, e.g., magnetic, magneto optical disks, or optical disks. However, a computer need not have such devices. Computer readable media suitable for storing computer program instructions and data include all forms of non-volatile memory, media and memory devices, including by way of example semiconductor memory devices, e.g., EPROM, EEPROM, and flash memory devices; magnetic disks, e.g., internal hard disks or removable disks; magneto optical disks; and CD ROM and DVD-ROM disks. The processor and the memory can be supplemented by, or incorporated in, special purpose logic circuitry.
[0082] To provide for interaction with a user, one or more aspects of the disclosure can be implemented on a computer having a display device, e.g., a CRT (cathode ray tube), LCD (liquid crystal display) monitor, or touch screen for displaying information to the user and optionally a keyboard and a pointing device, e.g., a mouse or a trackball, by which the user can provide input to the computer. Other kinds of devices can be used to provide interaction with a user as well; for example, feedback provided to the user can be any form of sensory feedback, e.g., visual feedback, auditory feedback, or tactile feedback; and input from the user can be received in any form, including acoustic, speech, or tactile input. In addition, a computer can interact with a user by sending documents to and receiving documents from a device that is used by the user; for example, by sending web pages to a web browser on a user's client device in response to requests received from the web browser.
[0083] A number of implementations have been described. Nevertheless, it will be understood that various modifications may be made without departing from the spirit and scope of the disclosure. Accordingly, other implementations are within the scope of the following claims.