Source code for msd_pytorch.image_dataset

import collections
import torch
from torch.utils.data import Dataset
import numpy as np
import imageio
import tifffile
from pathlib import Path
import re
from .errors import InputError
import logging


def _natural_sort(l):
    def key(x):
        return [int(c) if c.isdigit() else c for c in re.split("([0-9]+)", x)]

    return sorted(l, key=key)


def _convert_to_integral(img):
    """Convert numpy array to integral value type.

    Handles Boolean arrays, signed and unsigned integer type arrays.

    :param img: A numpy array to convert.
    :returns:
    :rtype:

    """
    if img.dtype.kind == "u":
        return img
    elif img.dtype.kind == "i":
        logging.warning("Converting signed integer image to unsigned integer.")
        # The PyTorch segmentation losses require 64 bit labels for
        # some reason. We might as well convert the image to uint64
        # here.
        return img.astype(np.uint64)
    elif img.dtype.kind == "b":
        # convert boolean to unsigned integer.
        return img.astype(np.uint8)
    else:
        raise InputError(
            f"Image could not be converted to an integral value. Its type is {img.dtype}."
        )


def _relabel_image(img, labels):
    img = _convert_to_integral(img)

    if isinstance(labels, collections.Iterable):
        # Check for values in the image that are not in the label set:
        non_labels = set(np.unique(img)) - set(labels)
        if non_labels:
            raise InputError(
                f"Encountered unexpected values {non_labels} that are not in the label set."
            )
        # Relabel the image
        data = np.copy(img)
        for i, label in enumerate(labels):
            data[img == label] = i
            return data
    else:
        # Image values should be contained in [0, labels-1]. We check this and return the image.
        if img.min() < 0 or labels <= img.max():
            raise InputError(
                f"Image pixel value range {[img.min(), img.max()]} exceeded range {[0, labels - 1]}."
            )
        else:
            return img


def _load_natural_image(path):
    img = np.array(imageio.imread(path))

    # If the image is a gray-scale, RGB, or RGBA image. The channel
    # dimension will be last. We move it to the front.
    if img.ndim == 3 and img.shape[2] in [1, 3, 4]:
        return img.swapaxes(0, 2)
    else:
        return img


[docs]class ImageStack(object): """A stack of images stored on disk. An image stack describes a collection of images matching the file path specifier `path_specifier`. The images can be tiff files, or any other image filetype supported by imageio. The image paths are sorted using a natural sorting mechanism. So "scan1.tif" comes before "scan10.tif". Images can be retrieved by indexing into the stack. For example: ``ImageStack("*.tif")[i]`` These images are returned as torch tensors with three dimensions CxHxW. """
[docs] def __init__(self, path_specifier, *, collapse_channels=False, labels=None): """Create a new ImageStack. :param path_specifier: `string` A path with optional glob pattern describing the image file paths. Tildes and other HOME directory specifications are expanded with `expanduser` and symlinks are resolved. If the path points to a directory, then all files in the directory are included in the image stack. If the path points to file, then that single file is included in the image stack. Alternatively, one may specify a "glob pattern" to match specific files in the directory. Of course, if the glob pattern does not contain a '*', then it may match a single file. Examples: * ``"~/train_images/"`` * ``"~/train_images/cats*.png"`` * ``"~/train_images/*.tif"`` * ``"~/train_images/scan*"`` * ``"~/train_images/just_one_image.jpeg"`` :param collapse_channels: `bool` By default, the images are returned in the CxHxW format, where C is the number of channels and H and W specify the height and width, respectively. If `collapse_channels=True`, then all channels in the image will be averaged to a single channel. This can be used to convert color images to gray-scale images, for instance. If `collapse_channels=False`, any channels in the image will be retained. In either case, the returned images have at least one channel. :param labels: `int` or `list(int)` By default, all image pixel values are converted to float32. If you want to retrieve the image pixels as integral values instead, set * `labels=k` for an integer `k` if the labels are contained in the set {0, 1, ..., k-1}; * `labels=[1,2,5]` if the labels are contained in the set {1,2,5}. Setting labels is useful for segmentation. :returns: An ImageStack :rtype: """ super(ImageStack, self).__init__() path_specifier = Path(path_specifier).expanduser().resolve() self.path_specifier = path_specifier self.collapse_channels = collapse_channels self.labels = labels self.paths = ImageStack.find_images(path_specifier)
[docs] def find_images(path_specifier): path_specifier = Path(path_specifier) if path_specifier.name and "*" in path_specifier.name: paths = path_specifier.parent.glob(path_specifier.name) elif path_specifier.is_file(): paths = [path_specifier] logging.warning(f"Image stack consists of single file {path_specifier}") elif path_specifier.is_dir(): paths = path_specifier.glob("*") else: paths = [] paths = [str(p) for p in paths] paths = _natural_sort(paths) if len(paths) == 0: logging.warning( f"Image stack is empty for path specification {path_specifier}" ) return paths
@property def num_labels(self): """The number of labels in this image stack. If the stack is not labeled, this property access raises a RuntimeError. :returns: The number of labels in this image stack. :rtype: int """ if self.labels is None: raise RuntimeError("This image stack has no labels") elif isinstance(self.labels, collections.Iterable): return len(list(self.labels)) else: return int(self.labels) def __len__(self): return len(self.paths) def __getitem__(self, i): path = self.paths[i] # Load image try: if Path(path).suffix.lower() in [".tif", ".tiff"]: img = np.array(tifffile.imread(path)) else: img = _load_natural_image(path) except Exception as e: raise InputError(f"Could not read image from {path}. Got error {e}") # Convert image type if necessary: if self.labels is not None: try: img = _relabel_image(img, self.labels) except InputError as e: raise InputError( f"Expected labeled image from path {path}. Got error {e}" ) else: img = img.astype(np.float32) # Check and set image dimensions if img.ndim > 3: raise InputError(f"Image in {path} has more than 3 dimensions.") elif img.ndim < 2: raise InputError(f"Image in {path} has less than 2 dimensions.") # The # of dimensions can be 2 or 3 at this point. Make it 3. if img.ndim == 2: img = img[None, ...] # Collapse channels if necessary assert not ( self.labels is not None and self.collapse_channels ), "Cannot collapse channels of segmentation image stack." if self.collapse_channels: img = np.mean(img, axis=0, keepdims=True) img = torch.from_numpy(img) return img
[docs]class ImageDataset(Dataset): """A dataset for images stored on disk. """
[docs] def __init__( self, input_path_specifier, target_path_specifier, *, collapse_channels=False, labels=None, ): """Create a new image dataset. :param input_path_specifier: `string` A path with optional glob pattern describing the image file paths. Tildes and other HOME directory specifications are expanded with `expanduser` and symlinks are resolved. If the path points to a directory, then all files in the directory are included in the image stack. If the path points to file, then that single file is included in the image stack. Alternatively, one may specify a "glob pattern" to match specific files in the directory. Of course, if the glob pattern does not contain a '*', then it may match a single file. Examples: * ``"~/train_images/"`` * ``"~/train_images/cats*.png"`` * ``"~/train_images/*.tif"`` * ``"~/train_images/scan*"`` * ``"~/train_images/just_one_image.jpeg"`` :param target_path_specifier: `string` A pattern that describes the target data. Format is similar to the input path specification. :param collapse_channels: `bool` By default, the images are returned in the CxHxW format, where C is the number of channels and H and W specify the height and width, respectively. If `collapse_channels=True`, then all channels in the image will be averaged to a single channel. This can be used to convert color images to gray-scale images, for instance. If `collapse_channels=False`, any channels in the image will be retained. In either case, the returned images have at least one channel. :param labels: `int` or `list(int)` By default, both input and target image pixel values are converted to float32. If you want to retrieve the target image pixels as integral values instead, set: * ``labels=k`` for an integer ``k`` if the labels are contained in the set {0, 1, ..., k-1}; * ``labels=[1,2,5]`` if the labels are contained in the set {1,2,5}. Setting labels is useful for segmentation. :returns: :rtype: """ super(ImageDataset, self).__init__() self.input_path_specifier = input_path_specifier self.target_path_specifier = target_path_specifier self.collapse_channels = collapse_channels self.labels = labels # Do not collapse channels in the target images when we do # segmentation. This is not supported. collapse_target = collapse_channels and labels is None self.input_stack = ImageStack( input_path_specifier, collapse_channels=collapse_channels ) self.target_stack = ImageStack( target_path_specifier, collapse_channels=collapse_target, labels=labels ) if len(self.input_stack) != len(self.target_stack): raise InputError( f"Number of input and target images does not match. " f"Got {len(self.input_stack)} input images and {len(self.target_stack)} target images." )
def __len__(self): return len(self.input_stack) def __getitem__(self, i): return (self.input_stack[i], self.target_stack[i]) @property def num_labels(self): """The number of labels in this image stack. If the stack is not labeled, this property access raises a RuntimeError. :returns: The number of labels in this image stack. :rtype: int """ return self.target_stack.num_labels