from typing import Optional
import logging
import numpy as np
import scipy
logger = logging.getLogger(__name__)
[docs]class Normalize:
def __init__(self, per_dim=None):
"""Class version of normalize function."""
self.per_dim = per_dim
def __call__(self, x):
return normalize(x, per_dim=self.per_dim)
def __repr__(self):
return 'Normalize({})'.format(self.per_dim)
[docs]class ToFloat:
def __call__(self, x):
return x.astype(np.float32)
def __repr__(self):
return 'ToFloat()'
[docs]def normalize(img, per_dim=None):
"""Subtract mean, set STD to 1.0
Parameters:
per_dim: normalize along other axes dimensions not equal to per dim
"""
axis = tuple([i for i in range(img.ndim) if i != per_dim])
slices = tuple([slice(None) if i == per_dim else np.newaxis for i in range(img.ndim)]) # to handle broadcasting
result = img.astype(np.float32)
result -= np.mean(result, axis=axis)[slices]
result /= np.std(result, axis=axis)[slices]
return result
[docs]def do_nothing(img):
return img.astype(np.float)
[docs]class Propper:
"""Padder + Cropper"""
def __init__(self, action='-', **kwargs):
self.action = action
if self.action in ['+', 'pad']:
self.transformer = Padder(**kwargs)
elif self.action in ['-', 'crop']:
self.transformer = Cropper(**kwargs)
else:
raise NotImplementedError
def __repr__(self):
return repr(self.transformer)
def __call__(self, x_in):
return self.transformer(x_in)
[docs] def undo_last(self, x_in):
return self.transformer.undo_last(x_in)
[docs]class Padder(object):
def __init__(self, padding='+', by=16, mode='constant'):
"""
padding: '+', int, sequence
'+': pad dimensions up to multiple of "by"
int: pad each dimension by this value
sequence: pad each dimensions by corresponding value in sequence
by: int
for use with '+' padding option
mode: str
passed to numpy.pad function
"""
self.padding = padding
self.by = by
self.mode = mode
self.pads = {}
self.last_pad = None
def __repr__(self):
return 'Padder{}'.format((self.padding, self.by, self.mode))
def _calc_pad_width(self, shape_in):
if isinstance(self.padding, (str, int)):
paddings = (self.padding, )*len(shape_in)
else:
paddings = self.padding
pad_width = []
for i in range(len(shape_in)):
if isinstance(paddings[i], int):
pad_width.append((paddings[i],)*2)
elif paddings[i] == '+':
padding_total = int(np.ceil(1.*shape_in[i]/self.by)*self.by) - shape_in[i]
pad_left = padding_total//2
pad_right = padding_total - pad_left
pad_width.append((pad_left, pad_right))
assert len(pad_width) == len(shape_in)
return pad_width
[docs] def undo_last(self, x_in):
"""Crops input so its dimensions matches dimensions of last input to __call__."""
assert x_in.shape == self.last_pad['shape_out']
slices = [slice(a, -b) if (a, b) != (0, 0) else slice(None) for a, b in self.last_pad['pad_width']]
return x_in[slices].copy()
def __call__(self, x_in):
shape_in = x_in.shape
pad_width = self.pads.get(shape_in, self._calc_pad_width(shape_in))
x_out = np.pad(x_in, pad_width, mode=self.mode)
if shape_in not in self.pads:
self.pads[shape_in] = pad_width
self.last_pad = {'shape_in': shape_in, 'pad_width': pad_width, 'shape_out': x_out.shape}
return x_out
[docs]class Cropper(object):
def __init__(self, cropping='-', by=16, offset='mid', n_max_pixels=9732096, dims_no_crop=None):
"""Crop input array to given shape."""
self.cropping = cropping
self.offset = offset
self.by = by
self.n_max_pixels = n_max_pixels
self.dims_no_crop = [dims_no_crop] if isinstance(dims_no_crop, int) else dims_no_crop
self.crops = {}
self.last_crop = None
def __repr__(self):
return 'Cropper{}'.format((self.cropping, self.by, self.offset, self.n_max_pixels, self.dims_no_crop))
def _adjust_shape_crop(self, shape_crop):
shape_crop_new = list(shape_crop)
prod_shape = np.prod(shape_crop_new)
idx_dim_reduce = 0
order_dim_reduce = list(range(len(shape_crop))[-2:]) # alternate between last two dimensions
while prod_shape > self.n_max_pixels:
dim = order_dim_reduce[idx_dim_reduce]
if not (dim == 0 and shape_crop_new[dim] <= 64):
shape_crop_new[dim] -= self.by
prod_shape = np.prod(shape_crop_new)
idx_dim_reduce += 1
if idx_dim_reduce >= len(order_dim_reduce):
idx_dim_reduce = 0
value = tuple(shape_crop_new)
return value
def _calc_shape_crop(self, shape_in):
croppings = (self.cropping, )*len(shape_in) if isinstance(self.cropping, (str, int)) else self.cropping
shape_crop = []
for i in range(len(shape_in)):
if (croppings[i] is None) or (self.dims_no_crop is not None and i in self.dims_no_crop):
shape_crop.append(shape_in[i])
elif isinstance(croppings[i], int):
shape_crop.append(shape_in[i] - croppings[i])
elif croppings[i] == '-':
shape_crop.append(shape_in[i]//self.by*self.by)
else:
raise NotImplementedError
if self.n_max_pixels is not None:
shape_crop = self._adjust_shape_crop(shape_crop)
self.crops[shape_in]['shape_crop'] = shape_crop
return shape_crop
def _calc_offsets_crop(self, shape_in, shape_crop):
offsets = (self.offset, )*len(shape_in) if isinstance(self.offset, (str, int)) else self.offset
offsets_crop = []
for i in range(len(shape_in)):
offset = (shape_in[i] - shape_crop[i])//2 if offsets[i] == 'mid' else offsets[i]
if offset + shape_crop[i] > shape_in[i]:
logger.error(f'Cannot crop outsize image dimensions ({offset}:{offset + shape_crop[i]} for dim {i})')
raise AttributeError
offsets_crop.append(offset)
self.crops[shape_in]['offsets_crop'] = offsets_crop
return offsets_crop
def _calc_slices(self, shape_in):
shape_crop = self._calc_shape_crop(shape_in)
offsets_crop = self._calc_offsets_crop(shape_in, shape_crop)
slices = [slice(offsets_crop[i], offsets_crop[i] + shape_crop[i]) for i in range(len(shape_in))]
self.crops[shape_in]['slices'] = slices
return slices
def __call__(self, x_in):
shape_in = x_in.shape
if shape_in in self.crops:
slices = self.crops[shape_in]['slices']
else:
self.crops[shape_in] = {}
slices = self._calc_slices(shape_in)
x_out = x_in[slices].copy()
self.last_crop = {'shape_in': shape_in, 'slices': slices, 'shape_out': x_out.shape}
return x_out
[docs] def undo_last(self, x_in):
"""Pads input with zeros so its dimensions matches dimensions of last input to __call__."""
assert x_in.shape == self.last_crop['shape_out']
shape_out = self.last_crop['shape_in']
slices = self.last_crop['slices']
x_out = np.zeros(shape_out, dtype=x_in.dtype)
x_out[slices] = x_in
return x_out
[docs]class Resizer(object):
def __init__(self, factors, per_dim=None):
"""
Parameters:
factors: tuple of resizing factors for each dimension of the input array
per_dim: normalize along other axes dimensions not equal to per dim
"""
self.factors = factors
self.per_dim = per_dim
def __call__(self, x):
if self.per_dim is None:
return scipy.ndimage.zoom(x, (self.factors), mode='nearest')
ars_resized = list()
for idx in range(x.shape[self.per_dim]):
slices = tuple([idx if i == self.per_dim else slice(None) for i in range(x.ndim)])
ars_resized.append(scipy.ndimage.zoom(x[slices], self.factors, mode='nearest'))
return np.stack(ars_resized, axis=self.per_dim)
def __repr__(self):
return 'Resizer({:s}, {})'.format(str(self.factors), self.per_dim)
[docs]class Capper(object):
def __init__(self, low=None, hi=None):
self._low = low
self._hi = hi
def __call__(self, ar):
result = ar.copy()
if self._hi is not None:
result[result > self._hi] = self._hi
if self._low is not None:
result[result < self._low] = self._low
return result
def __repr__(self):
return 'Capper({}, {})'.format(self._low, self._hi)
[docs]def flip_y(ar: np.ndarray) -> np.ndarray:
"""Flip array along y axis.
Array dimensions should end in YX.
Parameters
----------
ar
Input array to be flipped.
Returns
-------
np.ndarray
Flipped array.
"""
return np.flip(ar, axis=-2)
[docs]def flip_x(ar: np.ndarray) -> np.ndarray:
"""Flip array along x axis.
Array dimensions should end in YX.
Parameters
----------
ar
Input array to be flipped.
Returns
-------
np.ndarray
Flipped array.
"""
return np.flip(ar, axis=-1)
[docs]def norm_around_center(
ar: np.ndarray,
z_center: Optional[int] = None,
):
"""Returns normalized version of input array.
The array will be normalized with respect to the mean, std pixel intensity
of the sub-array of length 32 in the z-dimension centered around the
array's "z_center".
Parameters
----------
ar
Input 3d array to be normalized.
z_center
Z-index of cell centers.
Returns
-------
np.ndarray
Nomralized array, dtype = float32
"""
if ar.ndim != 3:
raise ValueError('Input array must be 3d')
if ar.shape[0] < 32:
raise ValueError(
'Input array must be at least length 32 in first dimension'
)
if z_center is None:
z_center = ar.shape[0]//2
chunk_zlen = 32
z_start = z_center - chunk_zlen//2
if z_start < 0:
z_start = 0
logger.warn(f'Warning: z_start set to {z_start}')
if (z_start + chunk_zlen) > ar.shape[0]:
z_start = ar.shape[0] - chunk_zlen
logger.warn(f'Warning: z_start set to {z_start}')
chunk = ar[z_start: z_start + chunk_zlen, :, :]
ar = ar - chunk.mean()
ar = ar/chunk.std()
return ar.astype(np.float32)