"""
Plot3d
======
Plotting routines based on qt.
Note
----
This module is based on the pyqtgraph package.
"""
__author__ = 'Christoph Kirst <christoph.kirst.ck@gmail.com>'
__license__ = 'GPLv3 - GNU General Public License v3 (see LICENSE)'
__copyright__ = 'Copyright © 2020 by Christoph Kirst'
__webpage__ = 'https://idisco.info'
__download__ = 'https://www.github.com/ChristophKirst/ClearMap2'
import itertools
from pathlib import Path
import numpy as np
import pyqtgraph as pg
import functools as ft
from PyQt5.QtCore import QRect
from PyQt5.QtWidgets import QApplication
import ClearMap.Visualization.Qt.DataViewer as dv
import ClearMap.Visualization.Qt.utils as qtu
############################################################################################################
# Plotting
############################################################################################################
# TODO: figure / windows handler to update data in existing windows
from ClearMap.Utils.utilities import runs_on_spyder
[docs]
def plot(source, axis=None, scale=None, title=None, invert_y=True, min_max=None, screen=None,
arrange=True, lut=None, max_projection=None, to_front=True, parent=None, sync=True):
"""
Plot a source as 2d slices.
Arguments
---------
source : Source, pathlib.Path, list or dict
The source to plot. If a list is given several synchronized windows are
generated. If an element in the list is a list of sources those are
overlayed in different colors in that window.
axis : int or None
The axis along which to slice the data.
scale : tuple of float
A spatial scale for each axis used for the spatial cursor position.
title : str or None
The title of the window.
invert_y : bool
If True invert the y axis (as typically done for images).
min_max : tuple or None
The minimal and maximal values for each source. If None, determine them from
the source.
screen : int or None
Specify on which screen to open the window.
Returns
-------
plot : DataViewer
A data viewer class.
"""
if isinstance(source, Path):
source = str(source)
if not isinstance(source, (list, tuple)):
source = [source]
m_plot = multi_plot(source, axis=axis, scale=scale, title=title, invert_y=invert_y,
min_max=min_max, max_projection=max_projection, screen=screen, arrange=arrange, lut=lut, to_front=to_front,
parent=parent, sync=sync)
if not runs_on_spyder():
inst = QApplication.instance()
# if inst is not None:
# inst.exec_()
return m_plot
[docs]
def multi_plot(sources, axis=None, scale=None, title=None, invert_y=True, min_max=None,
max_projection=None, arrange=True, screen=None, lut='flame', screen_percent=90, parent=None, sync=True, to_front=True):
"""
Plot a source as 2d slices.
Arguments
---------
sources : list of sources
The sources to plot.If an element in the list is a list of sources
those are overlayed in different colors in that window.
axis : int or None
The axis along which to slice the data.
scale : tuple of float
A spatial scale for each axis used for the spatial cursor position.
title : str or None
The title of the window.
invert_y : bool
If True invert the y axis (as typically done for images).
min_max : tuple or None
The minimal and maximal values for each source. If None, determine them from
the source.
screen : int or None
Specify on which screen to open the window.
Returns
-------
plots : list of DataViewers
A list of viewer classes.
"""
if not isinstance(title, (tuple, list)):
title = [title] * len(sources)
if not isinstance(lut, (list, tuple)):
lut = [lut] * len(sources)
if min_max is None or np.isscalar(min_max[0]): # Because it is a list of lists
min_max = [min_max] * len(sources)
if not isinstance(max_projection, list):
max_projection = [max_projection] * len(sources)
dvs = [dv.DataViewer(source=src, axis=axis, scale=scale, title=title_, invertY=invert_y,
minMax=min_max_, max_projection=max_projection_, default_lut=lut_, parent=parent)
for src, title_, lut_, min_max_, max_projection_ in zip(sources, title, lut, min_max, max_projection)]
if arrange:
try:
geo = qtu.tiled_layout(len(dvs), percent=screen_percent, screen=screen)
for d, g in zip(dvs, geo):
# d.setFixedSize(int(0.95 * g[2]), int(0.9 * g[3]))
d.setGeometry(QRect(*g))
except: # FIXME: too broad
pass
if sync:
for d1, d2 in itertools.combinations(dvs, 2):
synchronize(d1, d2)
if to_front:
bring_to_front(dvs)
#for d in dvs:
# d.update();
return dvs
[docs]
def arrange_plots(plots, screen = None, screen_percent = 90):
try:
geo = qtu.tiled_layout(len(plots), percent=screen_percent, screen=screen)
for d, g in zip(plots, geo):
d.setGeometry(pg.QtCore.QRect(*g))
except:
pass
[docs]
def synchronize(viewer1, viewer2):
"""Synchronize scrolling between two data viewers"""
def sync_d1_d2_scroll():
"""sync dv1 -> dv2"""
viewer2.sliceLine.setValue(viewer1.sliceLine.value())
def sync_d1_d2_button(button, ax):
viewer2.axis_buttons[ax].setChecked(button.isChecked())
viewer1.sliceLine.sigPositionChanged.connect(sync_d1_d2_scroll)
for ax, button in enumerate(viewer1.axis_buttons):
button.clicked.connect(ft.partial(viewer2.setSliceAxis, ax))
button.clicked.connect(ft.partial(sync_d1_d2_button, button, ax))
def sync_d2_d1_scroll():
"""sync dv2 -> dv1"""
viewer1.sliceLine.setValue(viewer2.sliceLine.value())
def sync_d2_d1_button(button, ax):
viewer1.axis_buttons[ax].setChecked(button.isChecked())
viewer2.sliceLine.sigPositionChanged.connect(sync_d2_d1_scroll)
for ax, button in enumerate(viewer2.axis_buttons):
button.clicked.connect(ft.partial(viewer1.setSliceAxis, ax))
button.clicked.connect(ft.partial(sync_d2_d1_button, button, ax))
viewer1.view.setXLink(viewer2.view)
viewer1.view.setYLink(viewer2.view)
[docs]
def set_source(viewer, source):
"""Set the source data in a viewer.
Arguments
---------
viewer : DataViewer
The viewer to set a new source for.
source : Source
The source to use in the viewer.
Returns
-------
viewer : DataViewer
The viewer.
"""
viewer.setSource(source)
return viewer
[docs]
def bring_to_front(plots):
if not isinstance(plots, list):
plots = [plots]
for plot in plots:
plot.setWindowFlag(pg.Qt.QtCore.Qt.WindowStaysOnTopHint)
plot.raise_()
plot.activateWindow()
plot.show()
[docs]
def close(plots='all'):
if plots == 'all':
pg.Qt.App.closeAllWindows()
else:
if not isinstance(plots, list):
plots = [plots]
for plot in plots:
plot.close()
############################################################################################################
# ## Tests
############################################################################################################
def _test():
import numpy as np
import ClearMap.Visualization.Qt.Plot3d as p3d
img1 = np.random.rand(*(100, 80, 30))
img2 = np.random.rand(*(100, 80, 30)) > 0.5
p = p3d.plot([img1, img2]) # analysis:ignore