Source code for ClearMap.ImageProcessing.Experts.Axons

# -*- coding: utf-8 -*-
"""
Axons
=====

Expert axon image processing pipeline.

This module provides the basic routines for processing axon data.
The routines are used in the :mod:`ClearMap.Scripts.AxonMap` pipeline.
"""
__author__ = 'Christoph Kirst <christoph.kirst.ck@gmail.com>'
__license__ = 'GPLv3 - GNU General Public License v3 (see LICENSE.txt)'
__copyright__ = 'Copyright © 2023 by Christoph Kirst'

import numpy as np
import gc

import scipy.ndimage as ndi
import skimage.filters as skif

import ClearMap.IO.IO as io

import ClearMap.ParallelProcessing.BlockProcessing as bp
import ClearMap.ParallelProcessing.DataProcessing.ArrayProcessing as ap

import ClearMap.ImageProcessing.Filter.Rank as rnk
import ClearMap.ImageProcessing.LocalStatistics as ls
import ClearMap.ImageProcessing.LightsheetCorrection as lc
import ClearMap.ImageProcessing.Differentiation.Hessian as hes
# import ClearMap.ImageProcessing.Binary.Filling as bf
# import ClearMap.ImageProcessing.Binary.Smoothing as bs

import ClearMap.Utils.Timer as tmr
import ClearMap.Utils.HierarchicalDict as hdict


########################################################################################################################
# Preprocess
########################################################################################################################

default_preprocess_parameter = dict(
    # initial clipping and mask generation
    clip=dict(
        clip_range = (140,26000),
        dtype = 'uint16'),

    # lightsheet correction
    lightsheet=dict(method='subtract',
                    percentile=0.25,
                    lightsheet=dict(selem=(150, 1, 1)),
                    background=dict(selem=(200, 200, 1),
                                    spacing=(25, 25, 1),
                                    step=(2, 2, 1),
                                    interpolate=1),
                    lightsheet_vs_background=2.5,
                    lightsheet_minimum=1,
                    save=False),

    # median
    median=None,

    # equalization
    equalize=dict(percentile=(0.3, 0.975),
                  selem=(200, 200, 5),
                  spacing=(50, 50, 5),
                  interpolate=1,
                  save=False),

    # general
    dtype = 'float16',
    max_bin = 2**12,
)

default_preprocess_processing_parameter = dict(
    size_max=5,
    size_min='fixed',
    overlap=0,
    axes=[2],
    optimization=True,
    optimization_fix='all',
    verbose=None,
    processes=None
)


[docs] def preprocess(source, sink=None, parameter=None, processing_parameter=None): parameter = parameter or default_preprocess_parameter processing_parameter = processing_parameter or default_preprocess_processing_parameter # initialize sink shape = io.shape(source) order = io.order(source) dtype = parameter.get('dtype', default_preprocess_parameter['dtype']) sink, sink_buffer = ap.initialize_sink(sink=sink, shape=shape, order=order, dtype=dtype, memory='automatic') # output saving for key in parameter.keys(): par = parameter[key] if isinstance(par, dict): filename = par.get('save', None) if filename: ap.initialize_sink(filename, shape=shape, order=order, dtype='float') # higher level output saving for key in par.keys(): par2 = par[key] if isinstance(par2, dict): filename = par2.get('save', None) if filename: ap.initialize_sink(filename, shape=shape, order=order, dtype='float') parameter.update(verbose=processing_parameter.get('verbose', False)) bp.process(preprocess_block, source, sink, function_type='block', parameter=parameter, **processing_parameter) return sink
[docs] def preprocess_block(source, sink, parameter): """Preprocess a Block.""" # initialize parameter and slicing verbose = parameter.get('verbose', False) prefix = 'Block %s: ' % (source.info(),) if verbose else None total_time = tmr.Timer(prefix) if verbose else None max_bin = parameter.get('max_bin', default_preprocess_parameter['max_bin']) base_slicing = sink.valid.base_slicing valid_slicing = source.valid.slicing # initialize binary status for inspection status = parameter.get('status', None) if status: status = io.as_source(status) status = status[base_slicing] # clippingMAX_BIN parameter_clip = parameter.get('clip', None) if parameter_clip: parameter_clip = parameter_clip.copy() parameter_clip.update(norm=max_bin) timer = tmr.Timer(prefix) if verbose else None if verbose: hdict.pprint(parameter_clip, head=prefix + 'Clipping:') save = parameter_clip.pop('save', None) clipped, mask = clip(source, **parameter_clip) if save: save = io.as_source(save) save[base_slicing] = clipped[valid_slicing] if verbose: timer.print_elapsed_time('Clipping') else: clipped = source.as_buffer() mask = None # active arrays: clipped, mask # lightsheet correction parameter_lightsheet = parameter.get('lightsheet', None) if parameter_lightsheet: parameter_lightsheet = parameter_lightsheet.copy() timer = tmr.Timer(prefix) if verbose else None if verbose: hdict.pprint(parameter_lightsheet, head=prefix + 'Lightsheet:') lightsheet_method = parameter_lightsheet.pop('method', 'subtract') # parameter_lightsheet.update(max_bin=max_bin); save = parameter_lightsheet.pop('save', None) if lightsheet_method == 'subtract': parameter_lightsheet.pop('lightsheet_minimum', None) corrected = lc.correct_lightsheet(clipped, mask=mask, max_bin=max_bin, **parameter_lightsheet) else: parameter_lightsheet.pop('lightsheet_vs_background', None) corrected = lc.correct_lightsheet_divide(clipped, mask=mask, max_bin=max_bin, **parameter_lightsheet) if save: save = io.as_source(save) save[base_slicing] = corrected[valid_slicing] if verbose: timer.print_elapsed_time('Lightsheet') else: corrected = clipped del clipped # active arrays: corrected, mask # median filter parameter_median = parameter.get('median', None) if parameter_median: parameter_median = parameter_median.copy() timer = tmr.Timer(prefix) if verbose else None if verbose: hdict.pprint(parameter_median, head=prefix + 'Median:') save = parameter_median.pop('save', None) median = rnk.median(corrected, max_bin=max_bin, mask=mask, **parameter_median) if save: save = io.as_source(save) save[base_slicing] = median[valid_slicing] if verbose: timer.print_elapsed_time('Median') else: median = corrected del corrected # active arrays: median, mask # equalize parameter_equalize = parameter.get('equalize', None) if parameter_equalize: parameter_equalize = parameter_equalize.copy() timer = tmr.Timer(prefix) if verbose else None if verbose: hdict.pprint(parameter_equalize, head=prefix + 'Equalization:') save = parameter_equalize.pop('save', None) equalized = equalize(median, mask=mask, **parameter_equalize) if save: save = io.as_source(save) save[base_slicing] = equalized[valid_slicing] if verbose: timer.print_elapsed_time('Equalization') else: equalized = median del median, mask # active arrays: equalized sink[valid_slicing] = equalized[valid_slicing] if verbose: total_time.print_elapsed_time('Preprocessing') gc.collect() return None
[docs] def clip(source, clip_range, norm, dtype): """Clips lower and rescales to higher clip range""" clip_low, clip_high = clip_range clipped = np.array(source[:], dtype='float16') #clip low low = clipped < clip_low clipped[low] = clip_low # clip high clipped[clipped >= clip_high] = clip_high # not needed if clip high is large enough clipped -= clip_low clipped *= float(norm) / (clip_high - clip_low) clipped = np.asarray(clipped, dtype=dtype) return clipped, np.logical_not(low)
[docs] def deconvolve(source, binarized, sigma=10): convolved = np.zeros(source.shape, dtype=float) convolved[binarized] = source[binarized] for z in range(convolved.shape[2]): convolved[:, :, z] = ndi.gaussian_filter(convolved[:, :, z], sigma=sigma) deconvolved = source - np.minimum(source, convolved) deconvolved[binarized] = source[binarized] return deconvolved
[docs] def threshold_isodata(source): try: thresholds = skif.threshold_isodata(source, return_all=True) if len(thresholds) > 0: return thresholds[-1] else: return 1 except: return 1
[docs] def threshold_adaptive(source, function=threshold_isodata, selem=(100, 100, 3), spacing=(25, 25, 3), interpolate=1, mask=None, step=None): source = io.as_source(source)[:] threshold = ls.apply_local_function(source, function=function, mask=mask, dtype=float, selem=selem, spacing=spacing, interpolate=interpolate, step=step) return threshold
[docs] def equalize(source, percentile=(0.5, 0.95), max_value=1.5, selem=(200, 200, 5), spacing=(50, 50, 5), interpolate=1, mask=None): equalized = ls.local_percentile(source, percentile=percentile, mask=mask, dtype=float, selem=selem, spacing=spacing, interpolate=interpolate) normalize = 1 / np.maximum(equalized[..., 0], 1) maxima = equalized[..., 1] ids = maxima * normalize > max_value normalize[ids] = max_value / maxima[ids] equalized = np.array(source, dtype=float) * normalize return equalized
[docs] def tubify(source, sigma=1.0, gamma12=1.0, gamma23=1.0, alpha=0.25): return hes.lambda123(source=source, sink=None, sigma=sigma, gamma12=gamma12, gamma23=gamma23, alpha=alpha)
############################################################################### ### Helper ###############################################################################
[docs] def status_to_description(status): """Converts a status int to its description. Arguments --------- status : int The status. Returns ------- description : str The description corresponding to the status. """ description = '' for k in range(len(STATUS_NAMES) - 1, -1, -1): if status / 2 ** k == 1: description = STATUS_NAMES[k] + ',' + description status -= 2 ** k if len(description) == 0: description = 'Background' else: description = description[:-1] return description
[docs] def binary_statistics(source): """Counts the binarization types. Arguments --------- source : array The status array of the binarization process. Returns ------- statistics : dict A dict with entires {description : count}. """ status, counts = np.unique(io.as_source(source)[:], return_counts=True) return {status_to_description(s): c for s, c in zip(status, counts)}
############################################################################### ### Tests ############################################################################### def _test(): """Tests.""" import numpy as np import ClearMap.Visualization.Plot3d as p3d import ClearMap.Tests.Files as tsf import ClearMap.ImageProcessing.Experts.Vasculature as vasc source = np.array(tsf.source('vls')[:300, :300, 80:120]) source[:, :, [0, -1]] = 0 source[:, [0, -1], :] = 0 source[[0, -1], :, :] = 0 bpar = vasc.default_binarization_parameter.copy() bpar['clip']['clip_range'] = (150, 7000) bpar['as_memory'] = True # bpar['binary_status'] = 'binary_status.npy' ppar = vasc.default_processing_parameter.copy() ppar['processes'] = 10 ppar['size_max'] = 10 sink = 'binary.npy' # sink=None; binary = vasc.binarize(source, sink=sink, binarization_parameter=bpar, processing_parameter=ppar) p3d.plot([source, binary]) import ClearMap.IO.IO as io io.delete_file(sink) pppar = vasc.default_postprocessing_parameter.copy() pppar['smooth']['iterations'] = 3 smoothed = vasc.postprocess(binary, postprocessing_parameter=pppar) p3d.plot([binary, smoothed])