Stochastic Segmentation Networks
An efficient probabilistic method for modelling aleatoric uncertainty with any image segmentation network architecture.
segmentation health medical-imaging mri computer-vision research paper notebook arxiv:2006.06015 code demo

In image segmentation, there is often more than one plausible solution for a given input. In medical imaging, for example, experts will often disagree about the exact location of object boundaries. Estimating this inherent uncertainty and predicting multiple plausible hypotheses is of great interest in many applications, yet this ability is lacking in most current deep learning methods. In this paper, we introduce stochastic segmentation networks (SSNs), an efficient probabilistic method for modeling aleatoric uncertainty with any image segmentation network architecture. In contrast to approaches that produce pixel-wise estimates, SSNs model joint distributions over entire label maps and thus can generate multiple spatially coherent hypotheses for a single image. By using a low-rank multivariate normal distribution over the logit space to model the probability of the label map given the image, we obtain a spatially consistent probability distribution that can be efficiently computed by a neural network without any changes to the underlying architecture. We tested our method on the segmentation of real-world medical data, including lung nodules in 2D CT and brain tumours in 3D multimodal MRI scans. SSNs outperform state-of-the-art for modeling correlated uncertainty in ambiguous images while being much simpler, more flexible, and more efficient.

Don't forget to tag @MiguelMonteiro in your comment, otherwise they may not be notified.

Authors community post
Share this project
Similar projects
Medical Zoo - 3D Multi-modal Medical Image Segmentation
My articles on deep learning in medical imaging
A comprehensive healthcare conversational agent powered by Visual QA and segmentation models.
Pixellib is a library for performing segmentation of images.
MedicalZoo PyTorch
A pytorch-based deep learning framework for multi-modal 2D/3D medical image segmentation
Top collections