# -*- coding: utf-8 -*-
"""
Axons
=====
Expert axon image processing pipeline.
This module provides the basic routines for processing axon data.
The routines are used in the :mod:`ClearMap.Scripts.AxonMap` pipeline.
"""
__author__ = 'Christoph Kirst <christoph.kirst.ck@gmail.com>'
__license__ = 'GPLv3 - GNU General Public License v3 (see LICENSE.txt)'
__copyright__ = 'Copyright © 2023 by Christoph Kirst'
import numpy as np
import gc
import scipy.ndimage as ndi
import skimage.filters as skif
import ClearMap.IO.IO as io
import ClearMap.ParallelProcessing.BlockProcessing as bp
import ClearMap.ParallelProcessing.DataProcessing.ArrayProcessing as ap
import ClearMap.ImageProcessing.Filter.Rank as rnk
import ClearMap.ImageProcessing.LocalStatistics as ls
import ClearMap.ImageProcessing.LightsheetCorrection as lc
import ClearMap.ImageProcessing.Differentiation.Hessian as hes
# import ClearMap.ImageProcessing.Binary.Filling as bf
# import ClearMap.ImageProcessing.Binary.Smoothing as bs
import ClearMap.Utils.Timer as tmr
import ClearMap.Utils.HierarchicalDict as hdict
########################################################################################################################
# Preprocess
########################################################################################################################
default_preprocess_parameter = dict(
# initial clipping and mask generation
clip=dict(
clip_range = (140,26000),
dtype = 'uint16'),
# lightsheet correction
lightsheet=dict(method='subtract',
percentile=0.25,
lightsheet=dict(selem=(150, 1, 1)),
background=dict(selem=(200, 200, 1),
spacing=(25, 25, 1),
step=(2, 2, 1),
interpolate=1),
lightsheet_vs_background=2.5,
lightsheet_minimum=1,
save=False),
# median
median=None,
# equalization
equalize=dict(percentile=(0.3, 0.975),
selem=(200, 200, 5),
spacing=(50, 50, 5),
interpolate=1,
save=False),
# general
dtype = 'float16',
max_bin = 2**12,
)
default_preprocess_processing_parameter = dict(
size_max=5,
size_min='fixed',
overlap=0,
axes=[2],
optimization=True,
optimization_fix='all',
verbose=None,
processes=None
)
[docs]
def preprocess(source, sink=None, parameter=None, processing_parameter=None):
parameter = parameter or default_preprocess_parameter
processing_parameter = processing_parameter or default_preprocess_processing_parameter
# initialize sink
shape = io.shape(source)
order = io.order(source)
dtype = parameter.get('dtype', default_preprocess_parameter['dtype'])
sink, sink_buffer = ap.initialize_sink(sink=sink, shape=shape, order=order, dtype=dtype, memory='automatic')
# output saving
for key in parameter.keys():
par = parameter[key]
if isinstance(par, dict):
filename = par.get('save', None)
if filename:
ap.initialize_sink(filename, shape=shape, order=order, dtype='float')
# higher level output saving
for key in par.keys():
par2 = par[key]
if isinstance(par2, dict):
filename = par2.get('save', None)
if filename:
ap.initialize_sink(filename, shape=shape, order=order, dtype='float')
parameter.update(verbose=processing_parameter.get('verbose', False))
bp.process(preprocess_block, source, sink, function_type='block', parameter=parameter, **processing_parameter)
return sink
[docs]
def preprocess_block(source, sink, parameter):
"""Preprocess a Block."""
# initialize parameter and slicing
verbose = parameter.get('verbose', False)
prefix = 'Block %s: ' % (source.info(),) if verbose else None
total_time = tmr.Timer(prefix) if verbose else None
max_bin = parameter.get('max_bin', default_preprocess_parameter['max_bin'])
base_slicing = sink.valid.base_slicing
valid_slicing = source.valid.slicing
# initialize binary status for inspection
status = parameter.get('status', None)
if status:
status = io.as_source(status)
status = status[base_slicing]
# clippingMAX_BIN
parameter_clip = parameter.get('clip', None)
if parameter_clip:
parameter_clip = parameter_clip.copy()
parameter_clip.update(norm=max_bin)
timer = tmr.Timer(prefix) if verbose else None
if verbose:
hdict.pprint(parameter_clip, head=prefix + 'Clipping:')
save = parameter_clip.pop('save', None)
clipped, mask = clip(source, **parameter_clip)
if save:
save = io.as_source(save)
save[base_slicing] = clipped[valid_slicing]
if verbose:
timer.print_elapsed_time('Clipping')
else:
clipped = source.as_buffer()
mask = None
# active arrays: clipped, mask
# lightsheet correction
parameter_lightsheet = parameter.get('lightsheet', None)
if parameter_lightsheet:
parameter_lightsheet = parameter_lightsheet.copy()
timer = tmr.Timer(prefix) if verbose else None
if verbose:
hdict.pprint(parameter_lightsheet, head=prefix + 'Lightsheet:')
lightsheet_method = parameter_lightsheet.pop('method', 'subtract')
# parameter_lightsheet.update(max_bin=max_bin);
save = parameter_lightsheet.pop('save', None)
if lightsheet_method == 'subtract':
parameter_lightsheet.pop('lightsheet_minimum', None)
corrected = lc.correct_lightsheet(clipped, mask=mask, max_bin=max_bin, **parameter_lightsheet)
else:
parameter_lightsheet.pop('lightsheet_vs_background', None)
corrected = lc.correct_lightsheet_divide(clipped, mask=mask, max_bin=max_bin, **parameter_lightsheet)
if save:
save = io.as_source(save)
save[base_slicing] = corrected[valid_slicing]
if verbose:
timer.print_elapsed_time('Lightsheet')
else:
corrected = clipped
del clipped
# active arrays: corrected, mask
# median filter
parameter_median = parameter.get('median', None)
if parameter_median:
parameter_median = parameter_median.copy()
timer = tmr.Timer(prefix) if verbose else None
if verbose:
hdict.pprint(parameter_median, head=prefix + 'Median:')
save = parameter_median.pop('save', None)
median = rnk.median(corrected, max_bin=max_bin, mask=mask, **parameter_median)
if save:
save = io.as_source(save)
save[base_slicing] = median[valid_slicing]
if verbose:
timer.print_elapsed_time('Median')
else:
median = corrected
del corrected
# active arrays: median, mask
# equalize
parameter_equalize = parameter.get('equalize', None)
if parameter_equalize:
parameter_equalize = parameter_equalize.copy()
timer = tmr.Timer(prefix) if verbose else None
if verbose:
hdict.pprint(parameter_equalize, head=prefix + 'Equalization:')
save = parameter_equalize.pop('save', None)
equalized = equalize(median, mask=mask, **parameter_equalize)
if save:
save = io.as_source(save)
save[base_slicing] = equalized[valid_slicing]
if verbose:
timer.print_elapsed_time('Equalization')
else:
equalized = median
del median, mask
# active arrays: equalized
sink[valid_slicing] = equalized[valid_slicing]
if verbose:
total_time.print_elapsed_time('Preprocessing')
gc.collect()
return None
[docs]
def clip(source, clip_range, norm, dtype):
"""Clips lower and rescales to higher clip range"""
clip_low, clip_high = clip_range
clipped = np.array(source[:], dtype='float16')
#clip low
low = clipped < clip_low
clipped[low] = clip_low
# clip high
clipped[clipped >= clip_high] = clip_high # not needed if clip high is large enough
clipped -= clip_low
clipped *= float(norm) / (clip_high - clip_low)
clipped = np.asarray(clipped, dtype=dtype)
return clipped, np.logical_not(low)
[docs]
def deconvolve(source, binarized, sigma=10):
convolved = np.zeros(source.shape, dtype=float)
convolved[binarized] = source[binarized]
for z in range(convolved.shape[2]):
convolved[:, :, z] = ndi.gaussian_filter(convolved[:, :, z], sigma=sigma)
deconvolved = source - np.minimum(source, convolved)
deconvolved[binarized] = source[binarized]
return deconvolved
[docs]
def threshold_isodata(source):
try:
thresholds = skif.threshold_isodata(source, return_all=True)
if len(thresholds) > 0:
return thresholds[-1]
else:
return 1
except:
return 1
[docs]
def threshold_adaptive(source, function=threshold_isodata, selem=(100, 100, 3), spacing=(25, 25, 3), interpolate=1,
mask=None, step=None):
source = io.as_source(source)[:]
threshold = ls.apply_local_function(source, function=function, mask=mask, dtype=float, selem=selem, spacing=spacing,
interpolate=interpolate, step=step)
return threshold
[docs]
def equalize(source, percentile=(0.5, 0.95), max_value=1.5, selem=(200, 200, 5), spacing=(50, 50, 5), interpolate=1,
mask=None):
equalized = ls.local_percentile(source, percentile=percentile, mask=mask, dtype=float, selem=selem, spacing=spacing,
interpolate=interpolate)
normalize = 1 / np.maximum(equalized[..., 0], 1)
maxima = equalized[..., 1]
ids = maxima * normalize > max_value
normalize[ids] = max_value / maxima[ids]
equalized = np.array(source, dtype=float) * normalize
return equalized
[docs]
def tubify(source, sigma=1.0, gamma12=1.0, gamma23=1.0, alpha=0.25):
return hes.lambda123(source=source, sink=None, sigma=sigma, gamma12=gamma12, gamma23=gamma23, alpha=alpha)
###############################################################################
### Helper
###############################################################################
[docs]
def status_to_description(status):
"""Converts a status int to its description.
Arguments
---------
status : int
The status.
Returns
-------
description : str
The description corresponding to the status.
"""
description = ''
for k in range(len(STATUS_NAMES) - 1, -1, -1):
if status / 2 ** k == 1:
description = STATUS_NAMES[k] + ',' + description
status -= 2 ** k
if len(description) == 0:
description = 'Background'
else:
description = description[:-1]
return description
[docs]
def binary_statistics(source):
"""Counts the binarization types.
Arguments
---------
source : array
The status array of the binarization process.
Returns
-------
statistics : dict
A dict with entires {description : count}.
"""
status, counts = np.unique(io.as_source(source)[:], return_counts=True)
return {status_to_description(s): c for s, c in zip(status, counts)}
###############################################################################
### Tests
###############################################################################
def _test():
"""Tests."""
import numpy as np
import ClearMap.Visualization.Plot3d as p3d
import ClearMap.Tests.Files as tsf
import ClearMap.ImageProcessing.Experts.Vasculature as vasc
source = np.array(tsf.source('vls')[:300, :300, 80:120])
source[:, :, [0, -1]] = 0
source[:, [0, -1], :] = 0
source[[0, -1], :, :] = 0
bpar = vasc.default_binarization_parameter.copy()
bpar['clip']['clip_range'] = (150, 7000)
bpar['as_memory'] = True
# bpar['binary_status'] = 'binary_status.npy'
ppar = vasc.default_processing_parameter.copy()
ppar['processes'] = 10
ppar['size_max'] = 10
sink = 'binary.npy'
# sink=None;
binary = vasc.binarize(source, sink=sink, binarization_parameter=bpar, processing_parameter=ppar)
p3d.plot([source, binary])
import ClearMap.IO.IO as io
io.delete_file(sink)
pppar = vasc.default_postprocessing_parameter.copy()
pppar['smooth']['iterations'] = 3
smoothed = vasc.postprocess(binary, postprocessing_parameter=pppar)
p3d.plot([binary, smoothed])