Source code for fnet.data.bufferedpatchdataset

from collections import deque
from typing import List, Sequence, Union
import collections.abc
import logging

from tqdm import tqdm
import numpy as np
import torch


logger = logging.getLogger(__name__)
ArrayLike = Union[np.ndarray, torch.Tensor]


[docs]class BufferedPatchDataset: """Provides patches from items of a dataset. Parameters ---------- dataset Dataset object. patch_shape Shape of patch to be extracted from dataset items. buffer_size Size of buffer. buffer_switch_interval Number of patches provided between buffer item exchanges. Set to -1 to disable exchanges. shuffle_images Set to randomize order of dataset item insertion into buffer. """ def __init__( self, dataset: collections.abc.Sequence, patch_shape: Sequence[int] = (32, 64, 64), buffer_size: int = 1, buffer_switch_interval: int = -1, shuffle_images: bool = True, ): self.dataset = dataset self.patch_shape = patch_shape self.buffer_size = min(len(self.dataset), buffer_size) self.buffer_switch_interval = buffer_switch_interval self.shuffle_images = shuffle_images self.counter = 0 self.epochs = -1 # incremented to 0 when buffer initially filled self.buffer = deque() self.remaining_to_be_in_buffer = deque() self.buffer_history = [] for _ in tqdm(range(self.buffer_size), desc='Buffering images'): self.insert_new_element_into_buffer() def __iter__(self): return self def __next__(self): patch = self.get_random_patch() self.counter += 1 if (self.buffer_switch_interval > 0) and ( self.counter % self.buffer_switch_interval == 0 ): self.insert_new_element_into_buffer() return patch def _check_last_datum(self) -> None: """Checks last dataset item added to buffer.""" nd = len(self.patch_shape) idx_buf = self.buffer_history[-1] shape_spatial = None for idx_c, component in enumerate(self.buffer[-1]): if shape_spatial is None: shape_spatial = component.shape[-nd:] elif component.shape[-nd:] != shape_spatial: raise ValueError( f'Dataset item {idx_buf}, component {idx_c} shape ' f'{component.shape} incompatible with first component ' f'shape {self.buffer[-1][0].shape}' ) if ( nd > len(component.shape) or any( self.patch_shape[d] > shape_spatial[d] for d in range(nd) ) ): raise ValueError( f'Dataset item {idx_buf}, component {idx_c} shape ' f'{component.shape} incompatible with patch_shape ' f'{self.patch_shape}' )
[docs] def insert_new_element_into_buffer(self) -> None: """Inserts new dataset item into buffer. Returns ------- None """ if len(self.remaining_to_be_in_buffer) == 0: self.epochs += 1 self.remaining_to_be_in_buffer = deque(range(len(self.dataset))) if self.shuffle_images: np.random.shuffle(self.remaining_to_be_in_buffer) if len(self.buffer) >= self.buffer_size: self.buffer.popleft() new_datum_index = self.remaining_to_be_in_buffer.popleft() self.buffer_history.append(new_datum_index) self.buffer.append(self.dataset[new_datum_index]) logger.info(f'Added item {new_datum_index} into buffer') self._check_last_datum()
[docs] def get_random_patch(self) -> List[ArrayLike]: """Samples random patch from an item in the buffer. Let nd be the number of dimensions of the patch. If the item has more dimensions than the patch, then sampling will be from the last nd dimensions of the item. Returns ------- List[ArrayLike] Random patch sampled from a dataset item. """ nd = len(self.patch_shape) buffer_index = np.random.randint(len(self.buffer)) datum = self.buffer[buffer_index] shape_spatial = datum[0].shape[-nd:] patch = [] slices = None for part in datum: if slices is None: starts = np.array( [ np.random.randint(0, d - p + 1) for d, p in zip(shape_spatial, self.patch_shape) ] ) ends = starts + np.array(self.patch_shape) slices = tuple(slice(s, e) for s, e in zip(starts, ends)) # Pad slices with "slice(None)" if there are non-spatial dimensions slices_pad = (slice(None),)*(len(part.shape) - len(shape_spatial)) patch.append(part[slices_pad + slices]) return patch
[docs] def get_batch(self, batch_size: int) -> Sequence[torch.Tensor]: """Returns a batch of patches. Parameters ---------- batch_size Number of patches in batch. Returns ------- Sequence[torch.Tensor] Batch of patches. """ return tuple( torch.tensor(np.stack(part)) for part in zip( *[next(self) for _ in range(batch_size)] ) )
[docs] def get_buffer_history(self) -> List[int]: """Returns a list of indices of dataset elements inserted into the buffer. Returns ------- List[int] Indices of dataset elements. """ return self.buffer_history