Source code for fnet.fnet_model

"""Module to define main fnet model wrapper class."""


from pathlib import Path
from typing import Callable, Iterator, List, Optional, Sequence, Tuple, Union
import logging
import math
import os

from scipy.ndimage import zoom
import numpy as np
import tifffile
import torch

from fnet.metrics import corr_coef
from fnet.predict_piecewise import predict_piecewise as _predict_piecewise_fn
from fnet.transforms import flip_y, flip_x, norm_around_center
from fnet.utils.general_utils import get_args, retry_if_oserror, str_to_object
from fnet.utils.model_utils import move_optim


LOGGER = logging.getLogger(__name__)


def _weights_init(m):
    classname = m.__class__.__name__
    if classname.startswith('Conv'):
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


[docs]def get_per_param_options(module, wd): """Returns list of per parameter group options. Applies the specified weight decay (wd) to parameters except parameters within batch norm layers and bias parameters. """ if wd == 0: return module.parameters() with_decay = list() without_decay = list() for idx_m, (name_m, module_sub) in enumerate(module.named_modules()): if list(module_sub.named_children()): continue # Skip "container" modules if isinstance(module_sub, torch.nn.modules.batchnorm._BatchNorm): for param in module_sub.parameters(): without_decay.append(param) continue for name_param, param in module_sub.named_parameters(): if 'weight' in name_param: with_decay.append(param) elif 'bias' in name_param: without_decay.append(param) # Check that no parameters were missed or duplicated n_param_module = len(list(module.parameters())) n_param_lists = len(with_decay) + len(without_decay) n_elem_module = sum([p.numel() for p in module.parameters()]) n_elem_lists = sum([p.numel() for p in with_decay + without_decay]) assert n_param_module == n_param_lists assert n_elem_module == n_elem_lists per_param_options = [ { 'params': with_decay, 'weight_decay': wd, }, { 'params': without_decay, 'weight_decay': 0.0, }, ] return per_param_options
[docs]class Model: """Class that encompasses a pytorch network and its optimizer. """ def __init__( self, betas=(0.5, 0.999), criterion_class='fnet.losses.WeightedMSE', init_weights=True, lr=0.001, nn_class='fnet.nn_modules.fnet_nn_3d.Net', nn_kwargs={}, scheduler=None, weight_decay=0, gpu_ids=-1, ): self.betas = betas self.criterion = str_to_object(criterion_class)() self.gpu_ids = [gpu_ids] if isinstance(gpu_ids, int) else gpu_ids self.init_weights = init_weights self.lr = lr self.nn_class = nn_class self.nn_kwargs = nn_kwargs self.scheduler = scheduler self.weight_decay = weight_decay self.count_iter = 0 self.device = ( torch.device('cuda', self.gpu_ids[0]) if self.gpu_ids[0] >= 0 else torch.device('cpu') ) self.optimizer = None self._init_model() self.fnet_model_kwargs, self.fnet_model_posargs = get_args() self.fnet_model_kwargs.pop('self') def _init_model(self): self.net = str_to_object(self.nn_class)( **self.nn_kwargs ) if self.init_weights: self.net.apply(_weights_init) self.net.to(self.device) self.optimizer = torch.optim.Adam( get_per_param_options( self.net, wd=self.weight_decay ), lr=self.lr, betas=self.betas, ) if self.scheduler is not None: if self.scheduler[0] == 'snapshot': period = self.scheduler[1] self.scheduler = torch.optim.lr_scheduler.LambdaLR( self.optimizer, lambda x: (0.01 + (1 - 0.01)*(0.5 + 0.5*math.cos(math.pi*(x % period)/period))), ) elif self.scheduler[0] == 'step': step_size = self.scheduler[1] self.scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, step_size ) else: raise NotImplementedError def __str__(self): out_str = [ f'*** {self.__class__.__name__} ***', f'{self.nn_class}(**{self.nn_kwargs})', f'iter: {self.count_iter}', f'gpu: {self.gpu_ids}', ] return os.linesep.join(out_str)
[docs] def get_state(self): return { 'fnet_model_class': (self.__module__ + '.' + self.__class__.__qualname__), 'fnet_model_kwargs': self.fnet_model_kwargs, 'fnet_model_posargs': self.fnet_model_posargs, 'nn_state': self.net.state_dict(), 'optimizer_state': self.optimizer.state_dict(), 'count_iter': self.count_iter, }
[docs] def to_gpu(self, gpu_ids: Union[int, List[int]]) -> None: """Move network to specified GPU(s). Parameters ---------- gpu_ids GPU(s) on which to perform training or prediction. """ if isinstance(gpu_ids, int): gpu_ids = [gpu_ids] self.gpu_ids = gpu_ids self.device = ( torch.device('cuda', self.gpu_ids[0]) if self.gpu_ids[0] >= 0 else torch.device('cpu') ) self.net.to(self.device) if self.optimizer is not None: move_optim(self.optimizer, self.device)
[docs] def save(self, path_save: str): """Saves model to disk. Parameters ---------- path_save Filename to which model is saved. """ dirname = os.path.dirname(path_save) if not os.path.exists(dirname): os.makedirs(dirname) LOGGER.info('Created: {dirname}') curr_gpu_ids = self.gpu_ids self.to_gpu(-1) retry_if_oserror(torch.save)(self.get_state(), path_save) self.to_gpu(curr_gpu_ids)
[docs] def load_state(self, state: dict, no_optim: bool = False): self.count_iter = state['count_iter'] self.net.load_state_dict(state['nn_state']) if no_optim: self.optimizer = None return self.optimizer.load_state_dict(state['optimizer_state'])
[docs] def train_on_batch( self, x_batch: torch.Tensor, y_batch: torch.Tensor, weight_map_batch: Optional[torch.Tensor] = None, ) -> float: """Update model using a batch of inputs and targets. Parameters ---------- x_batch Batched input. y_batch Batched target. weight_map_batch Optional batched weight map. Returns ------- float Loss as determined by self.criterion. """ if self.scheduler is not None: self.scheduler.step() self.net.train() x_batch = x_batch.to(dtype=torch.float32, device=self.device) y_batch = y_batch.to(dtype=torch.float32, device=self.device) if len(self.gpu_ids) > 1: module = torch.nn.DataParallel(self.net, device_ids=self.gpu_ids) else: module = self.net self.optimizer.zero_grad() y_hat_batch = module(x_batch) args = [y_hat_batch, y_batch] if weight_map_batch is not None: args.append(weight_map_batch) loss = self.criterion(*args) loss.backward() self.optimizer.step() self.count_iter += 1 return loss.item()
def _predict_on_batch_tta(self, x_batch: torch.Tensor) -> torch.Tensor: """Performs model prediction using test-time augmentation.""" augs = [ None, [flip_y], [flip_x], [flip_y, flip_x], ] x_batch = x_batch.numpy() y_hat_batch_mean = None for aug in augs: x_batch_aug = x_batch.copy() if aug is not None: for trans in aug: x_batch_aug = trans(x_batch_aug) y_hat_batch = self.predict_on_batch(x_batch_aug.copy()).numpy() if aug is not None: for trans in aug: y_hat_batch = trans(y_hat_batch) if y_hat_batch_mean is None: y_hat_batch_mean = np.zeros( y_hat_batch.shape, dtype=np.float32 ) y_hat_batch_mean += y_hat_batch y_hat_batch_mean /= len(augs) return torch.tensor( y_hat_batch_mean, dtype=torch.float32, device=torch.device('cpu') )
[docs] def predict_on_batch(self, x_batch: torch.Tensor) -> torch.Tensor: """Performs model prediction on a batch of data. Parameters ---------- x_batch Batch of input data. Returns ------- torch.Tensor Batch of model predictions. """ x_batch = torch.tensor( x_batch, dtype=torch.float32, device=self.device ) if len(self.gpu_ids) > 1: module = torch.nn.DataParallel(self.net, device_ids=self.gpu_ids) else: module = self.net module.eval() with torch.no_grad(): prediction_batch = module(x_batch).cpu() return prediction_batch
[docs] def predict( self, x: Union[torch.Tensor, np.ndarray], tta: bool = False, ) -> torch.Tensor: """Performs model prediction on a single example. Parameters ---------- x Input data. piecewise Set to perform piecewise predictions. i.e., predict on patches of the input and stitch together the predictions. tta Set to use test-time augmentation. Returns ------- torch.Tensor Model prediction. """ x_batch = torch.unsqueeze(torch.tensor(x), 0) if tta: return self._predict_on_batch_tta(x_batch).squeeze(0) return self.predict_on_batch(x_batch).squeeze(0)
[docs] def predict_piecewise( self, x: Union[torch.Tensor, np.ndarray], **predict_kwargs, ) -> torch.Tensor: """Performs model prediction piecewise on a single example. Predicts on patches of the input and stitchs together the predictions. Parameters ---------- x Input data. **predict_kwargs Kwargs to pass to predict method. Returns ------- torch.Tensor Model prediction. """ if isinstance(x, np.ndarray): x = torch.from_numpy(x) if len(x.size()) == 4: dims_max = [None, 32, 512, 512] elif len(x.size()) == 3: dims_max = [None, 1024, 1024] y_hat = _predict_piecewise_fn( self, x, dims_max=dims_max, overlaps=16, **predict_kwargs, ) return y_hat
[docs] def test_on_batch( self, x_batch: torch.Tensor, y_batch: torch.Tensor, weight_map_batch: Optional[torch.Tensor] = None, ) -> float: """Test model on a batch of inputs and targets. Parameters ---------- x_batch Batched input. y_batch Batched target. weight_map_batch Optional batched weight map. Returns ------- float Loss as evaluated by self.criterion. """ if len(self.gpu_ids) > 1: network = torch.nn.DataParallel(self.net, device_ids=self.gpu_ids) else: network = self.net network.eval() x_batch = x_batch.to(dtype=torch.float32, device=self.device) with torch.no_grad(): y_hat_batch = network(x_batch).cpu() args = [y_hat_batch, y_batch] if weight_map_batch is not None: args.append(weight_map_batch) loss = self.criterion(*args) network.train() return loss.item()
[docs] def test_on_iterator( self, iterator: Iterator, **kwargs: dict, ) -> float: """Test model on iterator which has items to be passed to test_on_batch. Parameters ---------- iterator Iterator that generates items to be passed to test_on_batch. kwargs Additional keyword arguments to be passed to test_on_batch. Returns ------- float Mean loss for items in iterable. """ loss_sum = 0 for item in iterator: loss_sum += self.test_on_batch(*item, **kwargs) return loss_sum/len(iterator)
[docs] def evaluate( self, x: torch.Tensor, y: torch.Tensor, metric: Optional = None, piecewise: bool = False, **kwargs, ) -> Tuple[float, torch.Tensor]: """Evaluates model output using a metric function. Parameters ---------- x Input data. y Target data. metric Metric function. If None, uses fnet.metrics.corr_coef. piecewise Set to perform predictions piecewise. **kwargs Additional kwargs to be passed to predict() method. Returns ------- float Evaluation as determined by metric function. torch.Tensor Model prediction. """ if metric is None: metric = corr_coef if piecewise: y_hat = self.predict_piecewise(x, **kwargs) else: y_hat = self.predict(x, **kwargs) if y is None: return None, y_hat evaluation = metric(y, y_hat) return evaluation, y_hat
[docs] def apply_on_single_zstack( self, input_img: Optional[np.ndarray] = None, filename: Optional[Union[Path, str]] = None, inputCh: Optional[int] = None, normalization: Optional[Callable] = None, already_normalized: bool = False, ResizeRatio: Optional[Sequence[float]] = None, cutoff: Optional[float] = None, ) -> np.ndarray: """Applies model to a single z-stack input. This assumes the loaded network architecture can receive 3d grayscale images as input. Parameters ---------- input_img 3d or 4d image with shape (Z, Y, X) or (C, Z, Y, X) respectively. filename Path to input image. Ignored if input_img is supplied. inputCh Selected channel if filename is a path to a 4d image. normalization Input image normalization function. already_normalized Set to skip input normalization. ResizeRatio Resizes each dimension of the the input image by the specified factor if specified. cutoff If specified, converts the output to a binary image with cutoff as threshold value. Returns ------- np.ndarray Predicted image with shape (Z, Y, X). If cutoff is set, dtype will be numpy.uint8. Otherwise, dtype will be numpy.float. Raises ------ ValueError If parameters are invalid. FileNotFoundError If specified file does not exist. IndexError If inputCh is invalid. """ if input_img is None: if filename is None: raise ValueError('input_img or filename must be specified') input_img = tifffile.imread(str(filename)) if inputCh is not None: if input_img.ndim != 4: raise ValueError('input_img must be 4d if inputCh specified') input_img = input_img[inputCh, ] if input_img.ndim != 3: raise ValueError('input_img must be 3d') normalization = normalization or norm_around_center if not already_normalized: input_img = normalization(input_img) if ResizeRatio is not None: if len(ResizeRatio) != 3: raise ValueError('ResizeRatio must be length 3') input_img = zoom(input_img, zoom=ResizeRatio, mode='nearest') yhat = ( self.predict_piecewise(input_img[np.newaxis, ], tta=True) .squeeze(dim=0) .numpy() ) if cutoff is not None: yhat = (yhat >= cutoff).astype(np.uint8)*255 return yhat