# -*- coding: utf-8 -*-
"""
Created on Mon Feb 20 22:35:23 2023
@author: fyz11
"""
def _mkdir(directory):
import os
if not os.path.exists(directory):
os.makedirs(directory)
return []
def _normalize99(Y,lower=0.01,upper=99.99):
""" normalize image so 0.0 is 0.01st percentile and 1.0 is 99.99th percentile
Upper and lower percentile ranges configurable.
Parameters
----------
Y: ndarray, float
Component array of lenth N by L1 by L2 by ... by LN.
upper: float
upper percentile above which pixels are sent to 1.0
lower: float
lower percentile below which pixels are sent to 0.0
Returns
--------------
normalized array with a minimum of 0 and maximum of 1
"""
import numpy as np
X = Y.copy()
return np.interp(X, (np.percentile(X, lower), np.percentile(X, upper)), (0, 1))
def _interp2(query_pts, grid_shape, I_ref, method='linear', cast_uint8=False):
import numpy as np
from scipy.interpolate import RegularGridInterpolator
spl = RegularGridInterpolator((np.arange(grid_shape[0]),
np.arange(grid_shape[1])),
I_ref, method=method, bounds_error=False, fill_value=0)
I_query = spl((query_pts[...,0],
query_pts[...,1]))
if cast_uint8:
I_query = np.uint8(I_query)
return I_query
def _interp3(query_pts, grid_shape, I_ref, method='linear', cast_uint8=False):
from scipy.interpolate import RegularGridInterpolator
import numpy as np
spl_3 = RegularGridInterpolator((np.arange(grid_shape[0]),
np.arange(grid_shape[1]),
np.arange(grid_shape[2])),
I_ref, method=method, bounds_error=False, fill_value=0)
I_query = spl_3((query_pts[...,0],
query_pts[...,1],
query_pts[...,2]))
if cast_uint8:
I_query = np.uint8(I_query)
return I_query
# =============================================================================
# 2D stuff
# =============================================================================
[docs]
def connected_components_pts_2D( pts, pts0, shape,
smooth_sigma=1,
thresh_factor=None,
mask=None,
min_area=1) :
import numpy as np
import scipy.ndimage as ndimage
import skimage.measure as skmeasure
import skimage.segmentation as sksegmentation
# parse ...
votes_grid_acc = np.zeros(shape)
# count
votes_grid_acc[(pts[:,0]).astype(np.int),
(pts[:,1]).astype(np.int)] += 1. # add a vote.
# smooth to get a density (fast KDE estimation)
votes_grid_acc = ndimage.gaussian_filter(votes_grid_acc, sigma=smooth_sigma)
if thresh_factor is not None:
if mask is not None:
votes_grid_binary = votes_grid_acc >np.mean(votes_grid_acc[mask]) + thresh_factor*np.std(votes_grid_acc[mask])
else:
votes_grid_binary = votes_grid_acc >np.mean(votes_grid_acc) + thresh_factor*np.std(votes_grid_acc)
else:
votes_grid_binary = votes_grid_acc > np.mean(votes_grid_acc) # just threshold over the mean.
cell_seg_connected = skmeasure.label(votes_grid_binary, connectivity=1) # use the full conditional
cell_uniq_regions = np.setdiff1d(np.unique(cell_seg_connected),0)
if len(cell_uniq_regions)>0:
props = skmeasure.regionprops(cell_seg_connected)
areas = np.hstack([re.area for re in props])
invalid_areas = cell_uniq_regions[areas<=min_area]
for invalid in invalid_areas:
cell_seg_connected[cell_seg_connected==invalid] = 0
if cell_seg_connected.max() > 0:
cell_seg_connected = sksegmentation.relabel_sequential(cell_seg_connected)[0]
cell_seg_connected_original = np.zeros_like(cell_seg_connected)
cell_seg_connected_original[(pts0[:,0]).astype(np.int),
(pts0[:,1]).astype(np.int)] = cell_seg_connected[(pts[:,0]).astype(np.int),
(pts[:,1]).astype(np.int)]
if mask is not None:
cell_seg_connected[mask == 0] = 0
cell_seg_connected_original[mask==0] = 0 # also mask the predicted.
return cell_seg_connected_original, cell_seg_connected, votes_grid_acc # return the accumulator.!
[docs]
def connected_components_pts_3D( pts, pts0, shape,
smooth_sigma=1,
thresh_factor=None,
mask=None,
min_area=1) :
import numpy as np
import scipy.ndimage as ndimage
import skimage.measure as skmeasure
import skimage.segmentation as sksegmentation
# parse ...
votes_grid_acc = np.zeros(shape)
# count
votes_grid_acc[(pts[:,0]).astype(np.int),
(pts[:,1]).astype(np.int),
(pts[:,2]).astype(np.int)] += 1. # add a vote.
# smooth to get a density (fast KDE estimation)
votes_grid_acc = ndimage.gaussian_filter(votes_grid_acc, sigma=smooth_sigma)
if thresh_factor is not None:
if mask is not None:
votes_grid_binary = votes_grid_acc >np.mean(votes_grid_acc[mask]) + thresh_factor*np.std(votes_grid_acc[mask])
else:
votes_grid_binary = votes_grid_acc >np.mean(votes_grid_acc) + thresh_factor*np.std(votes_grid_acc)
else:
votes_grid_binary = votes_grid_acc > np.mean(votes_grid_acc) # just threshold over the mean.
cell_seg_connected = skmeasure.label(votes_grid_binary, connectivity=2) # use the full conditional
cell_uniq_regions = np.setdiff1d(np.unique(cell_seg_connected),0)
if len(cell_uniq_regions)>0:
props = skmeasure.regionprops(cell_seg_connected)
areas = np.hstack([re.area for re in props])
invalid_areas = cell_uniq_regions[areas<=min_area]
for invalid in invalid_areas:
cell_seg_connected[cell_seg_connected==invalid] = 0
if cell_seg_connected.max() > 0:
cell_seg_connected = sksegmentation.relabel_sequential(cell_seg_connected)[0]
cell_seg_connected_original = np.zeros_like(cell_seg_connected)
cell_seg_connected_original[(pts0[:,0]).astype(np.int),
(pts0[:,1]).astype(np.int),
(pts0[:,2]).astype(np.int)] = cell_seg_connected[(pts[:,0]).astype(np.int),
(pts[:,1]).astype(np.int),
(pts[:,2]).astype(np.int)]
if mask is not None:
cell_seg_connected[mask == 0] = 0
cell_seg_connected_original[mask==0] = 0 # also mask the predicted.
return cell_seg_connected_original, cell_seg_connected, votes_grid_acc # return the accumulator.!
def _sdf_distance_transform(binary, rev_sign=True):
import numpy as np
from scipy.ndimage import distance_transform_edt
# import skfmm
# import GeodisTK
pos_binary = binary.copy()
neg_binary = np.logical_not(pos_binary)
res = distance_transform_edt(neg_binary) * neg_binary - (distance_transform_edt(pos_binary) - 1) * pos_binary
# res = skfmm.distance(neg_binary, dx=0.5) * neg_binary - (skfmm.distance(pos_binary, dx=0.5) - 1) * pos_binary
# res = skfmm.distance(neg_binary) * neg_binary - (skfmm.distance(pos_binary) - 1) * pos_binary # this was fast!.
# res = geodesic_distance_2d((neg_binary*1.).astype(np.float32), S=neg_binary, lamb=0.8, iter=10) * neg_binary - (geodesic_distance_2d((pos_binary*1.).astype(np.float32), S=neg_binary, lamb=0.5, iter=10) - 1) * pos_binary
if rev_sign:
res = res * -1
return res
[docs]
def surf_normal_sdf(binary, return_sdf=True, smooth_gradient=None, eps=1e-12, norm_vectors=True):
import numpy as np
import scipy.ndimage as ndimage
sdf_vol = _sdf_distance_transform(binary, rev_sign=True) # so that we have it pointing outwards!.
# compute surface normal of the signed distance function.
sdf_vol_normal = np.array(np.gradient(sdf_vol))
# smooth gradient
if smooth_gradient is not None: # smoothing needs to be done before normalization of magnitude.
sdf_vol_normal = np.array([ndimage.gaussian_filter(sdf, sigma=smooth_gradient) for sdf in sdf_vol_normal])
if norm_vectors:
sdf_vol_normal = sdf_vol_normal / (np.linalg.norm(sdf_vol_normal, axis=0)[None,:]+eps)
return sdf_vol_normal, sdf_vol
[docs]
def mean_curvature_sdf(sdf_normal):
def divergence(f):
import numpy as np
"""
Computes the divergence of the vector field f, corresponding to dFx/dx + dFy/dy + ...
:param f: List of ndarrays, where every item of the list is one dimension of the vector field
:return: Single ndarray of the same shape as each of the items in f, which corresponds to a scalar field
"""
num_dims = len(f)
return np.ufunc.reduce(np.add, [np.gradient(f[i], axis=i) for i in range(num_dims)])
H = .5*(divergence(sdf_normal))# total curvature is the divergence of the normal.
return H
[docs]
def gradient_watershed2D_binary(binary,
gradient_img=None,
divergence_rescale=True,
smooth_sigma=1,
smooth_gradient=1,
delta=.5,
n_iter=10,
min_area=5,
eps=1e-20,
interp=True,
thresh_factor=None,
track_flow=True, # if track_flow then we record!.
mask=None,
debug_viz=False):
""" parses the instance level segmentation implicitly given as an input binary or a vector field.
The algorithm works as an inverse watershed.
Step 1: a grid of points is seeds on the image
Step 2: points are propagated for n_iter according to the gradient_img, condensing towards cell centers implicitly implied by the gradient image.
Step 3: individual cluster centers are found by binarisation and connected component, removing objects < min_area
result is an integer image the same size as binary.
Parameters
----------
binary : (MxNxL) numpy array
input binary image defining the voxels that need labeling
gradient_img : (MxNxLx3) numpy array
This is a gradient field such as that from applying np.array(np.gradient(img)).transpose(1,2,3,0) where img is a potential such as a distance transform or probability map.
divergence_rescale :
If True, the gradient_img is scaled by the divergence which is equivalent to the mean curvature, this helps to prevent early breakage for tube-like structures.
smooth_sigma : scalar
controls the catchment area for identifying distinct cells at the final propagation position. Smaller smooth_sigma leads to more oversegmentation.
smooth_gradient : scalar
the isotropic sigma value controlling the Gaussian smoothing of the gradient field. More smoothing results in more cells grouped together
delta: scalar
the voxel size to propagate grid points per iteration. Related to the stability. If too small takes too long. If too large, might not converge. if delta=1, takes a 1 voxel step.
n_iter: int
the number of iterations to run. (To do: monitor convergence and break early to improve speed)
min_area: scalar
volume of cells < min_area are removed.
eps: float
a small number for numerical stability
thresh_factor: scalar
The final cells are identified by thresholding on a threshold mean+thresh_factor*std. Thresh_factor controls what is an object prior to connected components analysis
mask: (MxNxL) numpy array
optional binary mask to gate the region to parse labels for.
debug_viz: bool
if True, visualise the position of the points at every algorithm iteration.
Returns
-------
cell_seg_connected_original : (MxNxL)
an integer image where each unique int > 0 relates to a unique object such that object 1 is retrieved by cell_seg_connected_original==1.
"""
import scipy.ndimage as ndimage
import numpy as np
import skimage.morphology as skmorph
import pylab as plt
import skimage.measure as skmeasure
import skimage.segmentation as sksegmentation
from tqdm import tqdm
# compute the signed distance transform
if gradient_img is not None:
sdf_normals = gradient_img.transpose(2,0,1) # use the supplied gradients!
sdf_normals = sdf_normals * binary[None,...]
else:
sdf_normals, sdf_binary = surf_normal_sdf(binary, return_sdf=True, smooth_gradient=smooth_gradient, eps=eps, norm_vectors=True)
sdf_normals = sdf_normals * binary[None,...]
if divergence_rescale:
# rescale the speed
curvature_2D = mean_curvature_sdf(sdf_normals/(np.linalg.norm(sdf_normals, axis=0)[None,...]+eps))
curvature_2D = _normalize99(curvature_2D) # rescales to a factor between 0-1
sdf_normals = sdf_normals * curvature_2D[None,...] # multiplicative factor rescaling
# print(sdf_normals.shape)
grid = np.zeros(binary.shape, dtype=np.int32)
pts = np.argwhere(binary>0) # (N,ndim)
tracks = [pts]
for ii in tqdm(np.arange(n_iter)):
pt_ii = tracks[-1].copy()
if interp:
pts_vect_ii = np.array([_interp2(pt_ii, binary.shape, I_ref=sdf_normals[ch], method='linear', cast_uint8=False) for ch in np.arange(len(sdf_normals))]).T
else:
pts_vect_ii = sdf_normals[:,np.rint(pt_ii[:,0]).astype(np.int64), np.rint(pt_ii[:,1]).astype(np.int64)].T
pts_vect_ii = pts_vect_ii / (np.linalg.norm(pts_vect_ii, axis=-1)[:,None] + eps)
pt_ii_next = pt_ii + delta*pts_vect_ii
pt_ii_next[:,0] = np.clip(pt_ii_next[:,0], 0, binary.shape[0]-1)
pt_ii_next[:,1] = np.clip(pt_ii_next[:,1], 0, binary.shape[1]-1)
if track_flow:
tracks.append(pt_ii_next)
else:
tracks[-1] = pt_ii_next.copy() # copy over.
if debug_viz:
plt.figure(figsize=(10,10))
plt.imshow(binary)
plt.plot(pt_ii_next[:,1], pt_ii_next[:,0], 'r.')
plt.show()
tracks = np.array(tracks)
cell_seg_connected_original, cell_seg_connected, votes_grid_acc = connected_components_pts_2D( pts=tracks[-1],
pts0=pts,
shape=binary.shape[:2],
smooth_sigma=smooth_sigma,
thresh_factor=thresh_factor,
mask=mask,
min_area=min_area)
return cell_seg_connected_original, cell_seg_connected, tracks, votes_grid_acc
[docs]
def gradient_watershed3D_binary(binary,
gradient_img=None,
divergence_rescale=True,
smooth_sigma=1,
smooth_gradient=1,
delta=1,
n_iter=100,
min_area=5,
eps=1e-12,
thresh_factor=None,
mask=None,
debug_viz=False):
""" parses the instance level segmentation implicitly given as an input binary or a vector field.
The algorithm works as an inverse watershed.
Step 1: a grid of points is seeds on the image
Step 2: points are propagated for n_iter according to the gradient_img, condensing towards cell centers implicitly implied by the gradient image.
Step 3: individual cluster centers are found by binarisation and connected component, removing objects < min_area
result is an integer image the same size as binary.
Parameters
----------
binary : (MxNxL) numpy array
input binary image defining the voxels that need labeling
gradient_img : (MxNxLx3) numpy array
This is a gradient field such as that from applying np.array(np.gradient(img)).transpose(1,2,3,0) where img is a potential such as a distance transform or probability map.
divergence_rescale :
If True, the gradient_img is scaled by the divergence which is equivalent to the mean curvature, this helps to prevent early breakage for tube-like structures.
smooth_sigma : scalar
controls the catchment area for identifying distinct cells at the final propagation position. Smaller smooth_sigma leads to more oversegmentation.
smooth_gradient : scalar
the isotropic sigma value controlling the Gaussian smoothing of the gradient field. More smoothing results in more cells grouped together
delta: scalar
the voxel size to propagate grid points per iteration. Related to the stability. If too small takes too long. If too large, might not converge. if delta=1, takes a 1 voxel step.
n_iter: int
the number of iterations to run. (To do: monitor convergence and break early to improve speed)
min_area: scalar
volume of cells < min_area are removed.
eps: float
a small number for numerical stability
thresh_factor: scalar
The final cells are identified by thresholding on a threshold mean+thresh_factor*std. Thresh_factor controls what is an object prior to connected components analysis
mask: (MxNxL) numpy array
optional binary mask to gate the region to parse labels for.
debug_viz: bool
if True, visualise the position of the points at every algorithm iteration.
Returns
-------
cell_seg_connected_original : (MxNxL)
an integer image where each unique int > 0 relates to a unique object such that object 1 is retrieved by cell_seg_connected_original==1.
"""
import scipy.ndimage as ndimage
import numpy as np
import skimage.morphology as skmorph
import pylab as plt
import skimage.measure as skmeasure
import skimage.segmentation as sksegmentation
from tqdm import tqdm
from .plotting import set_axes_equal
if gradient_img is not None:
sdf_normals = gradient_img.transpose(3,0,1,2) # use the supplied gradients!
sdf_normals = sdf_normals * binary[None,...]
else:
# compute the signed distance transform
sdf_normals, sdf_binary = surf_normal_sdf(binary, return_sdf=True, smooth_gradient=smooth_gradient, eps=eps, norm_vectors=True)
sdf_normals = sdf_normals * binary[None,...]
if divergence_rescale:
# rescale the speed
curvature_3D = mean_curvature_sdf(sdf_normals)
curvature_3D = _normalize99(curvature_3D, lower=0.01,upper=99) # rescales to a factor between 0-1
sdf_normals = sdf_normals * curvature_3D[None,...] # multiplicative factor rescaling
grid = np.zeros(binary.shape, dtype=np.int32)
pts = np.argwhere(binary>0) # (N,ndim)
tracks = [pts]
for ii in tqdm(np.arange(n_iter)):
pt_ii = tracks[-1].copy()
"""
interp helps!.
"""
pts_vect_ii = np.array([_interp3(pt_ii, binary.shape, I_ref=sdf_normals[ch], method='linear', cast_uint8=False) for ch in np.arange(len(sdf_normals))]).T
# pts_vect_ii = sdf_normals[:,
# pt_ii[...,0].astype(np.int32),
# pt_ii[...,1].astype(np.int32),
# pt_ii[...,2].astype(np.int32)].T # direct lookup - not interp!.
pts_vect_ii = pts_vect_ii / (np.linalg.norm(pts_vect_ii, axis=-1)[:,None] + 1e-20)
pt_ii_next = pt_ii + delta*pts_vect_ii
pt_ii_next[:,0] = np.clip(pt_ii_next[:,0], 0, binary.shape[0]-1)
pt_ii_next[:,1] = np.clip(pt_ii_next[:,1], 0, binary.shape[1]-1)
pt_ii_next[:,2] = np.clip(pt_ii_next[:,2], 0, binary.shape[2]-1)
tracks[-1] = pt_ii_next # overwrite
# plt.figure(figsize=(10,10))
# plt.imshow(binary.max(axis=0))
# plt.plot(pt_ii_next[:,2],
# pt_ii_next[:,1], 'r.')
# plt.show()
if debug_viz:
sampling = 100
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(111, projection='3d')
ax.set_proj_type('ortho') # this works better!.
ax.set_box_aspect(aspect = (1,1,1)) # this works.
# ax.scatter(v_watertight[::sampling,0],
# v_watertight[::sampling,1],
# v_watertight[::sampling,2],
# c='k', s=1, alpha=0.0)#all_labels_branches[np.squeeze(all_dists)<20], s=1)
ax.scatter(pt_ii_next[::sampling,0],
pt_ii_next[::sampling,1],
pt_ii_next[::sampling,2], c='r',s=1)
# ax.scatter(centroids3D_from_xz[:,0],
# centroids3D_from_xz[:,1],
# centroids3D_from_xz[:,2], c='g',s=10)
# ax.scatter(centroids3D_from_yz[:,0],
# centroids3D_from_yz[:,1],
# centroids3D_from_yz[:,2], c='b',s=10)
# # ax.scatter(skel3D_coords[:,0],
# # skel3D_coords[:,1],
# # skel3D_coords[:,2], c='k',s=5, alpha=1)
# ax.view_init(-90,0)
# ax.view_init(0,180)
ax.view_init(180,0)
# ax.set_xlim([0,binary.shape[0]]) # why is this plot not good?
# ax.set_ylim([0,binary.shape[1]])
# ax.set_zlim([0,binary.shape[2]])
set_axes_equal(ax)
plt.show()
tracks = np.array(tracks)
# parse ...
votes_grid_acc = np.zeros(binary.shape)
votes_grid_acc[(tracks[-1][:,0]).astype(np.int),
(tracks[-1][:,1]).astype(np.int),
(tracks[-1][:,2]).astype(np.int)] += 1. # add a vote.
# smooth to get a density (fast KDE estimation)
votes_grid_acc = ndimage.gaussian_filter(votes_grid_acc, sigma=smooth_sigma)
if thresh_factor is not None:
if mask is not None:
votes_grid_binary = votes_grid_acc >np.mean(votes_grid_acc[mask]) + thresh_factor*np.std(votes_grid_acc[mask])
else:
votes_grid_binary = votes_grid_acc >np.mean(votes_grid_acc) + thresh_factor*np.std(votes_grid_acc)
else:
votes_grid_binary = votes_grid_acc > np.mean(votes_grid_acc) # just threshold over the mean.
cell_seg_connected = skmeasure.label(votes_grid_binary, connectivity=2)
cell_uniq_regions = np.setdiff1d(np.unique(cell_seg_connected),0)
if len(cell_uniq_regions)>0:
props = skmeasure.regionprops(cell_seg_connected)
areas = np.hstack([re.area for re in props])
invalid_areas = cell_uniq_regions[areas<=min_area]
for invalid in invalid_areas:
cell_seg_connected[cell_seg_connected==invalid] = 0
if cell_seg_connected.max() > 0:
cell_seg_connected = sksegmentation.relabel_sequential(cell_seg_connected)[0]
cell_seg_connected_original = np.zeros_like(cell_seg_connected)
cell_seg_connected_original[(pts[:,0]).astype(np.int),
(pts[:,1]).astype(np.int),
(pts[:,2]).astype(np.int)] = cell_seg_connected[(tracks[-1][:,0]).astype(np.int),
(tracks[-1][:,1]).astype(np.int),
(tracks[-1][:,2]).astype(np.int)]
# if mask is not None:
# cell_seg_connected[mask == 0] = 0
# plt.figure(figsize=(10,10))
# plt.imshow(cell_seg_connected.max(axis=0))
# plt.show()
# plt.figure(figsize=(10,10))
# plt.imshow(cell_seg_connected_original.max(axis=0))
# plt.show()
# return cell_seg_connected_original
# ah... didn't make it here?
cell_seg_connected_original, cell_seg_connected, votes_grid_acc = connected_components_pts_3D( pts=tracks[-1],
pts0=pts,
shape=binary.shape[:3],
smooth_sigma=smooth_sigma,
thresh_factor=thresh_factor,
mask=mask,
min_area=min_area)
# plt.figure(figsize=(10,10))
# plt.imshow(cell_seg_connected.max(axis=0))
# plt.show()
# plt.figure(figsize=(10,10))
# plt.imshow(cell_seg_connected_original.max(axis=0))
# plt.show()
return cell_seg_connected_original, cell_seg_connected, tracks, votes_grid_acc