Medical image segmentation method based on U-Net

11580646 · 2023-02-14

Assignee

Inventors

Cpc classification

International classification

Abstract

A medical image segmentation method based on a U-Net, including: sending real segmentation image and original image to a generative adversarial network for data enhancement to generate a composite image with a label; then putting the composite image into original data set to obtain an expanded data set, and sending the expanded data set to improved multi-feature fusion segmentation network for training. A Dilated Convolution Module is added between the shallow and deep feature skip connections of the segmentation network to obtain receptive fields with different sizes, which enhances the fusion of detail information and deep semantics, improves the adaptability to the size of the segmentation target, and improves the medical image segmentation accuracy. The over-fitting problem that occurs when training the segmentation network is alleviated by using the expanded data set of the generative adversarial network.

Claims

1. A medical image segmentation method based on a U-Net, comprising following steps: step 1: selecting a medical image data set from existing medical image database; step 2: obtaining paired original image and manual segmentation label of a target area in the original image from the medical image data set; and generating, by the real segmentation map, a composite image based on a generator G; sending the composite image to a discriminator D for discrimination; and discriminating, by the discriminator D, whether the composite image comes from the medical image data set, and outputting a probability that the composite image comes from the medical image data set; step 3: importing the paired original image and the real segmentation map of the target area in the original image into a generative adversarial network to train the generative adversarial network to obtain a generator model; wherein a generative adversarial joint loss function of the generative adversarial network is:
custom character(G,D)=E.sub.x,y[log D(x,y)]+E.sub.x,z[log(1−D(x,G(x,z))], wherein, x is the real segmentation map of the original image, y is the original image, z is random noise, E[*] represents an expected value of a distribution function, and D(x,y) is an output probability of the discriminator D when input is x and y, G(x,z) is the composite image; increasing distance loss of L1 to constrain a difference between the composite image and the original image and reduce fuzz:
custom character.sub.L1(G)=E.sub.x,y,z[∥y−G(x,z)∥.sub.1]; step 4: using the generator model trained in step 3 to generate the composite image; wherein the composite image and the original image are used as an input data set of a multi-feature fusion segmentation network; and dividing the input data set into a training set and a testing set; step 5: training the multi-feature fusion segmentation network by using the input data set in step 4 to obtain a segmentation network model; wherein, in a decoding process of the multi-feature fusion segmentation network, each decoder layer is connected to a feature mapping of shallow and same layer from an encoder via an Dilated Convolution Module; step 6: inputting the original image to be segmented into the trained segmentation network model for segmentation to obtain an actual segmentation image output by the model.

2. The medical image segmentation method of claim 1, wherein, in step 3, the training of the generative adversarial network comprises a training of the generator G and a training of the discriminator D; a forward propagation and backward propagation of neural network are used to alternately train the discriminator D and the generator G by gradient descent method until a probability that the composite image generated by the generator G is a real image is identified by the discriminator D as 0.5, the training is completed, and the generator model and a discriminator model are obtained.

3. The medical image segmentation method of claim 1, wherein, in step 5, the multi-feature fusion segmentation network comprises feature extraction and resolution enhancement; the feature extraction comprises five convolutional blocks and four down-sampling, and the convolutional blocks are connected by the down-sampling; the resolution enhancement comprises four convolutional blocks connected by up-sampling.

4. The medical image segmentation method of claim 1, wherein, step 5 comprises following steps: in the multi-feature fusion segmentation network, setting the loss function as a set similarity measurement function, and a specific formula is Dice = 2 .Math. "\[LeftBracketingBar]" A .Math. B .Math. "\[RightBracketingBar]" .Math. "\[LeftBracketingBar]" A .Math. "\[RightBracketingBar]" + .Math. "\[LeftBracketingBar]" B .Math. "\[RightBracketingBar]" , wherein, |A∩B| represents common elements between set A and set B, |A| represents a number of elements in the set A, |B| represents a number of elements in the set B, and the set A is a segmented image obtained by segmenting the input data set by the multi-feature fusion segmentation network, elements in the set B are the real segmentation image of the target area in the original image; approximating the |A∩B| as a point multiplication between an actual segmented image and the real segmentation image to calculate a set similarity measurement function of a predicted real segmentation image; adding all element values in a result of the point multiplication; stopping training when the loss function is minimum, and obtaining the trained segmentation network model.

5. The medical image segmentation method of claim 3, wherein, the generator G is a codec structure, in which residual blocks of same level are skip connected in a manner like U-net; the generator G comprises 9 residual blocks, 2 down-sampled convolutional layers with a stride of 2, and 2 transposed convolutions; after all non-residual blocks, a batch normalization function and a Relu function are executed; the discriminator D uses Markovian discriminator model which is same as patchGAN.

6. The medical image segmentation method of claim 3, wherein, a connection sequence in the convolutional block is a 3×3 convolutional layer, a batch normalization layer, a Relu activation function, a 3×3 convolutional layer, a batch normalization layer, and a Relu activation function; each down-sampling uses a max-pooling with a stride of 2, a feature map size of the original image after the convolutional layer becomes half of a feature map size of the original image before down-sampling, and a number of feature map channels of the original image becomes twice a number of feature map channels of the original image before down-sampling; up-sampling uses bilinear interpolation to double a resolution of the feature map of the original image; in all convolutional layers, first and last convolutional layers use 7×7 convolutional kernels, and other convolutional layers use 3×3 convolutional kernels; the 7×7 convolutional kernels use separable convolution to reduce parameters of the segmentation network model and calculation amount of the segmentation network model.

7. The medical image segmentation method of claim 2, wherein, the step of slopping training when the loss function is minimized, and obtaining a trained segmentation network model, comprises following steps: initializing weight parameters of the multi-feature fusion segmentation network at each stage based on Adam optimizer, and randomly initializing the weight parameters by using a Gaussian distribution with an average value of 0; for each sample image in the training set input in the segmentation network model, calculating a total error between the real segmentation image and the real segmentation map of the target area in the original image by using the forward propagation first; then calculating partial derivative of each weight parameter by using the backward propagation of the neural network; and finally updating the weight parameter according to the gradient descent method; and repeating above steps to minimize the loss function to obtain the trained segmentation network model; wherein the sample image comprises the composite image and the original image.

8. The medical image segmentation method of claim 1, wherein the input data set used as the multi-feature fusion segmentation network together with the original image is divided into the training set and the testing set in a ratio of 7:3.

Description

BRIEF DESCRIPTION OF THE DRAWINGS

(1) FIG. 1 is a schematic diagram of a generative adversarial network in the method of the present disclosure;

(2) FIG. 2 is a schematic diagram of a generator in the generative adversarial network in the method of the present disclosure;

(3) FIG. 3 is a schematic diagram of a structure of a segmentation network in the method of the present disclosure;

(4) FIG. 4 is a schematic diagram of a Dilated Convolution Module for multi-feature fusion in the present disclosure.

DETAILED DESCRIPTION OF PREFERRED EMBODIMENTS

(5) The following embodiments are only used to illustrate the technical solutions of the present disclosure more clearly, and do not limit the protection scope of the present disclosure.

(6) The present disclosure provides a medical image segmentation method based on a U-Net, including following steps:

(7) step 1: selecting a medical image data set from existing medical image database;

(8) step 2: obtaining paired original image and real segmentation map of a target area in the original image from the medical image data set; and generating, by the real segmentation map, a composite image based on a generator G;

(9) sending the composite image to a discriminator D for discrimination; and discriminating, by the discriminator D, whether the composite image comes from the medical image data set, and outputting a probability that the composite image comes from the medical image data set;

(10) step 3: importing the paired original image and the real segmentation map of the target area in the original image into a generative adversarial network to train the generative adversarial network to obtain a generator model; wherein a generative adversarial joint loss function of the generative adversarial network is:
custom character(G,D)=E.sub.x,y[log D(x,y)]+E.sub.x,z[log(1−D(x,G(x,z))],

(11) wherein, x is the real segmentation map of the original image, y is the original image, z is random noise, E[*] represents an expected value of a distribution function, and D(x,y) is an output probability of the discriminator D when input is x and y, G(x,z) is the composite image;

(12) increasing distance loss of L1 to constrain a difference between the composite image and the original image and reduce fuzz:
custom character.sub.L1(G)=E.sub.x,y,z[∥y−G(x,z)∥.sub.1];

(13) step 4: using the generator model trained in step 3 to generate the composite image; wherein the composite image and the original image are used as an input data set of a multi-feature fusion segmentation network; and dividing the input data set into a training set and a testing set;

(14) step 5: training the multi-feature fusion segmentation network by using the input data set in step 4 to obtain a segmentation network model; wherein, in a decoding process of the multi-feature fusion segmentation network, each decoder layer is connected to a feature mapping of shallow and same layer from an encoder via an Dilated Convolution Module;

(15) step 6: inputting the original image to be segmented into the trained segmentation network model for segmentation to obtain a real segmentation image.

(16) Further, in step 3, the training of the generative adversarial network comprises a training of the generator G and a training of the discriminator D; a forward propagation and backward propagation of neural network are used to alternately train the discriminator D and the generator G by gradient descent method until a probability that the composite image generated by the generator G is a real image is identified by the discriminator D as 0.5, the training is completed, and the generator model and a discriminator model are obtained.

(17) Further, in step 5, the multi-feature fusion segmentation network comprises feature extraction and resolution enhancement; the feature extraction comprises five convolutional blocks and four down-sampling, and the convolutional blocks are connected by the down-sampling; the resolution enhancement comprises four convolutional blocks connected by up-sampling.

(18) Further, step 5 comprises following steps:

(19) in the multi-feature fusion segmentation network, setting a loss function as a set similarity measurement function, and a specific formula is

(20) Dice = 2 .Math. "\[LeftBracketingBar]" A .Math. B .Math. "\[RightBracketingBar]" .Math. "\[LeftBracketingBar]" A .Math. "\[RightBracketingBar]" + .Math. "\[LeftBracketingBar]" B .Math. "\[RightBracketingBar]" ,

(21) wherein, |A∩B| represents common elements between set A and set B, |A| represents a number of elements in the set A, |B| represents a number of elements in the set B, and the set A is a segmented image obtained by segmenting the input data set by the multi-feature fusion segmentation network, elements in the set B are the real segmentation image of the target area in the original image;

(22) approximating the |A∩B| as a point multiplication between an actual segmented image and the real segmentation image to calculate the set similarity measurement function of a predicted real segmentation image; adding all element values in a result of the point multiplication; stopping training when the loss function is minimized, and obtaining the trained segmentation network model.

(23) Further, the generator G is a codec structure, in which residual blocks of same layer are skip connected in a manner like U-net; the generator G comprises 9 residual blocks, 2 down-sampled convolutional layers with a stride of 2, and 3 transposed convolutions;

(24) after all non-residual blocks, a batch normalization function and a Relu function are executed; the discriminator D uses Markovian discriminator model which is same as patchGAN.

(25) Further, a connection order in the convolutional blocks is a 3×3 convolutional layer, a batch normalization layer, a Relu activation function, a 3×3 convolutional layer, a batch normalization layer, and a Relu activation function; each down-sampling uses a max-pooling with a stride of 2, a feature map size of the original image after the convolutional layer becomes half of a feature map size of the original image before down-sampling, and a number of feature map channels of the original image becomes twice of a number of feature map channels of the original image before down-sampling; up-sampling uses bilinear interpolation to double a resolution of the feature map of the original image;

(26) in all convolutional layers, first and last convolutional layers use 7×7 convolutional kernels, and other convolutional layers use 3×3 convolutional kernels; the 7×7 convolutional kernels use separable convolution to reduce parameters of the segmentation network model and calculation amount of the segmentation network model.

(27) Further, the step of stopping training when the loss function is minimized, and obtaining a trained segmentation network model, comprises following steps:

(28) initializing weight parameters of the multi-feature fusion segmentation network at each stage based on Adam optimizer, and randomly initializing the weight parameters by using a Gaussian distribution with an average value of 0;

(29) for each sample image in the training set input in the segmentation network model, calculating a total error between the real segmentation image and the real segmentation map of the target area in the original image by using the forward propagation first; then calculating partial derivative of each weight parameter by using the backward propagation of the neural network; and finally updating the weight parameter according to the gradient descent method; and repeating above steps to minimize the loss function to obtain the trained segmentation network model; wherein the sample image comprises the composite image and the original image.

(30) Further, the input data set used as the multi-feature fusion segmentation network together with the original image is divided into the training set and the testing set in a ratio of 7:3.

(31) Step 1, the medical image data set is obtained, wherein the medical image data set is a DRIVE fundus vascular data set.

(32) The medical image data set is downloaded from the existing medical image database. The website address is:

(33) https://aistudio.baidu.com/aistudio/projectdetail/462184.

(34) Paired original image of fundus blood vessels and a real segmentation map of a target area in the original image of fundus blood vessels are obtained from the medical image data set; and the composite image is generated by the real segmentation map based on the generator G.

(35) The composite image is sent to the discriminator D for discrimination; it is judged by the discriminator D whether the composite image comes from the medical image data set, and the probability that the composite image comes from the medical image data set is output. There are many methods that can be used for this step in the prior art, and this embodiment does not give examples one by one.

(36) Step 2, the paired original image of fundus blood vessels and the real segmentation map of the target area in the original image of fundus brood vessels are extracted from the DRIVE fundus vascular data set and input into the generative adversarial network.

(37) The generative adversarial network uses pix2pix algorithm. A real segmented label image x is used as the input of the generator G to obtain the generated image G(x), and then the G(x) and x are merged together based on the channel dimension, and finally used as the input of the discriminator D to obtain a predicted probability value, the predicted probability value indicates whether the input is a pair of real images, the closer the probability value is to 1, the more certain that the input is a real pair of images. In addition, the real images y and x are also merged together based on the channel dimension, and used as the input of the discriminator D to obtain the probability prediction value.

(38) Step 3, the loss function is created. The joint loss function of generator and the discriminator is:
custom character(G,D)=E.sub.x,y[log D(x,y)]+E.sub.x,z[log(1−D(x,G(x,z))],

(39) wherein, x is the real segmentation map of the original image, y is the original image, z is random noise, E[*] represents an expected value of a distribution function, and D(x,y) is an output probability of the discriminator D when input is x and y, G(x,z) is the composite image.

(40) The distance loss of L1 is increased to constrain a difference between the composite image and the original image and reduce fuzz:
custom character.sub.L1(G)=E.sub.x,y,z[∥y−G(x,z)∥.sub.1];

(41) wherein, x is a segmentation label, y is a real fundus blood vessel image, and z is random noise. The dropout is used to generate the random noise.

(42) Total objective function is

(43) F = arg min G max D cGAN ( G , D ) + λℒ L 1 ( G ) .

(44) Step 4, the generative adversarial network is trained. The generative adversarial network adopts the pix2pix algorithm, and includes a generator G and a discriminator D. The generator G is in a codec structure, as shown in FIG. 2, is skip connected in a manner like U-net, and includes 9 residual blocks, 2 down-sampled convolutional layers with a stride of 2, and two transposed convolutions. The 9 residual blocks are connected in sequence, after all non-residual blocks, the batch normalization function and the Relu function are executed. The discriminator D uses Markovian discriminator model which is same as patchGAN. The Batch Normalization, is abbreviated as BN.

(45) All convolutional layers use 3×3 convolutional kernels except for the first and last layers that use 7×7 convolutional kernels. Wherein, the 7×7 convolutional kernel uses separable convolution to reduce model parameters and calculation amount.

(46) The input of the generator of the generative adversarial network is an image with labeled data set, and the output is the composite image of fundus blood vessel. The generator of the generative adversarial network is trained, number of iterations is M, and M is a positive integer of at least 400. The former learning rate is α, and the value of α is 0<α<0.01, and the latter learning rate decreases linearly.

(47) The discriminator for training the generative adversarial network uses the composite image of fundus blood vessel output by the generator of the generative adversarial network and the corresponding label as the input of the discriminator of the generative adversarial network. The discriminator discriminates whether the fake image output by the generator is a real image, and the discriminator of the generative adversarial network is trained, iterates N times, N is an even number of at least 300, the former learning rate is β, the value of β is 0<β<0.001, and the subsequent learning rate decreases linearly.

(48) The discriminator and generator are alternately trained until the probability that the fake image generated by the generator is discriminated by the discriminator as 0.5, the training ends, and the generator model and the discriminator model of the generative adversarial network are obtained.

(49) Step 5, the multi-feature fusion segmentation network is trained by using the input data set in step 4. The multi-feature fusion segmentation network includes feature extraction and resolution enhancement. In a decoding process of the multi-feature fusion segmentation network, each decoder layer is connected to a feature mapping of shallow and same layer from an encoder via an Dilated Convolution Module to obtain the detailed information of shallow layer of different receptive fields, and which is combined with the deep semantics to improve the segmentation accuracy of the segmentation target with different sizes. The trained generator is used to generate a composite fundus blood vessel image with input labels, which is added to the original data set; the expanded data set is divided into the testing set and a verification set at a ratio of 7:3, and the two sets are then input into the segmentation network of the present disclosure (as shown in FIG. 2).

(50) The specific connection order of the convolutional blocks is a 3×3 convolutional layer, a batch normalization layer, a Relu activation function, a 3×3 convolutional layer, a batch normalization layer and a Relu activation function. Each down-sampling uses a max-pooling with a stride of 2 to make the feature map size become half of the original image, and the number of feature map channels become twice of a number of feature map channels of the original image to compensate for the loss of information. Up-sampling uses bilinear interpolation to double the size of the image, that is, double the resolution.

(51) The four branches of the dilated convolution module use convolution kernels with different expansion rates to extract feature map information at different scales. Each branch ends with 1×1 convolution to control the number of feature map channels of each scale to achieve cross-channel fusion and information integration, and the feature mapping after stitching of different features is guaranteed to have the same dimension as the feature mapping of the input module.

(52) In a decoding process of the multi-feature fusion segmentation network, each decoder layer is connected to a feature mapping of shallow and same layer from an encoder via an Dilated Convolution Module to obtain the detailed information of shallow layer of different receptive fields, and combine with the deep semantics to improve the segmentation accuracy of the segmentation target with different sizes.

(53) On the original U-net basic network, an Dilated Convolution Module is added after each convolutional block in the encoding part. The specific connection of the Dilated Convolution Module is shown in FIG. 3, 3×3 convolutional kernels with different dilate rates are connected in parallel, the feature maps of different receptive fields are captured by each branch, and then the number of channels are adjusted by using 1×1 convolution, so that the input and output feature mapping of the module have the same dimensions, so as to ensure that the feature mapping are the same as those of the decoding part during fusion.

(54) In addition to the original skip connection of U-Net between the same stage, the connections between the stages of the decoding part and the lower or horizontal stages of the encoding part are increased to fuse deep semantic information and shallow detailed formation The connections can bridge the semantic gap caused by the large splicing span and keep more of the underlying information.

(55) The segmentation network includes feature extraction and revolution enhancement, the purpose of which is to reduce semantic gaps and fuse deep and shallow semantic features. The training is stopped when the loss function reaches a minimum.

(56) The step 5 is specifically implemented as follows: the loss function of segmentation network based on multi-feature fusion is set up.

(57) In the segmentation network part, the loss function is set as the dice coefficient commonly used in medicine, and the specific formula is:

(58) Dice = 2 .Math. "\[LeftBracketingBar]" A .Math. B .Math. "\[RightBracketingBar]" .Math. "\[LeftBracketingBar]" A .Math. "\[RightBracketingBar]" + .Math. "\[LeftBracketingBar]" B .Math. "\[RightBracketingBar]" ,

(59) wherein, |A∩B| represents common elements between set A and set B, |A| represents a number of elements in the set A, |B| represents a number of elements in the set B, and the elements in the set A is the segmented image obtained by segmenting the input data set by the multi-feature fusion segmentation network, elements in the set B are the real segmentation image of the target area in the original image.

(60) In order to calculate the set similarity measurement function of the predicted real segmentation image, the |A|+|B| is approximated as a point multiplication between an actual segmented image and the real segmentation image; and all element values in the set A and set B are added. When the loss function is minimized, the training is stopped to obtain the trained segmentation network model. In order to calculate the dice coefficient of the predicted segmentation image, |A|+|B| is approximated as the point multiplication between the predicted image and the label, and the elements in the set A and the set B are added.

(61) Step 7, the segmentation network is trained.

(62) In order to minimize the loss function in step 5, the Adam optimizer is used to initialize the weight parameters of the network in each stage first, and the weight parameters are randomly initialized with a Gaussian distribution with an average value of 0.

(63) For each sample image x, the forward propagation is first used to calculate the total error, and then the back propagation is used to calculate the partial derivative of each weight parameter; finally, the weight parameters are updated according to the gradient descent method. This step is repeated until the loss function reaches the minimum, and a trained segmentation network model is obtained.

(64) Step 8, the fundus blood vessel image to be segmented is input into the segmentation network of the present disclosure to obtain a segmented fundus blood vessel image.

(65) When segmenting the fundus blood vessel data set, the generative adversarial network is used to expand the DRIVE data set. By training a generator that can generate fundus-like blood vessel images, the problem of inaccurate segmentation caused by over-fitting in the training process due to the small data set of medical images is alleviated. At the same time, in the process of generating fake images with segmentation labels, the one-to-one relationship between labels and images is maintained, which provides favorable conditions for final evaluation. By improving the original U-Net structure, the present disclosure solves the problem of loss of shallow detail information in the down-sampling process. The increased multi-scale Dilated Convolution Module improves the fusion of deep and shallow semantics, reduces the semantic gap. The segmentation targets of different scales are extracted effectively, which improves the segmentation accuracy of foreground and background of medical images.

(66) The above are only the preferred embodiments of the present disclosure. It should be pointed out that for those of ordinary skill in the art, without departing from the technical principles of the present disclosure, several improvements and modifications can be made. These improvements and modifications should also fall within protection scope of the present disclosure.