Source code for ClearMap.Visualization.Qt.DataViewerAxon

"""
DataViewer
==========

Data viewer showing 3d data as 2d slices.

Usage
-----

.. image:: ../static/DataViewer.jpg

Note
----
This viewer is based on the pyqtgraph package.
"""
__author__ = 'Christoph Kirst <christoph.kirst.ck@gmail.com>, Charly Rousseau <charly.rousseau@icm-institute.org>'
__license__ = 'GPLv3 - GNU General Public License v3 (see LICENSE)'
__copyright__ = 'Copyright © 2020 by Christoph Kirst'
__webpage__ = 'https://idisco.info'
__download__ = 'https://www.github.com/ChristophKirst/ClearMap2'

import time
import functools as ft

import numpy as np

import pyqtgraph as pg
from PyQt5 import QtWidgets
from PyQt5.QtCore import QEvent, QRect, QSize, pyqtSignal, Qt
from PyQt5.QtGui import QPainter
from PyQt5.QtWidgets import QWidget, QRadioButton, QLabel, QSplitter, QApplication, QSizePolicy, QPushButton, QCheckBox, \
  QGraphicsPathItem, QGridLayout, QLineEdit, QScrollArea

from ClearMap.Utils.utilities import runs_on_spyder
from ClearMap.IO.IO import as_source
from ClearMap.IO.Source import Source
from ClearMap.Visualization.Qt.data_viewer_luts import LUT

pg.CONFIG_OPTIONS['useOpenGL'] = False  # set to False if trouble seeing data.

if not pg.QAPP:
    pg.mkQApp()


[docs] class DataViewer(QWidget): mouse_clicked = pyqtSignal(int, int, int) DEFAULT_SCATTER_PARAMS = { 'pen': 'red', 'brush': 'red', 'symbol': '+', 'size': 10 } def __init__(self, source, points=None, vectors=None, orientations=None, annotation=None, axis=None, scale=None, title=None, invertY=False, minMax=None, screen=None, parent=None, default_lut='flame', max_projection=None, points_style=None, vectors_style=None, orientations_style=None, **kwargs): QWidget.__init__(self, parent, **kwargs) # Images sources self.initializeSources(source, axis=axis, scale=scale) self.points = points if self.points is not None: self.points = as_source(points).array self.points_item = None self.points_style = dict(pen=None, brush='white') if points_style is not None: self.points_style.update(points_style) self.vectors = vectors if self.vectors is not None: self.vectors = as_source(vectors).array self.vectors_item = None self.vectors_base_item = None self.vectors_style = dict(pen=None, brush='lightblue') if vectors_style is not None: self.vectors_style.update(vectors_style) self.orientations = orientations if self.orientations is not None: self.orientations = as_source(orientations).array self.orientations_item = None self.orientations_style = dict(pen='gray') if orientations_style is not None: self.orientations_style.update(orientations_style) self.vectors = vectors self.vectors_item = None # print('init') self.annotation = annotation if title is None: if isinstance(source, str): title = source elif isinstance(source, Source): title = source.location if title is None: title = 'DataViewer' self.setWindowTitle(title) self.resize(1600, 1200) self.layout = QtWidgets.QGridLayout(self) self.layout.setContentsMargins(0, 0, 0, 0) # image pane self.view = pg.ViewBox() self.view.setAspectLocked(True) self.view.invertY(invertY) self.graphicsView = pg.GraphicsView() self.graphicsView.setObjectName("GraphicsView") self.graphicsView.setCentralItem(self.view) splitter = QSplitter() splitter.setOrientation(Qt.Horizontal) splitter.setSizes([self.width() - 10, 10]) self.layout.addWidget(splitter) image_splitter = QSplitter() image_splitter.setOrientation(Qt.Vertical) image_splitter.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) splitter.addWidget(image_splitter) # Image plots image_options = dict(clipToView=True, autoDownsample=True, autoLevels=False, useOpenGL=None) self.image_items = [pg.ImageItem(s[self.source_slice[:s.ndim]], **image_options) for s in self.sources] for itm in self.image_items: itm.setRect(QRect(0, 0, int(self.source_range_x), int(self.source_range_y))) itm.setCompositionMode(QPainter.CompositionMode_Plus) self.view.addItem(itm) self.view.setXRange(0, self.source_range_x) self.view.setYRange(0, self.source_range_y) # slice selector self.slicePlot = pg.PlotWidget() size_policy = QSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred) size_policy.setHorizontalStretch(0) size_policy.setVerticalStretch(0) size_policy.setHeightForWidth(self.slicePlot.sizePolicy().hasHeightForWidth()) self.slicePlot.setSizePolicy(size_policy) self.slicePlot.setMinimumSize(QSize(0, 40)) self.slicePlot.setObjectName("roiPlot") self.sliceLine = pg.InfiniteLine(0, movable=True) self.sliceLine.setPen((255, 255, 255, 200)) self.sliceLine.setZValue(1) self.slicePlot.addItem(self.sliceLine) self.slicePlot.hideAxis('left') self.updateSlicer() self.sliceLine.sigPositionChanged.connect(self.updateSlice) # Axis tools self.axis_buttons = [] axis_tools_layout, axis_tools_widget = self.__setup_axes_controls() # max projection depth self.max_projection = max_projection self.max_projection_edit = QLineEdit() if self.max_projection is not None: self.max_projection_edit.setText('%d' % self.max_projection) # self.max_projection_edit.setValidator(pg.QtGui.QIntValidator()) self.max_projection_edit.setMaxLength(4) self.max_projection_edit.setAlignment(Qt.AlignRight) self.max_projection_edit.setMaximumWidth(60) self.max_projection_edit.editingFinished.connect(self.change_max_projection) axis_tools_layout.addWidget(self.max_projection_edit, 0, 3) # points color self.points_color_button = pg.ColorButton(color=self.points_style.get('brush')) self.points_color_button.setMaximumWidth(30) self.points_color_button.sigColorChanged.connect(self.change_points_color) axis_tools_layout.addWidget(self.points_color_button, 0, 4) # vectors color and threshold self.vectors_color_button = pg.ColorButton(color=self.vectors_style.get('brush')) self.vectors_color_button.setMaximumWidth(30) self.vectors_color_button.sigColorChanged.connect(self.change_vectors_color) axis_tools_layout.addWidget(self.vectors_color_button, 0, 5) self.vectors_threshold_edit = QLineEdit() vectors_threshold = self.vectors_style.get('threshold', None) if vectors_threshold is not None: self.vectors_threshold_edit.setText('%.4f' % vectors_threshold) self.vectors_threshold_edit.setMaxLength(6) self.vectors_threshold_edit.setAlignment(Qt.AlignRight) self.vectors_threshold_edit.setMaximumWidth(60) self.vectors_threshold_edit.editingFinished.connect(self.change_vectors_threshold) axis_tools_layout.addWidget(self.vectors_threshold_edit, 0, 6) # orientation threshold self.orientations_color_button = pg.ColorButton(color=self.orientations_style.get('pen')) self.orientations_color_button.setMaximumWidth(30) self.orientations_color_button.sigColorChanged.connect(self.change_orientations_color) axis_tools_layout.addWidget(self.orientations_color_button, 0, 7) self.orientations_threshold_edit = QLineEdit() orientations_threshold = self.orientations_style.get('threshold', None) if orientations_threshold is not None: self.orientations_threshold_edit.setText('%.4f' % orientations_threshold) self.orientations_threshold_edit.setMaxLength(6) self.orientations_threshold_edit.setAlignment(Qt.AlignRight) self.orientations_threshold_edit.setMaximumWidth(60) self.orientations_threshold_edit.editingFinished.connect(self.change_orientations_threshold) axis_tools_layout.addWidget(self.orientations_threshold_edit, 0, 8) # coordinate label self.source_pointer = [0, 0, 0] self.source_label = QLabel("") self.source_label_scroll = QScrollArea() self.source_label_scroll.setMaximumHeight(30) self.source_label_scroll.setWidgetResizable(True) self.source_label_scroll.horizontalScrollBar().setStyleSheet("QScrollBar {height:0px;}") self.source_label_scroll.setWidget(self.source_label) axis_tools_layout.addWidget(self.source_label_scroll, 0, 9) self.graphicsView.scene().sigMouseMoved.connect(self.updateLabelFromMouseMove) # compose the image viewer image_splitter.addWidget(self.graphicsView) image_splitter.addWidget(self.slicePlot) image_splitter.addWidget(axis_tools_widget) image_splitter.setSizes([self.height() - 35 - 20, 35, 20]) # lut widgets self.luts = [LUT(image=i, color=c) for i, c in zip(self.image_items, self.__get_colors(default_lut))] lut_layout = QtWidgets.QGridLayout() lut_layout.setContentsMargins(0, 0, 0, 0) for d, lut in enumerate(self.luts): lut_layout.addWidget(lut, 0, d) lut_widget = QWidget() lut_widget.setLayout(lut_layout) lut_widget.setContentsMargins(0, 0, 0, 0) lut_widget.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding) # TODO: add option for sizepolicy splitter.addWidget(lut_widget) splitter.setStretchFactor(0, 1) splitter.setStretchFactor(1, 0) # update scale for lut in self.luts: lut.range_buttons[1][2].click() if minMax is not None: self.setMinMax(minMax) self.initialize_points_item() self.initialize_vectors_item() self.change_max_projection() # self.change_orientations_threshold() self.show()
[docs] def eventFilter(self, source, event): if event.type() == QEvent.Wheel: angle = event.angleDelta().y() # steps = angle / abs(angle) steps = angle / 120 self.sliceLine.setValue(self.sliceLine.value() + steps) return super().eventFilter(source, event)
def __cast_source(self, source): if isinstance(source, tuple): source = list(source) if not isinstance(source, list): source = [source] return source
[docs] def initializeSources(self, source, scale=None, axis=None, update=True): # initialize sources and axis settings source = self.__cast_source(source) self.n_sources = len(source) self.sources = [as_source(s) for s in source] # self.__cast_bools() # self.__ensure_3d() # source shapes self.source_shape = self.shape3d(self.sources[0].shape) for s in self.sources: if s.ndim > 3: raise RuntimeError(f'Source has {s.ndim} > 3 dimensions: {s}!') if self.shape3d(s.shape) != self.source_shape: raise RuntimeError(f'Sources shape {self.source_shape} vs {s.shape} in source {s}!') # slicing if axis is None: axis = 2 self.source_axis = axis self.source_index = np.array(np.array(self.source_shape, dtype=float) / 2, dtype=int) # scaling if scale is None: scale = np.ones(3) else: scale = np.array(scale) scale = np.hstack([scale, [1] * 3])[:3] self.source_scale = scale self.updateSourceRange() self.updateSourceSlice()
[docs] def setSource(self, source, index='all'): # TODO: see if could factor with __init__ if index == 'all': source = self.__cast_source(source) if self.n_sources != len(source): raise RuntimeError(f'Number of sources does not match! got {len(source)}, expected {self.n_sources}') source = [as_source(s) for s in source] index = range(self.n_sources) else: s = self.sources s[index] = as_source(source) source = s index = [index] # self.__cast_bools() for i in index: s = source[i] if s.shape != self.source_shape: raise RuntimeError('Shape of sources does not match!') if s.ndim == 2: s.shape = s.shape + (1,) if s.ndim != 3: raise RuntimeError(f'Sources dont have dimensions 2 or 3 but {s.ndim} in source {i}!') self.image_items[i].updateImage(s[self.source_slice[:s.ndim]]) self.sources = source
def __setup_axes_controls(self): axis_tools_layout = QGridLayout() for d, ax in enumerate('xyz'): button = QRadioButton(ax) button.setMaximumWidth(50) axis_tools_layout.addWidget(button, 0, d) button.clicked.connect(ft.partial(self.setSliceAxis, d)) self.axis_buttons.append(button) self.axis_buttons[self.source_axis].setChecked(True) axis_tools_widget = QWidget() axis_tools_widget.setLayout(axis_tools_layout) for i in range(self.n_sources): box = QCheckBox(f'{i}') box.setMaximumWidth(50) box.setChecked(True) box.stateChanged.connect(ft.partial(self.toggle_layer, i)) axis_tools_layout.addWidget(box, 1, i) self.axis_buttons.append(box) return axis_tools_layout, axis_tools_widget
[docs] def toggle_layer(self, i, state): self.image_items[i].setVisible(state == Qt.Checked)
def __get_colors(self, default_lut): if self.n_sources == 1: cols = [default_lut] elif self.n_sources == 2: cols = ['purple', 'green'] else: cols = np.array(['white', 'green', 'red', 'blue', 'purple'] * self.n_sources)[:self.n_sources] return cols
[docs] def getXYAxes(self): return [a for a in [0, 1, 2] if a != self.source_axis]
[docs] def updateSourceRange(self): x, y = self.getXYAxes() self.source_range_x = self.source_scale[x] * self.source_shape[x] self.source_range_y = self.source_scale[y] * self.source_shape[y]
[docs] def sourceSlice(self): axis = self.source_axis return tuple(slice(None) if d != axis else self.source_index[axis] for d in range(3))
[docs] def updateSourceSlice(self): """Set the current slice of the source""" self.source_slice = self.sourceSlice()
[docs] def updateSlicer(self): ax = self.source_axis self.slicePlot.setXRange(0, self.source_shape[ax]) self.sliceLine.setValue(self.source_index[ax]) stop = self.source_shape[ax] + 0.5 self.sliceLine.setBounds([0, stop])
[docs] def updateLabelFromMouseMove(self, event_pos): x, y = self.get_coords(event_pos) self.sync_cursors(x, y) x_axis, y_axis = self.getXYAxes() scaled_x, scaled_y = self.scale_coords(x, x_axis, y, y_axis) z = self.source_index[self.source_axis] pos = [z] * 3 pos[x_axis] = scaled_x pos[y_axis] = scaled_y self.source_pointer = pos self.updateLabel()
[docs] def scale_coords(self, x, x_axis, y, y_axis): scaled_x = min(int(x / self.source_scale[x_axis]), self.source_shape[x_axis] - 1) scaled_y = min(int(y / self.source_scale[y_axis]), self.source_shape[y_axis] - 1) return scaled_x, scaled_y
[docs] def get_coords(self, pos): mouse_point = self.view.mapSceneToView(pos) x, y = mouse_point.x(), mouse_point.y() x = min(max(0, x), self.source_range_x) y = min(max(0, y), self.source_range_y) return x, y
[docs] def sync_cursors(self, x, y): if self.cross is not None: self.cross.set_coords([x, y]) self.view.update() for pal in self.pals: pal.cross.set_coords([x, y]) pal._updateCoords(x, y) pal.view.update()
[docs] def updateLabel(self): x, y, z = self.source_pointer xs, ys, zs = self.source_scale vals = ", ".join([str(s[x, y, z]) for s in self.sources]) text = "<span style='font-size: 12pt; color: black'>(%d, %d, %d) {%.2f, %.2f, %.2f} [%s]" % ( x, y, z, x * xs, y * ys, z * zs, vals) if self.annotation is not None: a = self.annotation.get(self.sources[0][x, y, z], None) if a is not None: text = text + ("[%s]" % a) text += "</span>" self.source_label.setText(text)
[docs] def updateSlice(self, force_update=False): ax = self.source_axis index = min(max(0, int(self.sliceLine.value())), self.source_shape[ax] - 1) if self.max_projection is not None: slc_ax = ( slice(max(0, index - self.max_projection), min(self.source_shape[ax], index + self.max_projection)),); else: slc_ax = (index,) if index != self.source_index[ax] or force_update: self.source_index[ax] = index self.source_slice = self.source_slice[:ax] + slc_ax + self.source_slice[ax + 1:] self.source_pointer[ax] = index self.updateLabel() self.updateImage() self.update_points() self.update_vectors() self.update_orientations()
[docs] def setSliceAxis(self, axis): self.source_axis = axis self.updateSourceRange() self.updateSourceSlice() for img_itm, src in zip(self.image_items, self.sources): img_itm.updateImage(src[self.source_slice]) img_itm.setRect(QRect(0, 0, self.source_range_x, self.source_range_y)) self.view.setXRange(0, self.source_range_x) self.view.setYRange(0, self.source_range_y) self.updateSlicer()
[docs] def updateImage(self): for img_item, src in zip(self.image_items, self.sources): if self.max_projection is not None: image = np.max(src[self.source_slice[:src.ndim]], axis=self.source_axis) else: image = src[self.source_slice[:src.ndim]] if image.dtype == bool: image = image.view('uint8') img_item.updateImage(image)
[docs] def setMinMax(self, min_max, source=None): if source is None: if not isinstance(min_max, list): min_max = [min_max] * len(self.sources) source = list(range(len(self.sources))) else: if not isinstance(source, list): source = [source] if not isinstance(min_max, list): min_max = [min_max] * len(source) for s, mM in enumerate(min_max): self.luts[s].lut.region.setRegion(mM)
[docs] def shape3d(self, shape): return (shape + (1,) * 3)[:3]
[docs] def change_max_projection(self, value=None): if value is not None: self.max_projection_edit.setText('%d' % value) text = self.max_projection_edit.text() try: text = int(text) if text <= 0: text = None except ValueError: text = None self.max_projection = text self.updateSlice(force_update=True)
[docs] def initialize_points_item(self): if self.points_item is not None: self.view.removeItem(self.points_item) self.points_item = pg.ScatterPlotItem(**self.points_style) self.view.addItem(self.points_item)
[docs] def set_points(self, points): self.points = points if self.points is not None: self.points = as_source(points) self.initialize_points_item() self.update_points()
[docs] def update_points(self): if self.points is not None: points = self.points axis = self.source_axis axes = [d for d in range(3) if d != axis] index = self.source_index[axis] # select points in slice valid_min, valid_max = index - 0.5, index + 0.5 if self.max_projection is not None: valid_min, valid_max = valid_min - self.max_projection, valid_max + self.max_projection valid = np.logical_and(valid_min < points[..., axis], points[..., axis] <= valid_max) points = points[valid] x, y = [points[:, a] + 0.5 for a in axes] self.points_item.setData(x=x, y=y)
[docs] def change_points_color(self): color = self.points_color_button.color() self.points_style['brush'] = color self.points_item.setBrush(self.points_style['brush'])
[docs] def initialize_vectors_item(self): if self.vectors_base_item is not None: self.view.removeItem(self.vectors_base_item) self.vectors_base_item = pg.ScatterPlotItem(**self.vectors_style) self.view.addItem(self.vectors_base_item)
[docs] def set_vectors(self, vectors): self.vectors = vectors self.update_vectors()
[docs] def update_vectors(self): if self.vectors is not None: vectors = self.vectors axis = self.source_axis index = self.source_index[axis] slicing = tuple(slice(None) if a != axis else index for a in range(3)) axes = [d for d in range(3) if d != axis] vx, vy = [vectors[slicing + (a,)] for a in axes] x, y = np.meshgrid(np.arange(vectors.shape[axes[0]], dtype=float), np.arange(vectors.shape[axes[1]], dtype=float), indexing='ij') if self.vectors_style.get('threshold', None) is not None: select = self.sources[0][self.sourceSlice()] > self.vectors_style.get('threshold') vx, vy = vx[select], vy[select] x, y = x[select], y[select] else: vx, vy = vx.flatten(), vy.flatten() x, y = x.flatten(), y.flatten() x += 0.5 y += 0.5 px, py = np.zeros(x.shape[0] * 2), np.zeros(y.shape[0] * 2) px[0::2] = x px[1::2] = x + vx py[0::2] = y py[1::2] = y + vy path = pg.arrayToQPath(px, py, 'pairs') if self.vectors_item is not None: self.view.removeItem(self.vectors_item) self.vectors_item = QGraphicsPathItem(path) self.vectors_item.setPen(pg.mkPen(self.vectors_style.get('brush'))) self.view.addItem(self.vectors_item) self.vectors_base_item.setData(x=x, y=y)
[docs] def change_vectors_threshold(self, value=None): if value is not None: self.vectors_threshold_edit.setText('%d' % value) text = self.vectors_threshold_edit.text() # print('text=',text) try: value = float(text) except ValueError: value = None self.vectors_style['threshold'] = value self.updateSlice(force_update=True)
[docs] def change_vectors_color(self): color = self.vectors_color_button.color() self.vectors_style['brush'] = color self.vectors_item.setPen(pg.mkPen(self.vectors_style['brush'])) self.vectors_base_item.setBrush(self.vectors_style['brush'])
[docs] def set_orientations(self, orientations): self.orientations = orientations self.update_orientations()
[docs] def update_orientations(self): if self.orientations is not None: orientations = self.orientations axis = self.source_axis index = self.source_index[axis] slicing = tuple(slice(None) if a != axis else index for a in range(3)) axes = [d for d in range(3) if d != axis] vx, vy = [orientations[slicing + (a,)] for a in axes] x, y = np.meshgrid(np.arange(orientations.shape[axes[0]], dtype=float), np.arange(orientations.shape[axes[1]], dtype=float), indexing='ij') if self.orientations_style.get('threshold', None) is not None: select = self.sources[0][self.sourceSlice()] > self.orientations_style.get('threshold') vx, vy = vx[select], vy[select] x, y = x[select], y[select] else: vx, vy = vx.flatten(), vy.flatten() x, y = x.flatten(), y.flatten() x += 0.5 y += 0.5 px, py = np.zeros(x.shape[0] * 2), np.zeros(y.shape[0] * 2) l = 0.45 px[0::2] = x - l * vx; px[1::2] = x + l * vx py[0::2] = y - l * vy; py[1::2] = y + l * vy path = pg.arrayToQPath(px, py, 'pairs') if self.orientations_item is not None: self.view.removeItem(self.orientations_item) self.orientations_item = QGraphicsPathItem(path) self.orientations_item.setPen(pg.mkPen(self.orientations_style.get('pen'))) self.view.addItem(self.orientations_item)
[docs] def change_orientations_threshold(self, value=None): if value is not None: self.orientations_threshold_edit.setText('%d' % value) text = self.orientations_threshold_edit.text() # print('text=',text) try: value = float(text) except ValueError: value = None self.orientations_style['threshold'] = value self.updateSlice(force_update=True)
[docs] def change_orientations_color(self): color = self.orientations_color_button.color() self.orientations_style['pen'] = color self.orientations_item.setPen(self.orientations_style['pen'])
[docs] def set_color_scheme(self, type_, lut=0): self.luts[lut].lut.item.gradient.loadPreset(type_)
[docs] def handleMouseClick(self, event): event.accept() x, y = self.get_coords(event.scenePos()) btn = event.button() if btn != 1: return x_axis, y_axis = self.getXYAxes() scaled_x, scaled_y = self.scale_coords(x, x_axis, y, y_axis) self.mouse_clicked.emit(scaled_x, scaled_y, self.source_index[self.scroll_axis])
# def __cast_bools(self): # for i, s in enumerate(self.sources): # if s.dtype == bool: # self.sources[i] = s.view('uint8') # def __ensure_3d(self): # for i, s in enumerate(self.sources): # if s.ndim == 2: # s = s.view() # s.shape = s.shape + (1,) # self.sources[i] = s # if s.ndim != 3: # raise RuntimeError(f"Sources don't have dimensions 2 or 3 but {s.ndim} in source {i}!") ############################################################################################################ # ## Tests ############################################################################################################ def _test(): import numpy as np import ClearMap.Visualization.Qt.DataViewerAxon as dv from importlib import reload reload(dv) img1 = np.random.rand(*(100, 80, 30)) points = np.array(np.where(img1 > 0.99)).T points.shape vectors = np.random.rand(*(img1.shape + (3,))) # %gui qt reload(dv) dv.DataViewer(img1, points=points, vectors=vectors) from importlib import reload reload(dv) # TODOs: # real rgb images # cleanup interface (new pyhton style) # histogram functions