Source code for ClearMap.Alignment.Transformations.Transformation
# -*- coding: utf-8 -*-
"""
Transformation
==============
Abstract base class to handle transformation of coordinates and data
when cropping data or mapping between raw data and reference atlases.
"""
__author__ = 'Christoph Kirst <christoph.kirst.ck@gmail.com>'
__license__ = 'GPLv3 - GNU General Public License v3 (see LICENSE)'
__copyright__ = 'Copyright © 2020 by Christoph Kirst'
import io
import pickle
import numpy as np
# Base class
[docs]
class TransformationBase(object):
"""Base class for transformations of data, points and shapes."""
ttype=None
def __init__(self, inverse=False):
self.inverse = inverse
[docs]
def get_inverse(self, inverse=False):
return self.inverse if not inverse else not self.inverse
[docs]
def write(self, filename):
with open(filename, 'wb') as file:
pickle.dump(self.to_dict(), file, protocol=pickle.HIGHEST_PROTOCOL)
[docs]
@classmethod
def read(cls, filename):
with io.open(filename, 'rb') as file:
dictionary = pickle.load(file)
return cls.from_dict(dictionary)
def __repr__(self):
inverse = '|i|' if self.inverse else ''
return '%s%s' % (self.ttype, inverse)
# Factory
ttype_to_transformation = {}
[docs]
def transformation_from_dict(dictionary):
ttype = dictionary.get('ttype', None)
if ttype not in ttype_to_transformation.keys():
raise ValueError('No transformation of type %r' % ttype)
return ttype_to_transformation[ttype].from_dict(dictionary)
# Transformations
[docs]
class Transformation(TransformationBase):
ttype = 'Transformation'
def __init__(self, transformations=None, inverse=False):
super().__init__(inverse=inverse)
if transformations is None:
transformations = []
self.transformations = transformations
[docs]
def get_transformations(self, inverse=False, start_at=None, stop_at=None) -> list:
inverse = self.get_inverse(inverse)
transformations = self.transformations
if start_at is not None:
types = [transformation.ttype for transformation in transformations]
if start_at not in types:
raise ValueError('Transformation has no sub-transformation %r (%r)' % (start_at, types))
for i, ttype in enumerate(types):
if ttype == start_at:
transformations = transformations[i:]
break
if stop_at is not None:
types = [transformation.ttype for transformation in transformations]
if stop_at not in types:
raise ValueError('Transformation has no sub-transformation %r (%r)' % (stop_at, types))
for i, ttype in enumerate(types):
if ttype == stop_at:
transformations = transformations[:i + 1]
break
if inverse:
transformations = transformations[::-1]
return transformations
[docs]
def transform_points(self, source, inverse=False, start_at=None, stop_at=None, **kwargs):
transformations = self.get_transformations(inverse=inverse, start_at=start_at, stop_at=stop_at)
transformed_points = source
for transformation in transformations:
transformed_points = transformation.transform_points(transformed_points, inverse=inverse, **kwargs)
return transformed_points
[docs]
def transform_shape(self, shape, inverse=False, start_at=None, stop_at=None, **kwargs):
transformations = self.get_transformations(inverse=inverse, start_at=start_at, stop_at=stop_at)
transformed_shape = shape
for transformation in transformations:
transformed_shape = transformation.transform_shape(transformed_shape, inverse=inverse, **kwargs)
return transformed_shape
[docs]
def transform_data(self, source, inverse=False, start_at=None, stop_at=None, **kwargs):
transformations = self.get_transformations(inverse=inverse, start_at=start_at, stop_at=stop_at)
transformed_data = source
for transformation in transformations:
transformed_data = transformation.transform_points(transformed_data, inverse=inverse, **kwargs)
return transformed_data
[docs]
def to_dict(self) -> dict:
dictionary = super().to_dict()
dictionary.update(transformations=[transformation.to_dict() for transformation in self.transformations])
return dictionary
[docs]
@classmethod
def from_dict(cls, dictionary: dict):
transformations = [transformation_from_dict(transformation) for transformation in dictionary['transformations']]
return cls(transformations=transformations)
def __len__(self):
return len(self.transformations)
def __getitem__(self, item):
if isinstance(item, int):
return self.transformations.__getitem__(item)
else:
transformations = self.transformations.__getitem__(item)
return Transformation(transformations=transformations)
def __setitem__(self, *args):
self.transformations.__setitem__(*args)
def __repr__(self):
return "%s[\n %s]" % (super().__repr__(), '\n ->'.join([transformation.__repr__() for transformation in self.transformations]))
###############################################################################
# Test
###############################################################################
def _test():
import numpy as np
import ClearMap.Alignment.Transformation as tfm
from importlib import reload
reload(tfm)
# combined transformations
a1 = tfm.AffineTransformation(M=5 * np.identity(3))
a2 = tfm.AffineTransformation(M=np.identity(3), b=3)
t = tfm.Transformation(transformations=[a1,a2])
print(t)
points = np.random.rand(5,3)
transformed = t.transform_points(points)
print(np.allclose(points * 5 + 3, transformed))
# slicing
data = np.random.rand(30,50)
slicing = (slice(5,15,3), slice(None,40))
data_sliced = data[slicing]
s = tfm.SlicingTransformation(slicing=slicing, shape=data.shape)
print(s)
print(np.all(s.transform_data(data) == data_sliced))
points_sliced = np.array(np.where(data_sliced > 0.75)).T
points = s.transform_points(points_sliced, inverse=True)
print(np.all(data[points[:,0], points[:,1]] == data_sliced[points_sliced[:,0], points_sliced[:,1]]))