Source code for fnet.transforms

from typing import Optional
import logging

import numpy as np
import scipy


logger = logging.getLogger(__name__)


[docs]class Normalize: def __init__(self, per_dim=None): """Class version of normalize function.""" self.per_dim = per_dim def __call__(self, x): return normalize(x, per_dim=self.per_dim) def __repr__(self): return 'Normalize({})'.format(self.per_dim)
[docs]class ToFloat: def __call__(self, x): return x.astype(np.float32) def __repr__(self): return 'ToFloat()'
[docs]def normalize(img, per_dim=None): """Subtract mean, set STD to 1.0 Parameters: per_dim: normalize along other axes dimensions not equal to per dim """ axis = tuple([i for i in range(img.ndim) if i != per_dim]) slices = tuple([slice(None) if i == per_dim else np.newaxis for i in range(img.ndim)]) # to handle broadcasting result = img.astype(np.float32) result -= np.mean(result, axis=axis)[slices] result /= np.std(result, axis=axis)[slices] return result
[docs]def do_nothing(img): return img.astype(np.float)
[docs]class Propper: """Padder + Cropper""" def __init__(self, action='-', **kwargs): self.action = action if self.action in ['+', 'pad']: self.transformer = Padder(**kwargs) elif self.action in ['-', 'crop']: self.transformer = Cropper(**kwargs) else: raise NotImplementedError def __repr__(self): return repr(self.transformer) def __call__(self, x_in): return self.transformer(x_in)
[docs] def undo_last(self, x_in): return self.transformer.undo_last(x_in)
[docs]class Padder(object): def __init__(self, padding='+', by=16, mode='constant'): """ padding: '+', int, sequence '+': pad dimensions up to multiple of "by" int: pad each dimension by this value sequence: pad each dimensions by corresponding value in sequence by: int for use with '+' padding option mode: str passed to numpy.pad function """ self.padding = padding self.by = by self.mode = mode self.pads = {} self.last_pad = None def __repr__(self): return 'Padder{}'.format((self.padding, self.by, self.mode)) def _calc_pad_width(self, shape_in): if isinstance(self.padding, (str, int)): paddings = (self.padding, )*len(shape_in) else: paddings = self.padding pad_width = [] for i in range(len(shape_in)): if isinstance(paddings[i], int): pad_width.append((paddings[i],)*2) elif paddings[i] == '+': padding_total = int(np.ceil(1.*shape_in[i]/self.by)*self.by) - shape_in[i] pad_left = padding_total//2 pad_right = padding_total - pad_left pad_width.append((pad_left, pad_right)) assert len(pad_width) == len(shape_in) return pad_width
[docs] def undo_last(self, x_in): """Crops input so its dimensions matches dimensions of last input to __call__.""" assert x_in.shape == self.last_pad['shape_out'] slices = [slice(a, -b) if (a, b) != (0, 0) else slice(None) for a, b in self.last_pad['pad_width']] return x_in[slices].copy()
def __call__(self, x_in): shape_in = x_in.shape pad_width = self.pads.get(shape_in, self._calc_pad_width(shape_in)) x_out = np.pad(x_in, pad_width, mode=self.mode) if shape_in not in self.pads: self.pads[shape_in] = pad_width self.last_pad = {'shape_in': shape_in, 'pad_width': pad_width, 'shape_out': x_out.shape} return x_out
[docs]class Cropper(object): def __init__(self, cropping='-', by=16, offset='mid', n_max_pixels=9732096, dims_no_crop=None): """Crop input array to given shape.""" self.cropping = cropping self.offset = offset self.by = by self.n_max_pixels = n_max_pixels self.dims_no_crop = [dims_no_crop] if isinstance(dims_no_crop, int) else dims_no_crop self.crops = {} self.last_crop = None def __repr__(self): return 'Cropper{}'.format((self.cropping, self.by, self.offset, self.n_max_pixels, self.dims_no_crop)) def _adjust_shape_crop(self, shape_crop): shape_crop_new = list(shape_crop) prod_shape = np.prod(shape_crop_new) idx_dim_reduce = 0 order_dim_reduce = list(range(len(shape_crop))[-2:]) # alternate between last two dimensions while prod_shape > self.n_max_pixels: dim = order_dim_reduce[idx_dim_reduce] if not (dim == 0 and shape_crop_new[dim] <= 64): shape_crop_new[dim] -= self.by prod_shape = np.prod(shape_crop_new) idx_dim_reduce += 1 if idx_dim_reduce >= len(order_dim_reduce): idx_dim_reduce = 0 value = tuple(shape_crop_new) return value def _calc_shape_crop(self, shape_in): croppings = (self.cropping, )*len(shape_in) if isinstance(self.cropping, (str, int)) else self.cropping shape_crop = [] for i in range(len(shape_in)): if (croppings[i] is None) or (self.dims_no_crop is not None and i in self.dims_no_crop): shape_crop.append(shape_in[i]) elif isinstance(croppings[i], int): shape_crop.append(shape_in[i] - croppings[i]) elif croppings[i] == '-': shape_crop.append(shape_in[i]//self.by*self.by) else: raise NotImplementedError if self.n_max_pixels is not None: shape_crop = self._adjust_shape_crop(shape_crop) self.crops[shape_in]['shape_crop'] = shape_crop return shape_crop def _calc_offsets_crop(self, shape_in, shape_crop): offsets = (self.offset, )*len(shape_in) if isinstance(self.offset, (str, int)) else self.offset offsets_crop = [] for i in range(len(shape_in)): offset = (shape_in[i] - shape_crop[i])//2 if offsets[i] == 'mid' else offsets[i] if offset + shape_crop[i] > shape_in[i]: logger.error(f'Cannot crop outsize image dimensions ({offset}:{offset + shape_crop[i]} for dim {i})') raise AttributeError offsets_crop.append(offset) self.crops[shape_in]['offsets_crop'] = offsets_crop return offsets_crop def _calc_slices(self, shape_in): shape_crop = self._calc_shape_crop(shape_in) offsets_crop = self._calc_offsets_crop(shape_in, shape_crop) slices = [slice(offsets_crop[i], offsets_crop[i] + shape_crop[i]) for i in range(len(shape_in))] self.crops[shape_in]['slices'] = slices return slices def __call__(self, x_in): shape_in = x_in.shape if shape_in in self.crops: slices = self.crops[shape_in]['slices'] else: self.crops[shape_in] = {} slices = self._calc_slices(shape_in) x_out = x_in[slices].copy() self.last_crop = {'shape_in': shape_in, 'slices': slices, 'shape_out': x_out.shape} return x_out
[docs] def undo_last(self, x_in): """Pads input with zeros so its dimensions matches dimensions of last input to __call__.""" assert x_in.shape == self.last_crop['shape_out'] shape_out = self.last_crop['shape_in'] slices = self.last_crop['slices'] x_out = np.zeros(shape_out, dtype=x_in.dtype) x_out[slices] = x_in return x_out
[docs]class Resizer(object): def __init__(self, factors, per_dim=None): """ Parameters: factors: tuple of resizing factors for each dimension of the input array per_dim: normalize along other axes dimensions not equal to per dim """ self.factors = factors self.per_dim = per_dim def __call__(self, x): if self.per_dim is None: return scipy.ndimage.zoom(x, (self.factors), mode='nearest') ars_resized = list() for idx in range(x.shape[self.per_dim]): slices = tuple([idx if i == self.per_dim else slice(None) for i in range(x.ndim)]) ars_resized.append(scipy.ndimage.zoom(x[slices], self.factors, mode='nearest')) return np.stack(ars_resized, axis=self.per_dim) def __repr__(self): return 'Resizer({:s}, {})'.format(str(self.factors), self.per_dim)
[docs]class Capper(object): def __init__(self, low=None, hi=None): self._low = low self._hi = hi def __call__(self, ar): result = ar.copy() if self._hi is not None: result[result > self._hi] = self._hi if self._low is not None: result[result < self._low] = self._low return result def __repr__(self): return 'Capper({}, {})'.format(self._low, self._hi)
[docs]def flip_y(ar: np.ndarray) -> np.ndarray: """Flip array along y axis. Array dimensions should end in YX. Parameters ---------- ar Input array to be flipped. Returns ------- np.ndarray Flipped array. """ return np.flip(ar, axis=-2)
[docs]def flip_x(ar: np.ndarray) -> np.ndarray: """Flip array along x axis. Array dimensions should end in YX. Parameters ---------- ar Input array to be flipped. Returns ------- np.ndarray Flipped array. """ return np.flip(ar, axis=-1)
[docs]def norm_around_center( ar: np.ndarray, z_center: Optional[int] = None, ): """Returns normalized version of input array. The array will be normalized with respect to the mean, std pixel intensity of the sub-array of length 32 in the z-dimension centered around the array's "z_center". Parameters ---------- ar Input 3d array to be normalized. z_center Z-index of cell centers. Returns ------- np.ndarray Nomralized array, dtype = float32 """ if ar.ndim != 3: raise ValueError('Input array must be 3d') if ar.shape[0] < 32: raise ValueError( 'Input array must be at least length 32 in first dimension' ) if z_center is None: z_center = ar.shape[0]//2 chunk_zlen = 32 z_start = z_center - chunk_zlen//2 if z_start < 0: z_start = 0 logger.warn(f'Warning: z_start set to {z_start}') if (z_start + chunk_zlen) > ar.shape[0]: z_start = ar.shape[0] - chunk_zlen logger.warn(f'Warning: z_start set to {z_start}') chunk = ar[z_start: z_start + chunk_zlen, :, :] ar = ar - chunk.mean() ar = ar/chunk.std() return ar.astype(np.float32)