#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CellMap
=======
This module contains the class to analyze (detect) individual cells,
e.g. to analyze immediate early gene expression data from iDISCO+ cleared tissue [Renier2016]_.
.. image:: ../static/cell_abstract_2016.jpg
:target: https://doi.org/10.1016/j.cell.2020.01.028
:width: 300
.. figure:: ../static/CellMap_pipeline.png
iDISCO+ and ClearMap: A Pipeline for Cell Detection, Registration, and
Mapping in Intact Samples Using Light Sheet Microscopy.
"""
import copy
import importlib
import os
import platform
import re
import warnings
from concurrent.futures.process import BrokenProcessPool
import numpy as np
import pandas as pd
from PyQt5.QtGui import QColor
from matplotlib import pyplot as plt
import pyqtgraph as pg
# noinspection PyPep8Naming
import ClearMap.Alignment.Elastix as elastix
# noinspection PyPep8Naming
import ClearMap.IO.IO as clearmap_io
# noinspection PyPep8Naming
import ClearMap.Visualization.Plot3d as plot_3d
import ClearMap.Visualization.Qt.Plot3d as qplot_3d
# noinspection PyPep8Naming
import ClearMap.Alignment.Resampling as resampling
# noinspection PyPep8Naming
import ClearMap.ImageProcessing.Experts.Cells as cell_detection
# noinspection PyPep8Naming
import ClearMap.Analysis.Measurements.Voxelization as voxelization
# noinspection PyPep8Naming
import ClearMap.Alignment.Annotation as annotation
from ClearMap.processors.sample_preparation import PreProcessor
from ClearMap.processors.generic_tab_processor import TabProcessor
from ClearMap.Utils.utilities import runs_on_ui
from ClearMap.Visualization.Qt.widgets import Scatter3D
__author__ = 'Christoph Kirst <christoph.kirst.ck@gmail.com>, Charly Rousseau <charly.rousseau@icm-institute.org>'
__license__ = 'GPLv3 - GNU General Public License v3 (see LICENSE)'
__copyright__ = 'Copyright © 2020 by Christoph Kirst'
__webpage__ = 'https://idisco.info'
__download__ = 'https://www.github.com/ChristophKirst/ClearMap2'
if platform.system().lower().startswith('darwin'):
USE_BINARY_POINTS_FILE = False
else:
USE_BINARY_POINTS_FILE = True
[docs]
class CellDetector(TabProcessor):
def __init__(self, preprocessor=None):
super().__init__()
self.sample_config = None
self.processing_config = None
self.machine_config = None
self.preprocessor = None
self.workspace = None
self.cell_detection_re = ('Processing block',
re.compile(r'.*?Processing block \d+/\d+.*?\selapsed time:\s\d+:\d+:\d+\.\d+'))
self.setup(preprocessor)
[docs]
def setup(self, preprocessor):
self.preprocessor = preprocessor
if preprocessor is not None:
self.workspace = preprocessor.workspace
configs = preprocessor.get_configs()
self.sample_config = configs['sample']
self.machine_config = configs['machine']
self.processing_config = self.preprocessor.config_loader.get_cfg('cell_map')
self.set_progress_watcher(self.preprocessor.progress_watcher)
@property
def detected(self):
return os.path.exists(self.workspace.filename('cells', postfix='raw'))
[docs]
def run(self):
# select sub-slice for testing the pipeline
slicing = (
slice(*self.processing_config['test_set_slicing']['dim_0']),
slice(*self.processing_config['test_set_slicing']['dim_1']),
slice(*self.processing_config['test_set_slicing']['dim_2'])
)
self.create_test_dataset(slicing)
self.run_cell_detection(tuning=True)
if self.processing_config['detection']['preview']:
self.preview_cell_detection()
self.run_cell_detection()
# print(f"Number of cells detected: {self.get_n_detected_cells()}")
self.post_process_cells()
[docs]
def post_process_cells(self):
self.processing_config.reload()
if self.processing_config['detection']['plot_cells'] and not runs_on_ui():
self.plot_cells()
self.filter_cells()
if self.processing_config['cell_filtration']['preview'] and not runs_on_ui():
self.plot_filtered_cells()
self.atlas_align()
self.export_collapsed_stats()
[docs]
def voxelize(self, postfix=''):
self.processing_config.reload()
coordinates, cells, voxelization_parameter = self.get_voxelization_params(postfix=postfix)
# %% Unweighted
coordinates, counts_file_path = self.voxelize_unweighted(coordinates, voxelization_parameter)
if self.processing_config['voxelization']['preview']['counts'] and not runs_on_ui():
self.plot_voxelized_counts()
# %% Weighted
# intensities_file_path = self.voxelize_weighted(coordinates, cells, voxelization_parameter) # WARNING: Currently causing issues
# if self.processing_config['voxelization']['preview']['densities']:
# self.plot_voxelized_intensities()
[docs]
def plot_voxelized_counts(self, arrange=True, parent=None):
scale = self.preprocessor.processing_config['registration']['resampling']['raw_sink_resolution']
return plot_3d.plot(self.workspace.filename('density', postfix='counts'),
scale=scale, title='Cell density (voxelized)', lut='flame',
arrange=arrange, parent=parent)
[docs]
def create_test_dataset(self, slicing):
self.workspace.create_debug('stitched', slicing=slicing)
self.update_watcher_main_progress()
[docs]
def get_voxelization_params(self, postfix=''):
voxelization_parameter = {
'radius': self.processing_config['voxelization']['radii'],
'verbose': True
}
if self.workspace.debug: # Path will use debug
voxelization_parameter['shape'] = clearmap_io.shape(self.workspace.filename('cells', postfix='shape'))
elif self.preprocessor.was_registered:
voxelization_parameter['shape'] = clearmap_io.shape(self.preprocessor.annotation_file_path)
else:
voxelization_parameter['shape'] = self.preprocessor.resampled_shape
if postfix: # Hack to compensate for the fact that the realigned makes no sense in
cells, coordinates = self.get_coords(coord_type=postfix, aligned=False)
else:
cells, coordinates = self.get_coords(coord_type=None, aligned=True)
return coordinates, cells, voxelization_parameter
[docs]
def get_coords(self, coord_type='filtered', aligned=False):
if coord_type not in ('filtered', 'raw', None):
raise ValueError(f'Coordinate type "{coord_type}" not recognised')
if coord_type is None:
dataframe_path = self.workspace.filename('cells', extension='.feather')
if os.path.exists(dataframe_path):
table = pd.read_feather(dataframe_path)
else:
table = np.load(self.workspace.filename('cells')).T
else:
table = np.load(self.workspace.filename('cells', postfix=coord_type))
if aligned:
coordinates = np.array([table[axis] for axis in ['xt', 'yt', 'zt']]).T
else:
coordinates = np.array([table[axis] for axis in ['x', 'y', 'z']]).T
return table, coordinates
[docs]
def voxelize_unweighted(self, coordinates, voxelization_parameter):
"""
Voxelize un weighted i.e. for cell counts
Parameters
----------
coordinates
str, array or Source
Source of point of nxd coordinates.
voxelization_parameter
dict
Returns
-------
"""
counts_file_path = self.workspace.filename('density', postfix='counts') # TODO: improve var name
clearmap_io.delete_file(counts_file_path)
self.set_watcher_step('Unweighted voxelisation')
voxelization.voxelize(coordinates, sink=counts_file_path, **voxelization_parameter) # WARNING: prange
self.update_watcher_main_progress()
# self.remove_crust(coordinates, voxelization_parameter) # WARNING: currently causing issues
return coordinates, counts_file_path
[docs]
def voxelize_weighted(self, coordinates, source, voxelization_parameter):
"""
Voxelize weighted i.e. for cell intensities
Parameters
----------
coordinates
np.array
source
Source.Source
voxelization_parameter
dict
Returns
-------
"""
intensities_file_path = self.workspace.filename('density', postfix='intensities')
intensities = source['source']
voxelization.voxelize(coordinates, sink=intensities_file_path, weights=intensities, **voxelization_parameter) # WARNING: prange
return intensities_file_path
[docs]
def atlas_align(self):
"""Atlas alignment and annotation """
table, coordinates = self.get_coords(coord_type='filtered')
df = pd.DataFrame({'x': coordinates[:, 0], 'y': coordinates[:, 1], 'z': coordinates[:, 2]})
df['size'] = table['size']
df['source'] = table['source']
if self.preprocessor.was_registered:
coordinates_transformed = self.transform_coordinates(coordinates)
df['xt'] = coordinates_transformed[:, 0]
df['yt'] = coordinates_transformed[:, 1]
df['zt'] = coordinates_transformed[:, 2]
structure_ids = annotation.label_points(coordinates_transformed,
annotation_file=self.preprocessor.annotation_file_path,
key='id')
df['id'] = structure_ids
hemisphere_labels = annotation.label_points(coordinates_transformed,
annotation_file=self.preprocessor.hemispheres_file_path,
key='id')
df['hemisphere'] = hemisphere_labels
names = annotation.convert_label(structure_ids, key='id', value='name')
df['name'] = names
unique_ids = np.sort(df['id'].unique())
order_map = {id_: annotation.find(id_, key='id')['order'] for id_ in unique_ids}
df['order'] = df['id'].map(order_map)
color_map = {id_: annotation.find(id_, key='id')['rgb'] for id_ in unique_ids} # WARNING RGB upper case should give integer but does not work
df['color'] = df['id'].map(color_map)
volumes = annotation.annotation.get_lateralised_volume_map(
self.preprocessor.processing_config['registration']['resampling']['autofluo_sink_resolution'],
self.preprocessor.hemispheres_file_path
)
df['volume'] = df.set_index(['id', 'hemisphere']).index.map(volumes.get)
df.to_feather(self.workspace.filename('cells', extension='.feather'))
[docs]
def filter_cells(self):
self.processing_config.reload()
thresholds = {
'source': self.processing_config['cell_filtration']['thresholds']['intensity'],
'size': self.processing_config['cell_filtration']['thresholds']['size']
}
cell_detection.filter_cells(source=self.workspace.filename('cells', postfix='raw'),
sink=self.workspace.filename('cells', postfix='filtered'),
thresholds=thresholds)
[docs]
def run_cell_detection(self, tuning=False, save_shape=False):
self.processing_config.reload()
self.workspace.debug = tuning # TODO: use context manager
cell_detection_param = copy.deepcopy(cell_detection.default_cell_detection_parameter)
cell_detection_param['illumination_correction'] = None # WARNING: illumination or illumination_correction
cell_detection_param['background_correction']['shape'] = self.processing_config['detection']['background_correction']['diameter']
cell_detection_param['maxima_detection']['shape'] = self.processing_config['detection']['maxima_detection']['shape']
cell_detection_param['intensity_detection']['measure'] = ['source']
cell_detection_param['shape_detection']['threshold'] = self.processing_config['detection']['shape_detection']['threshold']
if tuning:
clearmap_io.delete_file(self.workspace.filename('cells', postfix='bkg'))
cell_detection_param['background_correction']['save'] = self.workspace.filename('cells', postfix='bkg')
clearmap_io.delete_file(self.workspace.filename('cells', postfix='shape'))
cell_detection_param['shape_detection']['save'] = self.workspace.filename('cells', postfix='shape')
if save_shape:
cell_detection_param['shape_detection']['save'] = self.workspace.filename('cells', postfix='shape')
# clearmap_io.delete_file(workspace.filename('cells', postfix='maxima'))
# cell_detection_param['maxima_detection']['save'] = workspace.filename('cells', postfix='maxima')
processing_parameter = copy.deepcopy(cell_detection.default_cell_detection_processing_parameter)
processing_parameter.update( # TODO: store as other dict and run .update(**self.extra_detection_params)
processes=self.machine_config['n_processes_cell_detection'],
size_min=self.machine_config['detection_chunk_size_min'],
size_max=self.machine_config['detection_chunk_size_max'],
overlap=self.machine_config['detection_chunk_overlap'],
verbose=True
)
n_steps = self.get_n_blocks(self.workspace.source('stitched').shape[2]) # OPTIMISE: read metadata w/out load # TODO: round to processors
self.prepare_watcher_for_substep(n_steps, self.cell_detection_re, 'Detecting cells')
try:
cell_detection.detect_cells(self.workspace.filename('stitched'),
self.workspace.filename('cells', postfix='raw'),
cell_detection_parameter=cell_detection_param,
processing_parameter=processing_parameter,
workspace=self.workspace) # WARNING: prange inside multiprocess (including arrayprocessing and devolvepoints for vox)
except BrokenProcessPool as err:
print('Cell detection canceled')
return
finally:
self.workspace.debug = False
self.update_watcher_main_progress()
[docs]
def export_as_csv(self):
"""
Export the cell coordinates to csv
.. deprecated:: 2.1
Use :func:`atlas_align` and `export_collapsed_stats` instead.
"""
warnings.warn("export_as_csv is deprecated and will be removed in future versions;"
"please use the new formats from atlas_align and export_collapsed_stats", DeprecationWarning, 2)
csv_file_path = self.workspace.filename('cells', extension='.csv')
self.get_cells_df().to_csv(csv_file_path)
[docs]
def export_collapsed_stats(self, all_regions=True):
df = self.get_cells_df()
collapsed = pd.DataFrame()
relevant_columns = ['id', 'order', 'name', 'hemisphere', 'volume', 'size']
for i in (0, 255): # Split by hemisphere to group by structure and reconcatenate hemispheres after
grouped = df[df['hemisphere'] == i][relevant_columns].groupby(['id'], as_index=False)
tmp = pd.DataFrame()
first = grouped.first()
tmp['Structure ID'] = first['id']
tmp['Structure order'] = first['order']
tmp['Structure name'] = first['name']
tmp['Hemisphere'] = first['hemisphere']
tmp['Structure volume'] = first['volume']
tmp['Cell counts'] = grouped.count()['name']
tmp['Average cell size'] = grouped.mean()['size']
collapsed = pd.concat((collapsed, tmp))
if all_regions: # Add regions even if they are empty
uniq_ids = np.unique(annotation.annotation.atlas)
tmp = pd.DataFrame({'Structure ID': uniq_ids, 'mock': ''})
tmp['Structure name'] = annotation.convert_label(uniq_ids, key='id', value='name')
df_mock = pd.DataFrame({'Hemisphere': [0, 255], 'mock': ''})
tmp = tmp.merge(df_mock, on='mock').drop(columns='mock')
vol_map = annotation.annotation.get_lateralised_volume_map(
self.preprocessor.processing_config['registration']['resampling']['autofluo_sink_resolution'],
self.preprocessor.hemispheres_file_path
)
tmp['Structure volume'] = tmp.set_index(['Structure ID', 'Hemisphere']).index.map(vol_map.get)
order_map = {id_: annotation.find(id_, key='id')['order'] for id_ in uniq_ids}
tmp['Structure order'] = tmp['Structure ID'].map(order_map)
collapsed = tmp.merge(collapsed[['Structure ID', 'Hemisphere', 'Cell counts', 'Average cell size']],
how='left', on=['Structure ID', 'Hemisphere'])
collapsed = collapsed.sort_values(by='Structure ID')
csv_file_path = self.workspace.filename('cells', postfix='stats', extension='.csv')
collapsed.to_csv(csv_file_path, index=False)
[docs]
def plot_cells(self): # For non GUI
source = self.workspace.source('cells', postfix='raw')
plt.figure(1)
plt.clf()
names = source.dtype.names
nx, ny = plot_3d.subplot_tiling(len(names))
for i, name in enumerate(names):
plt.subplot(nx, ny, i + 1)
plt.hist(source[name])
plt.title(name)
plt.tight_layout()
[docs]
def plot_cells_3d_scatter_w_atlas_colors(self, raw=False, parent=None):
if raw:
dv = qplot_3d.plot(self.workspace.filename('stitched'), title='Stitched and cells',
# scale=self.preprocessor.sample_config['resolutions']['raw'],# FIXME: correct scaling for anisotropic
arrange=False, lut='white', parent=parent)[0]
else:
if self.preprocessor.was_registered: # REFACTORING: could extract
dv = qplot_3d.plot(clearmap_io.source(self.preprocessor.reference_file_path),
title='Reference and cells',
arrange=False, lut='white', parent=parent)[0]
else:
dv = qplot_3d.plot(self.workspace.filename('resampled'), title='Resampled and cells',
arrange=False, lut='white', parent=parent)[0]
scatter = pg.ScatterPlotItem()
dv.view.addItem(scatter)
dv.scatter = scatter
df = self.get_cells_df()
if raw:
coordinates = df[['x', 'y', 'z']].values.astype(int)
# coordinates = coordinates * np.array(self.preprocessor.sample_config['resolutions']['raw'])
# coordinates = coordinates.astype(int) # required to match integer z # FIXME: correct scaling for anisotropic
else:
coordinates = df[['xt', 'yt', 'zt']].values.astype(int) # required to match integer z
dv.atlas = clearmap_io.read(self.preprocessor.annotation_file_path)
dv.structure_names = annotation.get_names_map()
if 'hemisphere' in df.columns:
hemispheres = df['hemisphere']
else:
hemispheres = None
dv.scatter_coords = Scatter3D(coordinates, colors=df['color'].values,
hemispheres=hemispheres, half_slice_thickness=0)
dv.refresh()
return [dv]
@property
def df_path(self):
feather_path = self.workspace.filename('cells', extension='.feather')
if os.path.exists:
return feather_path
else:
return self.workspace.filename('cells')
[docs]
def get_cells_df(self):
if self.df_path.endswith('.feather'):
return pd.read_feather(self.df_path)
else:
return pd.DataFrame(np.load(self.df_path))
[docs]
def plot_filtered_cells(self, parent=None, smarties=False):
_, coordinates = self.get_coords('filtered')
stitched_path = self.workspace.filename('stitched')
dv = qplot_3d.plot(stitched_path, title='Stitched and filtered cells', arrange=False,
lut='white', parent=parent)[0]
scatter = pg.ScatterPlotItem()
dv.view.addItem(scatter)
dv.scatter = scatter
dv.scatter_coords = Scatter3D(coordinates, smarties=smarties, half_slice_thickness=3)
dv.refresh()
return [dv]
[docs]
def plot_background_subtracted_img(self):
coordinates = np.hstack([self.workspace.source('cells', postfix='raw')[c][:, None] for c in 'xyz'])
p = plot_3d.list_plot_3d(coordinates)
return plot_3d.plot_3d(self.workspace.filename('stitched'), view=p, cmap=plot_3d.grays_alpha(alpha=1))
[docs]
def remove_crust(self, coordinates, voxelization_parameter):
dist2surf = clearmap_io.read(self.preprocessor.distance_file_path)
threshold = 3
# Convert coordinates to integer
int_coordinates = np.floor(coordinates).astype(int)
# Ensure all coordinates are within the dist2surf array bounds
valid_indices = (int_coordinates[:, 0] < dist2surf.shape[0]) & \
(int_coordinates[:, 1] < dist2surf.shape[1]) & \
(int_coordinates[:, 2] < dist2surf.shape[2])
# Apply the mask to get valid coordinates
valid_coordinates = int_coordinates[valid_indices]
# Get the dist2surf values at the valid coordinates
dist_values = dist2surf[valid_coordinates[:, 0],
valid_coordinates[:, 1],
valid_coordinates[:, 2]]
# Apply the threshold
coordinates_wcrust = valid_coordinates[dist_values > threshold]
voxelization.voxelize(coordinates_wcrust, sink=self.workspace.filename('density', postfix='counts_wcrust'),
**voxelization_parameter) # WARNING: prange
[docs]
def preview_cell_detection(self, parent=None, arrange=True, sync=True):
sources = [self.workspace.filename('stitched'),
self.workspace.filename('cells', postfix='bkg'),
self.workspace.filename('cells', postfix='shape')
]
sources = [s for s in sources if os.path.exists(s)] # Remove missing files (if not tuning)
titles = [os.path.basename(s) for s in sources]
luts = ['white', 'white', 'random']
return plot_3d.plot(sources, title=titles, arrange=arrange, sync=sync, lut=luts, parent=parent)
[docs]
def get_n_detected_cells(self):
if os.path.exists(self.workspace.filename('cells', postfix='raw')):
_, coords = self.get_coords(coord_type='raw')
return np.max(coords.shape) # TODO: check dimension instead
else:
return 0
[docs]
def get_n_filtered_cells(self):
if os.path.exists(self.workspace.filename('cells', postfix='filtered')):
_, coords = self.get_coords(coord_type='filtered')
return np.max(coords.shape) # TODO: check dimension instead
else:
return 0
[docs]
def plot_voxelized_intensities(self, arrange=True):
return plot_3d.plot(self.workspace.filename('density', postfix='intensities'), arrange=arrange)
[docs]
def get_n_blocks(self, dim_size):
blk_size = self.machine_config['detection_chunk_size_max']
overlap = self.machine_config['detection_chunk_overlap']
n_blocks = int(np.ceil((dim_size - blk_size) / (blk_size - overlap) + 1))
return n_blocks
[docs]
def export_to_clearmap1_fmt(self):
"""
ClearMap 1.0 export (will generate the files cells_ClearMap1_intensities, cells_ClearMap1_points_transformed,
cells_ClearMap1_points necessaries to use the analysis script of ClearMap1.
In ClearMap2 the 'cells' file contains already all this information)
In order to align the coordinates when we have right and left hemispheres,
if the orientation of the brain is left, will calculate the new coordinates for the Y axes,
this change will not affect the orientation of the heatmaps, since these are generated from
the ClearMap2 file 'cells'
.. deprecated:: 2.1
Use :func:`atlas_align` and `export_collapsed_stats` instead.
"""
warnings.warn("export_to_clearmap1_fmt is deprecated and will be removed in future versions;"
"please use the new formats from atlas_align and export_collapsed_stats", DeprecationWarning, 2)
source = self.workspace.source('cells')
clearmap1_format = {'points': ['x', 'y', 'z'],
'points_transformed': ['xt', 'yt', 'zt'],
'intensities': ['source', 'dog', 'background', 'size']}
for filename, names in clearmap1_format.items():
sink = self.workspace.filename('cells', postfix=['ClearMap1', filename])
print(filename, sink)
data = np.array(
[source[name] if name in source.dtype.names else np.full(source.shape[0], np.nan) for name in names]
)
data = data.T
clearmap_io.write(sink, data)
[docs]
def convert_cm2_to_cm2_1_fmt(self):
"""Atlas alignment and annotation """
cells = np.load(self.workspace.filename('cells'))
df = pd.DataFrame({ax: cells[ax] for ax in 'xyz'})
df['size'] = cells['size']
df['source'] = cells['source']
for ax in 'xyz':
df[f'{ax}t'] = cells[f'{ax}t']
df['order'] = cells['order']
df['name'] = cells['name']
coordinates_transformed = np.vstack([cells[f'{ax}t'] for ax in 'xyz']).T
# FIXME: Put key ID and get ID directly
hemisphere_label = annotation.label_points(coordinates_transformed,
annotation_file=self.preprocessor.hemispheres_file_path,
key='id')
unique_labels = np.sort(df['order'].unique())
color_map = {lbl: annotation.find(lbl, key='order')['rgb'] for lbl in
unique_labels} # WARNING RGB upper case should give integer but does not work
id_map = {lbl: annotation.find(lbl, key='order')['id'] for lbl in unique_labels}
atlas = clearmap_io.read(self.preprocessor.annotation_file_path)
atlas_scale = self.preprocessor.processing_config['registration']['resampling']['autofluo_sink_resolution']
atlas_scale = np.prod(atlas_scale)
volumes = {_id: (atlas == _id).sum() * atlas_scale for _id in
id_map.values()} # Volumes need a lookup on ID since the atlas is in ID space
df['id'] = df['order'].map(id_map)
df['hemisphere'] = hemisphere_label
df['color'] = df['order'].map(color_map)
df['volume'] = df['id'].map(volumes)
df.to_feather(self.workspace.filename('cells', extension='.feather'))
if __name__ == "__main__":
import sys
preprocessor = PreProcessor()
preprocessor.setup(sys.argv[1:3])
preprocessor.setup_atlases()
# preprocessor.run()
detector = CellDetector(preprocessor)