"""
DataViewer
==========
Data viewer showing 3d data as 2d slices.
Usage
-----
.. image:: ../static/DataViewer.jpg
Note
----
This viewer is based on the pyqtgraph package.
"""
__author__ = 'Christoph Kirst <christoph.kirst.ck@gmail.com>, Charly Rousseau <charly.rousseau@icm-institute.org>'
__license__ = 'GPLv3 - GNU General Public License v3 (see LICENSE)'
__copyright__ = 'Copyright © 2020 by Christoph Kirst'
__webpage__ = 'https://idisco.info'
__download__ = 'https://www.github.com/ChristophKirst/ClearMap2'
import time
import functools as ft
import numpy as np
import pyqtgraph as pg
from PyQt5 import QtWidgets
from PyQt5.QtCore import QEvent, QRect, QSize, pyqtSignal, Qt
from PyQt5.QtGui import QPainter
from PyQt5.QtWidgets import QWidget, QRadioButton, QLabel, QSplitter, QApplication, QSizePolicy, QPushButton, QCheckBox, \
QGraphicsPathItem, QGridLayout, QLineEdit, QScrollArea
from ClearMap.Utils.utilities import runs_on_spyder
from ClearMap.IO.IO import as_source
from ClearMap.IO.Source import Source
from ClearMap.Visualization.Qt.data_viewer_luts import LUT
pg.CONFIG_OPTIONS['useOpenGL'] = False # set to False if trouble seeing data.
if not pg.QAPP:
pg.mkQApp()
[docs]
class DataViewer(QWidget):
mouse_clicked = pyqtSignal(int, int, int)
DEFAULT_SCATTER_PARAMS = {
'pen': 'red',
'brush': 'red',
'symbol': '+',
'size': 10
}
def __init__(self, source,
points=None, vectors=None, orientations=None, annotation=None,
axis=None, scale=None, title=None,
invertY=False, minMax=None, screen=None, parent=None, default_lut='flame', max_projection=None,
points_style=None, vectors_style=None, original_orientation='zcxy', orientations_style=None, **kwargs):
# super().__init__(self, parent, **kwargs)
QWidget.__init__(self, parent, **kwargs) # TODO: check why super() doesn't handle **kwargs properly
# Images sources
self.sources = []
self.original_orientation = original_orientation
self.n_sources = 0
self.scroll_axis = None
self.source_shape = None
self.source_scale = None # xyz scaling factors between display and real coordinates
self.source_index = None # The xyz center of the current view
self.source_range_x = None
self.source_range_y = None
self.source_slice = None # current slice (in scroll axis)
self.cross = None # cursor
self.pals = [] # linked DataViewers
self.scatter = None
self.scatter_coords = None
self.atlas = None # WARNING: overlap w/ self.anotation ??
self.structure_names = None
self.z_cursor_width = 5
self.points = points
if self.points is not None:
self.points = as_source(points).array
self.points_item = None
self.points_style = dict(pen=None, brush='white')
if points_style is not None:
self.points_style.update(points_style)
self.vectors = vectors
if self.vectors is not None:
self.vectors = as_source(vectors).array
self.vectors_item = None
self.vectors_base_item = None
self.vectors_style = dict(pen=None, brush='lightblue')
if vectors_style is not None:
self.vectors_style.update(vectors_style)
self.orientations = orientations
if self.orientations is not None:
self.orientations = as_source(orientations).array
self.orientations_item = None
self.orientations_style = dict(pen='gray')
if orientations_style is not None:
self.orientations_style.update(orientations_style)
self.vectors = vectors
self.vectors_item = None
self.annotation = annotation
self.initializeSources(source, axis=axis, scale=scale)
# ## Gui Construction
original_title = title
if title is None:
if isinstance(source, str):
title = source
elif isinstance(source, Source):
title = source.location
if title is None:
title = 'DataViewer'
self.setWindowTitle(title)
self.resize(1600, 1200)
self.layout = QtWidgets.QGridLayout(self)
self.layout.setContentsMargins(0, 0, 0, 0)
# image pane
self.view = pg.ViewBox()
self.view.setAspectLocked(True)
self.view.invertY(invertY)
self.graphicsView = pg.GraphicsView()
self.graphicsView.setObjectName("GraphicsView")
self.graphicsView.setCentralItem(self.view)
splitter = QSplitter()
splitter.setOrientation(Qt.Horizontal)
splitter.setSizes([self.width() - 10, 10])
self.layout.addWidget(splitter)
image_splitter = QSplitter()
image_splitter.setOrientation(Qt.Vertical)
image_splitter.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
splitter.addWidget(image_splitter)
# Image plots
image_options = dict(clipToView=True, autoDownsample=True, autoLevels=False, useOpenGL=None)
if self.all_colour:
self.image_items = []
for s in self.sources:
slc = self.source_slice[:s.ndim]
layer = self.color_last(s.array[slc])
self.image_items.append(pg.ImageItem(layer, **image_options))
else:
self.image_items = [pg.ImageItem(s[self.source_slice[:s.ndim]], **image_options) for s in self.sources]
for itm in self.image_items:
itm.setRect(QRect(0, 0, int(self.source_range_x), int(self.source_range_y)))
itm.setCompositionMode(QPainter.CompositionMode_Plus)
self.view.addItem(itm)
self.view.setXRange(0, self.source_range_x)
self.view.setYRange(0, self.source_range_y)
# slice selector
if original_title:
self.slicePlot = pg.PlotWidget(title=f"""
<html><head/><body>
<h1 style=" margin-top:18px; margin-bottom:12px; margin-left:0px; margin-right:0px;
-qt-block-indent:0; text-indent:0px;">
<span style=" font-size:xx-large; font-weight:700;">{original_title}</span></h1></body></html>
""")
else:
self.slicePlot = pg.PlotWidget()
size_policy = QSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred) # TODO: add option for sizepolicy
size_policy.setHorizontalStretch(0)
size_policy.setVerticalStretch(0)
size_policy.setHeightForWidth(self.slicePlot.sizePolicy().hasHeightForWidth())
self.slicePlot.setSizePolicy(size_policy)
self.slicePlot.setMinimumSize(QSize(0, 40 + 40*bool(original_title)))
self.slicePlot.setObjectName("roiPlot")
self.sliceLine = pg.InfiniteLine(0, movable=True)
self.sliceLine.setPen((255, 255, 255, 200), width=self.z_cursor_width)
self.sliceLine.setZValue(1)
self.slicePlot.addItem(self.sliceLine)
self.slicePlot.hideAxis('left')
self.slicePlot.installEventFilter(self)
self.updateSlicer()
self.sliceLine.sigPositionChanged.connect(self.updateSlice)
# Axis tools
self.axis_buttons = []
axis_tools_layout, axis_tools_widget = self.__setup_axes_controls()
# max projection depth
self.max_projection = max_projection
self.max_projection_edit = QLineEdit()
if self.max_projection is not None:
self.max_projection_edit.setText('%d' % self.max_projection)
# self.max_projection_edit.setValidator(pg.QtGui.QIntValidator())
self.max_projection_edit.setMaxLength(4)
self.max_projection_edit.setAlignment(Qt.AlignRight)
self.max_projection_edit.setMaximumWidth(60)
self.max_projection_edit.editingFinished.connect(self.change_max_projection)
axis_tools_layout.addWidget(self.max_projection_edit, 0, 3)
# points color
self.points_color_button = pg.ColorButton(color=self.points_style.get('brush'))
self.points_color_button.setMaximumWidth(30)
self.points_color_button.sigColorChanged.connect(self.change_points_color)
axis_tools_layout.addWidget(self.points_color_button, 0, 4)
# vectors color and threshold
self.vectors_color_button = pg.ColorButton(color=self.vectors_style.get('brush'))
self.vectors_color_button.setMaximumWidth(30)
self.vectors_color_button.sigColorChanged.connect(self.change_vectors_color)
axis_tools_layout.addWidget(self.vectors_color_button, 0, 5)
self.vectors_threshold_edit = QLineEdit()
vectors_threshold = self.vectors_style.get('threshold', None)
if vectors_threshold is not None:
self.vectors_threshold_edit.setText('%.4f' % vectors_threshold)
self.vectors_threshold_edit.setMaxLength(6)
self.vectors_threshold_edit.setAlignment(Qt.AlignRight)
self.vectors_threshold_edit.setMaximumWidth(60)
self.vectors_threshold_edit.editingFinished.connect(self.change_vectors_threshold)
axis_tools_layout.addWidget(self.vectors_threshold_edit, 0, 6)
# orientation threshold
self.orientations_color_button = pg.ColorButton(color=self.orientations_style.get('pen'))
self.orientations_color_button.setMaximumWidth(30)
self.orientations_color_button.sigColorChanged.connect(self.change_orientations_color)
axis_tools_layout.addWidget(self.orientations_color_button, 0, 7)
self.orientations_threshold_edit = QLineEdit()
orientations_threshold = self.orientations_style.get('threshold', None)
if orientations_threshold is not None:
self.orientations_threshold_edit.setText('%.4f' % orientations_threshold)
self.orientations_threshold_edit.setMaxLength(6)
self.orientations_threshold_edit.setAlignment(Qt.AlignRight)
self.orientations_threshold_edit.setMaximumWidth(60)
self.orientations_threshold_edit.editingFinished.connect(self.change_orientations_threshold)
axis_tools_layout.addWidget(self.orientations_threshold_edit, 0, 8)
# coordinate label
self.source_pointer = np.zeros(self.sources[0].ndim, dtype=int)
self.source_label = QLabel("")
self.source_label_scroll = QScrollArea()
self.source_label_scroll.setMaximumHeight(30)
self.source_label_scroll.setWidgetResizable(True)
self.source_label_scroll.horizontalScrollBar().setStyleSheet("QScrollBar {height:0px;}")
self.source_label_scroll.setWidget(self.source_label)
axis_tools_layout.addWidget(self.source_label_scroll, 0, 9)
self.graphicsView.scene().sigMouseMoved.connect(self.updateLabelFromMouseMove)
# compose the image viewer
image_splitter.addWidget(self.graphicsView)
image_splitter.addWidget(self.slicePlot)
image_splitter.addWidget(axis_tools_widget)
image_splitter.setSizes([self.height() - 35 - 20, 35, 20])
# lut widgets
self.luts = [LUT(image=i, color=c) for i, c in zip(self.image_items, self.__get_colors(default_lut))]
lut_layout = QtWidgets.QGridLayout()
lut_layout.setContentsMargins(0, 0, 0, 0)
for d, lut in enumerate(self.luts):
lut_layout.addWidget(lut, 0, d)
lut_widget = QWidget()
lut_widget.setLayout(lut_layout)
lut_widget.setContentsMargins(0, 0, 0, 0)
lut_widget.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding) # TODO: add option for sizepolicy
splitter.addWidget(lut_widget)
splitter.setStretchFactor(0, 1)
splitter.setStretchFactor(1, 0)
# update scale
for lut in self.luts:
lut.range_buttons[1][2].click()
if minMax is not None:
self.setMinMax(minMax)
self.initialize_points_item()
self.initialize_vectors_item()
self.change_max_projection()
# self.change_orientations_threshold()
self.show()
@property
def space_axes(self):
color_axis = self.color_axis
if color_axis is None:
color_axis = -1 # Cannot use None with == testing because of implicit cast
return [ax for ax in range(self.sources[0].ndim) if ax != color_axis]
[docs]
def eventFilter(self, source, event):
if event.type() == QEvent.Wheel:
angle = event.angleDelta().y()
# steps = angle / abs(angle)
steps = angle / 120
self.sliceLine.setValue(self.sliceLine.value() + steps)
return super().eventFilter(source, event)
def __cast_source(self, source):
if isinstance(source, tuple):
source = list(source)
if not isinstance(source, list):
source = [source]
return source
[docs]
def initializeSources(self, source, scale=None, axis=None, update=True):
# initialize sources and axis settings
source = self.__cast_source(source)
self.n_sources = len(source)
self.sources = [as_source(s) for s in source]
for s in self.sources:
if s.ndim == 2:
s.shape = s.shape + (1,) # Add empty z dimension # FIXME: see if works or need to expand_dims
# self.__cast_bools()
# self.__ensure_3d()
# source shapes
self.source_shape = self.padded_shape(self.sources[0].shape)
for s in self.sources:
if s.ndim > 4:
raise RuntimeError(f'Source has {s.ndim} > 4 dimensions: {s}!')
if s.shape[:2] != self.source_shape[:2]:
raise RuntimeError(f'Sources shape {self.source_shape} vs {s.shape} in source {s}!')
# slicing
shape = list(range(self.sources[0].ndim))
if 3 in self.sources[0].shape: # Color image
shape.pop(self.sources[0].shape.index(3))
self.scroll_axis = axis if axis is not None else shape[-1] # Default to last axis
self.source_index = (np.array(self.source_shape, dtype=float) / 2).astype(int)
# scaling
scale = np.array(scale) if scale is not None else np.array([]) # Test Not default np.ones(3) ??
self.source_scale = np.pad(scale, (0, self.sources[0].ndim - len(scale)), 'constant', constant_values=1)
self.updateSourceRange()
self.updateSourceSlice()
[docs]
def setSource(self, source, index='all'): # TODO: see if could factor with __init__
if index == 'all':
source = self.__cast_source(source)
if self.n_sources != len(source):
raise RuntimeError(f'Number of sources does not match! got {len(source)}, expected {self.n_sources}')
source = [as_source(s) for s in source]
index = range(self.n_sources)
else:
s = self.sources
s[index] = as_source(source)
source = s
index = [index]
# self.__cast_bools()
for i in index:
s = source[i]
if s.shape != self.source_shape:
raise RuntimeError('Shape of sources does not match!')
elif s.ndim < 2 or s.ndim > 4: # FIXME: handle RGB
raise RuntimeError(f'Sources dont have dimensions 2, 3 or 4 but {s.ndim} in source {i}!')
if s.ndim == 4:
layer = self.color_last(s.array[self.source_slice[:s.ndim]])
self.image_items[i].updateImage(layer)
else:
if s.ndim == 2:
s.shape = s.shape + (1,)
self.image_items[i].updateImage(s[self.source_slice[:s.ndim]])
self.sources = source
def __setup_axes_controls(self):
axis_tools_layout = QGridLayout()
for d, ax in enumerate('xyz'):
button = QRadioButton(ax)
button.setMaximumWidth(50)
axis_tools_layout.addWidget(button, 0, d)
button.clicked.connect(ft.partial(self.setSliceAxis, d))
self.axis_buttons.append(button)
self.axis_buttons[self.space_axes.index(self.scroll_axis)].setChecked(True)
axis_tools_widget = QWidget()
axis_tools_widget.setLayout(axis_tools_layout)
for i in range(self.n_sources):
box = QCheckBox(f'{i}')
box.setMaximumWidth(50)
box.setChecked(True)
box.stateChanged.connect(ft.partial(self.toggle_layer, i))
axis_tools_layout.addWidget(box, 1, i)
self.axis_buttons.append(box)
return axis_tools_layout, axis_tools_widget
[docs]
def toggle_layer(self, i, state):
self.image_items[i].setVisible(state == Qt.Checked)
def __get_colors(self, default_lut):
if self.n_sources == 1:
cols = [default_lut]
elif self.n_sources == 2:
cols = ['purple', 'green']
else:
cols = np.array(['white', 'green', 'red', 'blue', 'purple'] * self.n_sources)[:self.n_sources]
return cols
[docs]
def color_last(self, source):
shape = np.array(source.shape)
c_idx = np.where(shape == 3)[0]
indices = np.delete(np.arange(source.ndim), c_idx[0])
indices = np.hstack((indices, c_idx))
return source.transpose(indices)
[docs]
def is_color(self, source):
return source.ndim > 3 and 3 in source.shape
@property
def color_axis(self):
try:
return self.sources[0].shape.index(3)
except ValueError:
return None
@property
def all_colour(self):
return all([self.is_color(s) for s in self.sources])
[docs]
def getXYAxes(self): # FIXME: properties
return [ax for ax in range(self.sources[0].ndim) if ax not in (self.scroll_axis, self.color_axis)]
[docs]
def updateSourceRange(self):
x, y = self.getXYAxes()
self.source_range_x = round(self.source_scale[x] * self.source_shape[x]) # TODO: check if round
self.source_range_y = round(self.source_scale[y] * self.source_shape[y])
[docs]
def updateSourceSlice(self):
"""Set the current slice of the source"""
if self.all_colour:
self.source_slice = [slice(None)] * 4 # TODO: check if could use self.sources[0].ndim
else:
self.source_slice = [slice(None)] * 3
if self.scroll_axis:
self.source_slice[self.scroll_axis] = self.source_index[self.scroll_axis]
self.source_slice = tuple(self.source_slice)
[docs]
def updateSlicer(self):
ax = self.scroll_axis
self.slicePlot.setXRange(0, self.source_shape[ax])
self.sliceLine.setValue(self.source_index[ax])
stop = self.source_shape[ax] + 0.5
self.sliceLine.setBounds([0, stop])
[docs]
def updateLabelFromMouseMove(self, event_pos):
x, y = self.get_coords(event_pos)
self.sync_cursors(x, y)
self._updateCoords(x, y)
def _updateCoords(self, x, y):
x_axis, y_axis = self.getXYAxes()
pos = [None] * self.sources[0].ndim
scaled_x, scaled_y = self.scale_coords(x, x_axis, y, y_axis)
z = self.source_index[self.scroll_axis]
pos[x_axis] = scaled_x
pos[y_axis] = scaled_y
pos[self.scroll_axis] = z
self.source_pointer = np.array(pos)
self.updateLabel()
[docs]
def scale_coords(self, x, x_axis, y, y_axis):
scaled_x = min(int(x / self.source_scale[x_axis]), self.source_shape[x_axis] - 1)
scaled_y = min(int(y / self.source_scale[y_axis]), self.source_shape[y_axis] - 1)
return scaled_x, scaled_y
[docs]
def get_coords(self, pos):
mouse_point = self.view.mapSceneToView(pos)
x, y = mouse_point.x(), mouse_point.y()
x = min(max(0, x), self.source_range_x)
y = min(max(0, y), self.source_range_y)
return x, y
[docs]
def sync_cursors(self, x, y):
if self.cross is not None:
self.cross.set_coords([x, y])
self.view.update()
for pal in self.pals:
pal.cross.set_coords([x, y])
pal._updateCoords(x, y)
pal.view.update()
[docs]
def updateLabel(self):
x_axis, y_axis = self.getXYAxes()
x, y, z = self.source_pointer[[x_axis, y_axis, self.scroll_axis]]
xs, ys, zs = self.source_scale[[x_axis, y_axis, self.scroll_axis]]
slc = [Ellipsis] * max(3, self.sources[0].ndim)
slc[x_axis] = x
slc[y_axis] = y
slc[self.scroll_axis] = z
slc = tuple(slc)
if self.all_colour:
vals = ", ".join([str(s.array[slc]) for s in self.sources])
else: # FIXME: check why array does not work for ndim = 3 (i.e. why we need 2 versions)
vals = ", ".join([str(s[slc]) for s in self.sources])
label = f"({x}, {y}, {z}) {{{x*xs:.2f}, {y*ys:.2f}, {z*zs:.2f}}} [{vals}]"
if self.annotation is not None:
struct_info = self.annotation.get(self.sources[0][x, y, z], None)
if struct_info:
label += f"[{struct_info}]"
elif self.atlas is not None:
try:
id_ = np.asscalar(self.atlas[slc]) # Deprecated since np version 1.16
except AttributeError:
id_ = self.atlas[slc].item()
label = f" <b style='color:#2d9cfc;'>{self.structure_names[id_]} ({id_})</b>" + label
if self.parent() is None or not self.parent().objectName().lower().startswith('dataviewer'):
label = f"<span style='font-size: 12pt; color: black'>{label}</span>"
self.source_label.setText(label)
[docs]
def updateSlice(self, force_update=False):
ax = self.scroll_axis
index = min(max(0, int(self.sliceLine.value())), self.source_shape[ax]-1)
if self.max_projection is not None:
slc_ax = (
slice(max(0, index - self.max_projection), min(self.source_shape[ax], index + self.max_projection)),);
else:
slc_ax = (index,)
if index != self.source_index[ax] or force_update:
self.source_index[ax] = index
self.source_slice = self.source_slice[:ax] + slc_ax + self.source_slice[ax+1:]
self.source_pointer[ax] = index
self.updateLabel()
self.updateImage()
self.update_points()
self.update_vectors()
self.update_orientations()
if self.scatter is not None:
self.plot_scatter_markers(ax, index)
[docs]
def refresh(self):
"""
Forces the plot to refresh, notably to display scatter info on top
Returns
-------
"""
self.sliceLine.setValue(self.sliceLine.value() + 1)
self.sliceLine.setValue(self.sliceLine.value() - 1)
[docs]
def setSliceAxis(self, axis):
# old_scroll_axis = self.scroll_axis
self.scroll_axis = self.space_axes[axis]
self.updateSourceRange()
self.updateSourceSlice()
for img_itm, src in zip(self.image_items, self.sources):
slc = self.source_slice
if self.all_colour:
layer = src.array[slc]
img_itm.updateImage(self.color_last(layer))
else:
img_itm.updateImage(src[slc])
img_itm.setRect(QRect(0, 0, self.source_range_x, self.source_range_y))
self.view.setXRange(0, self.source_range_x)
self.view.setYRange(0, self.source_range_y)
self.updateSlicer()
self.refresh()
[docs]
def updateImage(self):
for img_item, src in zip(self.image_items, self.sources):
slc = self.source_slice[:src.ndim]
if self.max_projection is not None:
image = np.max(src[self.source_slice[:src.ndim]], axis=self.source_axis)
elif self.all_colour:
image = src.array[slc]
image = self.color_last(image)
else:
image = src[slc]
if image.dtype == bool:
image = image.view('uint8')
img_item.updateImage(image)
[docs]
def setMinMax(self, min_max, source=None):
if source is None:
if not isinstance(min_max, list):
min_max = [min_max] * len(self.sources)
source = list(range(len(self.sources)))
else:
if not isinstance(source, list):
source = [source]
if not isinstance(min_max, list):
min_max = [min_max] * len(source)
for s, mM in enumerate(min_max):
self.luts[s].lut.region.setRegion(mM)
[docs]
def plot_scatter_markers(self, ax, index):
self.scatter.clear()
self.scatter_coords.axis = ax
pos = self.scatter_coords.get_pos(index)
if all(pos.shape):
if self.scatter_coords.has_colours:
self.scatter.setData(pos=pos,
symbol=(self.scatter_coords.get_symbols(index)),
size=10, # FIXME: scale size as function of zoom
**self.scatter_coords.get_draw_params(index))
else:
self.scatter.setData(pos=pos, **DataViewer.DEFAULT_SCATTER_PARAMS.copy()) # TODO: check if copy required
try: # FIXME: check why some markers trigger errors
if self.scatter_coords.half_slice_thickness is not None:
marker_params = self.scatter_coords.get_all_data(index)
if marker_params['pos'].shape[0]:
self.scatter.addPoints(symbol='o', brush=pg.mkBrush((0, 0, 0, 0)),
**marker_params) # FIXME: scale size as function of zoom
except KeyError as err:
print(f'DataViewer error: {err}')
[docs]
def change_max_projection(self, value=None):
if value is not None:
self.max_projection_edit.setText('%d' % value)
text = self.max_projection_edit.text()
try:
text = int(text)
if text <= 0:
text = None
except ValueError:
text = None
self.max_projection = text
self.updateSlice(force_update=True)
[docs]
def initialize_points_item(self):
if self.points_item is not None:
self.view.removeItem(self.points_item)
self.points_item = pg.ScatterPlotItem(**self.points_style)
self.view.addItem(self.points_item)
[docs]
def set_points(self, points):
self.points = points
if self.points is not None:
self.points = as_source(points)
self.initialize_points_item()
self.update_points()
[docs]
def update_points(self):
if self.points is not None:
points = self.points
axis = self.source_axis
axes = [d for d in range(3) if d != axis]
index = self.source_index[axis]
# select points in slice
valid_min, valid_max = index - 0.5, index + 0.5
if self.max_projection is not None:
valid_min, valid_max = valid_min - self.max_projection, valid_max + self.max_projection
valid = np.logical_and(valid_min < points[..., axis], points[..., axis] <= valid_max)
points = points[valid]
x, y = [points[:, a] + 0.5 for a in axes]
self.points_item.setData(x=x, y=y)
[docs]
def change_points_color(self):
color = self.points_color_button.color()
self.points_style['brush'] = color
self.points_item.setBrush(self.points_style['brush'])
[docs]
def initialize_vectors_item(self):
if self.vectors_base_item is not None:
self.view.removeItem(self.vectors_base_item)
self.vectors_base_item = pg.ScatterPlotItem(**self.vectors_style)
self.view.addItem(self.vectors_base_item)
[docs]
def set_vectors(self, vectors):
self.vectors = vectors
self.update_vectors()
[docs]
def update_vectors(self):
if self.vectors is not None:
vectors = self.vectors
axis = self.source_axis
index = self.source_index[axis]
slicing = tuple(slice(None) if a != axis else index for a in range(3))
axes = [d for d in range(3) if d != axis]
vx, vy = [vectors[slicing + (a,)] for a in axes]
x, y = np.meshgrid(np.arange(vectors.shape[axes[0]], dtype=float),
np.arange(vectors.shape[axes[1]], dtype=float),
indexing='ij')
if self.vectors_style.get('threshold', None) is not None:
select = self.sources[0][self.sourceSlice()] > self.vectors_style.get('threshold')
vx, vy = vx[select], vy[select]
x, y = x[select], y[select]
else:
vx, vy = vx.flatten(), vy.flatten()
x, y = x.flatten(), y.flatten()
x += 0.5
y += 0.5
px, py = np.zeros(x.shape[0] * 2), np.zeros(y.shape[0] * 2)
px[0::2] = x
px[1::2] = x + vx
py[0::2] = y
py[1::2] = y + vy
path = pg.arrayToQPath(px, py, 'pairs')
if self.vectors_item is not None:
self.view.removeItem(self.vectors_item)
self.vectors_item = QGraphicsPathItem(path)
self.vectors_item.setPen(pg.mkPen(self.vectors_style.get('brush')))
self.view.addItem(self.vectors_item)
self.vectors_base_item.setData(x=x, y=y)
[docs]
def change_vectors_threshold(self, value=None):
if value is not None:
self.vectors_threshold_edit.setText('%d' % value)
text = self.vectors_threshold_edit.text()
# print('text=',text)
try:
value = float(text)
except ValueError:
value = None
self.vectors_style['threshold'] = value
self.updateSlice(force_update=True)
[docs]
def change_vectors_color(self):
color = self.vectors_color_button.color()
self.vectors_style['brush'] = color
self.vectors_item.setPen(pg.mkPen(self.vectors_style['brush']))
self.vectors_base_item.setBrush(self.vectors_style['brush'])
[docs]
def set_orientations(self, orientations):
self.orientations = orientations
self.update_orientations()
[docs]
def update_orientations(self):
if self.orientations is not None:
orientations = self.orientations
axis = self.source_axis
index = self.source_index[axis]
slicing = tuple(slice(None) if a != axis else index for a in range(3))
axes = [d for d in range(3) if d != axis]
vx, vy = [orientations[slicing + (a,)] for a in axes]
x, y = np.meshgrid(np.arange(orientations.shape[axes[0]], dtype=float),
np.arange(orientations.shape[axes[1]], dtype=float), indexing='ij')
if self.orientations_style.get('threshold', None) is not None:
select = self.sources[0][self.sourceSlice()] > self.orientations_style.get('threshold')
vx, vy = vx[select], vy[select]
x, y = x[select], y[select]
else:
vx, vy = vx.flatten(), vy.flatten()
x, y = x.flatten(), y.flatten()
x += 0.5
y += 0.5
px, py = np.zeros(x.shape[0] * 2), np.zeros(y.shape[0] * 2)
l = 0.45
px[0::2] = x - l * vx;
px[1::2] = x + l * vx
py[0::2] = y - l * vy;
py[1::2] = y + l * vy
path = pg.arrayToQPath(px, py, 'pairs')
if self.orientations_item is not None:
self.view.removeItem(self.orientations_item)
self.orientations_item = QGraphicsPathItem(path)
self.orientations_item.setPen(pg.mkPen(self.orientations_style.get('pen')))
self.view.addItem(self.orientations_item)
[docs]
def change_orientations_threshold(self, value=None):
if value is not None:
self.orientations_threshold_edit.setText('%d' % value)
text = self.orientations_threshold_edit.text()
# print('text=',text)
try:
value = float(text)
except ValueError:
value = None
self.orientations_style['threshold'] = value
self.updateSlice(force_update=True)
[docs]
def change_orientations_color(self):
color = self.orientations_color_button.color()
self.orientations_style['pen'] = color
self.orientations_item.setPen(self.orientations_style['pen'])
[docs]
def set_color_scheme(self, type_, lut=0):
self.luts[lut].lut.item.gradient.loadPreset(type_)
[docs]
def enable_mouse_clicks(self):
self.graphicsView.scene().sigMouseClicked.connect(self.handleMouseClick)
[docs]
def handleMouseClick(self, event):
event.accept()
x, y = self.get_coords(event.scenePos())
btn = event.button()
if btn != 1:
return
x_axis, y_axis = self.getXYAxes()
scaled_x, scaled_y = self.scale_coords(x, x_axis, y, y_axis)
self.mouse_clicked.emit(scaled_x, scaled_y, self.source_index[self.scroll_axis])
[docs]
def padded_shape(self, shape):
pad_size = max(3, len(shape))
return (shape + (1,) * pad_size)[:pad_size]
# def __cast_bools(self):
# for i, s in enumerate(self.sources):
# if s.dtype == bool:
# self.sources[i] = s.view('uint8')
# def __ensure_3d(self):
# for i, s in enumerate(self.sources):
# if s.ndim == 2:
# s = s.view()
# s.shape = s.shape + (1,)
# self.sources[i] = s
# if s.ndim != 3:
# raise RuntimeError(f"Sources don't have dimensions 2 or 3 but {s.ndim} in source {i}!")
############################################################################################################
# ## Tests
############################################################################################################
def _test():
import numpy as np
import ClearMap.Visualization.Qt.DataViewerAxon as dv
from importlib import reload
reload(dv)
img1 = np.random.rand(*(100, 80, 30))
if not runs_on_spyder():
pg.mkQApp()
DataViewer(img1)
if not runs_on_spyder():
instance = QApplication.instance()
instance.exec_()
points = np.array(np.where(img1 > 0.99)).T
points.shape
vectors = np.random.rand(*(img1.shape + (3,)))
# %gui qt
reload(dv)
dv.DataViewer(img1, points=points, vectors=vectors)
if __name__ == '__main__':
print('testing')
_test()
time.sleep(60)