Source code for multiScaleAnalysis.SegmentationHighres.gradient_watershed.metrics

# -*- coding: utf-8 -*-
"""
Created on Wed Feb 15 03:04:09 2023

@author: fyz11
"""

import numpy as np 

[docs] def metrics_np(y_true, y_pred, metric_name, metric_type='standard', drop_last = True, mean_per_class=False, verbose=False): """ Compute mean metrics of two segmentation masks, via numpy. Args: y_true: true masks, one-hot encoded. y_pred: predicted masks, either softmax outputs, or one-hot encoded. metric_name: metric to be computed, either 'iou' or 'dice'. metric_type: one of 'standard' (default), 'soft', 'naive'. In the standard version, y_pred is one-hot encoded and the mean is taken only over classes that are present (in y_true or y_pred). The 'soft' version of the metrics are computed without one-hot encoding y_pred. The 'naive' version return mean metrics where absent classes contribute to the class mean as 1.0 (instead of being dropped from the mean). drop_last = True: boolean flag to drop last class (usually reserved for background class in semantic segmentation) mean_per_class = False: return mean along batch axis for each class. verbose = False: print intermediate results such as intersection, union (as number of pixels). Returns: IoU/Dice of y_true and y_pred, as a float, unless mean_per_class == True in which case it returns the per-class metric, averaged over the batch. Inputs are B*W*H*N tensors, with B = batch size, W = width, H = height, N = number of classes """ assert y_true.shape == y_pred.shape, 'Input masks should be same shape, instead are {}, {}'.format(y_true.shape, y_pred.shape) assert len(y_pred.shape) == 4, 'Inputs should be B*W*H*N tensors, instead have shape {}'.format(y_pred.shape) flag_soft = (metric_type == 'soft') flag_naive_mean = (metric_type == 'naive') num_classes = y_pred.shape[-1] # if only 1 class, there is no background class and it should never be dropped drop_last = drop_last and num_classes>1 if not flag_soft: if num_classes>1: # get one-hot encoded masks from y_pred (true masks should already be in correct format, do it anyway) y_pred = np.array([ np.argmax(y_pred, axis=-1)==i for i in range(num_classes) ]).transpose(1,2,3,0) y_true = np.array([ np.argmax(y_true, axis=-1)==i for i in range(num_classes) ]).transpose(1,2,3,0) else: y_pred = (y_pred > 0).astype(int) y_true = (y_true > 0).astype(int) # intersection and union shapes are batch_size * n_classes (values = area in pixels) axes = (1,2) # W,H axes of each image intersection = np.sum(np.abs(y_pred * y_true), axis=axes) # or, np.logical_and(y_pred, y_true) for one-hot mask_sum = np.sum(np.abs(y_true), axis=axes) + np.sum(np.abs(y_pred), axis=axes) union = mask_sum - intersection # or, np.logical_or(y_pred, y_true) for one-hot if verbose: print('intersection (pred*true), intersection (pred&true), union (pred+true-inters), union (pred|true)') print(intersection, np.sum(np.logical_and(y_pred, y_true), axis=axes), union, np.sum(np.logical_or(y_pred, y_true), axis=axes)) smooth = .001 iou = (intersection + smooth) / (union + smooth) dice = 2*(intersection + smooth)/(mask_sum + smooth) metric = {'iou': iou, 'dice': dice}[metric_name] # define mask to be 0 when no pixels are present in either y_true or y_pred, 1 otherwise mask = np.not_equal(union, 0).astype(int) # mask = 1 - np.equal(union, 0).astype(int) # True = 1 if drop_last: metric = metric[:,:-1] mask = mask[:,:-1] # return mean metrics: remaining axes are (batch, classes) # if mean_per_class, average over batch axis only # if flag_naive_mean, average over absent classes too if mean_per_class: if flag_naive_mean: return np.mean(metric, axis=0) else: # mean only over non-absent classes in batch (still return 1 if class absent for whole batch) return (np.sum(metric * mask, axis=0) + smooth)/(np.sum(mask, axis=0) + smooth) else: if flag_naive_mean: return np.mean(metric) else: # mean only over non-absent classes class_count = np.sum(mask, axis=0) return np.mean(np.sum(metric * mask, axis=0)[class_count!=0]/(class_count[class_count!=0]))
[docs] def mean_iou_np(y_true, y_pred, **kwargs): """ Compute mean Intersection over Union of two segmentation masks, via numpy. Calls metrics_np(y_true, y_pred, metric_name='iou'), see there for allowed kwargs. """ return metrics_np(y_true, y_pred, metric_name='iou', **kwargs)
[docs] def mean_dice_np(y_true, y_pred, **kwargs): """ Compute mean Dice coefficient of two segmentation masks, via numpy. Calls metrics_np(y_true, y_pred, metric_name='dice'), see there for allowed kwargs. """ return metrics_np(y_true, y_pred, metric_name='dice', **kwargs)
""" functions for 2D cell comparison """ def _match_cells(labels1, labels2, com1, com2, K=10, bg_label=0): """ labels1 - ground truth labels2 - predicted labels com1 - center-of-mass for labels1 com2 - center-of-mass for labels2 K - # of nearest neighbor candidates. """ from scipy.optimize import linear_sum_assignment # brute force is the absolute gold-standard, but can we do this really fast? and without needing to do from sklearn.neighbors import NearestNeighbors uniq1 = np.setdiff1d(np.unique(labels1), bg_label) uniq2 = np.setdiff1d(np.unique(labels2), bg_label) n1 = len(uniq1) n2 = len(uniq2) # initialise matrix. sim_matrix = np.zeros((n1,n2)) dice_matrix = np.zeros((n1,n2)) # turn into numpy array. com1 = np.vstack(com1) com2 = np.vstack(com2) # nearest neighbor match on centroids, then bipartite matching on iou ! nbrs = NearestNeighbors(n_neighbors=K, algorithm='ball_tree').fit(com1) _, indices = nbrs.kneighbors(com2) # print(indices.shape) # print(len(com1)) for j in range(len(com2)): cand_i = indices[j] if len(cand_i) > 0: for i in range(len(cand_i)): mask1 = labels1 == uniq1[cand_i[i]] mask2 = labels2 == uniq2[j] intersection = np.sum(np.abs(mask1*mask2)) union = np.sum(mask1) + np.sum(mask2) - intersection # jaccard. overlap = intersection / float(union + 1e-8) # dice? dice = 2*intersection / float(np.sum(mask1) + np.sum(mask2) + 1e-8) # print(overlap) sim_matrix[cand_i[i],j] = np.clip(overlap, 0, 1) dice_matrix[cand_i[i],j] = np.clip(dice, 0, 1) # hungarian. ind_i, ind_j = linear_sum_assignment(1-sim_matrix) # need to reverse this (distance) iou_pair = sim_matrix[ind_i, ind_j].copy() dice_pair = dice_matrix[ind_i, ind_j].copy() valid = iou_pair>0 # must have non-zero overlap! ind_i = ind_i[valid>0].copy() ind_j = ind_j[valid>0].copy() iou_pair = iou_pair[valid>0].copy() dice_pair = dice_pair[valid>0].copy() # return at the end also the sim matrix. return ind_i, ind_j, iou_pair, dice_pair, sim_matrix """ move to segmentation? """
[docs] def remove_small_labelled_objects(labels, minsize=64): import skimage.measure as skmeasure import numpy as np unique_label_ = np.setdiff1d(np.unique(labels), 0) regprops = skmeasure.regionprops(labels) regareas = np.hstack([re.area for re in regprops]) labels_out = labels.copy() remove_reg = unique_label_[regareas<=minsize] if len(remove_reg)>0: for rr in remove_reg: labels_out[labels==rr] = 0 # set to bg return labels_out
[docs] def compute_metrics_cells(labels_true, labels_pred, bg_label=0, K=15, iou_thresh=0, eps=1e-5, debug_viz=False): """ processes a list of images. """ from skimage.measure import label from skimage.filters import threshold_otsu # import scipy.ndimage.measurements as scipy_measure import scipy.ndimage as ndimage import pylab as plt n_images = len(labels_true) # based on the overlap we assign and compute the AP based on the segmentations. # if thresh is not None: stats = [] # n_GT, n_Pred, n_match, overlap_score. match_props = [] for ii in range(n_images): label_true = labels_true[ii].copy() label_pred = labels_pred[ii].copy() unique_label_true = np.setdiff1d(np.unique(label_true), bg_label) unique_label_pred = np.setdiff1d(np.unique(label_pred), bg_label) if debug_viz: plt.figure() plt.imshow(label_pred) plt.figure() plt.imshow(label_true) plt.show() """ Use a nearest neighbour type matching to expedite the instance matching. """ com_true = ndimage.center_of_mass(label_true>0, labels=label_true, index=unique_label_true); com_true = np.vstack(com_true) com_pred = ndimage.center_of_mass(label_pred>0, labels=label_pred, index=unique_label_pred); com_pred = np.vstack(com_pred) # Solve the matching problem based on iou (using knn as a prefilter) gt_i, pred_j, iou_ij, dice_ij, iou_matrix = _match_cells(label_true, label_pred, com_true, com_pred, K=np.minimum(K, len(label_pred)), bg_label=bg_label) # gate if only restricting to matches above a given threshold!. val_index = iou_ij > iou_thresh gt_i = gt_i[val_index] pred_j = pred_j[val_index] iou_ij = iou_ij[val_index] dice_ij = dice_ij[val_index] match_dict = {'gt_index': gt_i , 'pred_index': pred_j, 'iou_gt_pred': iou_ij, 'dice_gt_pred': dice_ij, 'iou_matrix': iou_matrix, 'gt_com': com_true, 'pred_com': com_pred, 'matched_labels_gt_pred': [unique_label_true[gt_i], unique_label_pred[pred_j]]} """ Compute the stats of matching """ n_match = len(pred_j) n_GT = len(unique_label_true) n_Pred = len(unique_label_pred) pre = n_match/float(n_Pred + eps) rec = n_match/float(n_GT + eps) f1 = 2*pre*rec / (pre + rec) iou = np.mean(iou_ij) # mean_iou matrix. dice = np.mean(dice_ij) stats.append([n_GT, n_Pred, n_match, pre, rec, f1, iou, dice]) match_props.append(match_dict) return np.vstack(stats), match_props