# -*- coding: utf-8 -*-
"""
Otsu, Yamamoto and Hachisuka (2018) - Reflectance Recovery
==========================================================
Defines objects for reflectance recovery, i.e. spectral upsampling, using
*Otsu et al. (2018)* method:
- :class:`colour.recovery.Dataset_Otsu2018`
- :func:`colour.recovery.XYZ_to_sd_Otsu2018`
- :func:`colour.recovery.NodeTree_Otsu2018`
References
----------
- :cite:`Otsu2018` : Otsu, H., Yamamoto, M., & Hachisuka, T. (2018).
Reproducing Spectral Reflectances From Tristimulus Colours. Computer
Graphics Forum, 37(6), 370-381. doi:10.1111/cgf.13332
"""
from __future__ import division, print_function, unicode_literals
import numpy as np
import six
from collections import namedtuple
from colour.colorimetry import (MSDS_CMFS_STANDARD_OBSERVER, SDS_ILLUMINANTS,
SpectralDistribution, SpectralShape,
msds_to_XYZ, sd_to_XYZ)
from colour.models import XYZ_to_xy
from colour.recovery import (SPECTRAL_SHAPE_OTSU2018, BASIS_FUNCTIONS_OTSU2018,
CLUSTER_MEANS_OTSU2018, SELECTOR_ARRAY_OTSU2018)
from colour.utilities import (as_float_array, domain_range_scale,
is_tqdm_installed, message_box, runtime_warning,
to_domain_1, zeros)
if six.PY3:
from unittest import mock
else:
import mock
if is_tqdm_installed():
from tqdm import tqdm
else:
tqdm = mock.MagicMock()
__author__ = 'Colour Developers'
__copyright__ = 'Copyright (C) 2013-2020 - Colour Developers'
__license__ = 'New BSD License - https://opensource.org/licenses/BSD-3-Clause'
__maintainer__ = 'Colour Developers'
__email__ = 'colour-developers@colour-science.org'
__status__ = 'Production'
__all__ = [
'Dataset_Otsu2018', 'DATASET_REFERENCE_OTSU2018', 'XYZ_to_sd_Otsu2018',
'PartitionAxis', 'ColourData', 'Node', 'NodeTree_Otsu2018'
]
[docs]class Dataset_Otsu2018(object):
"""
Stores all the information needed for the *Otsu et al. (2018)* spectral
upsampling method.
Datasets can be either generated and converted as a
:class:`colour.recovery.Dataset_Otsu2018` class instance using the
:meth:`colour.recovery.NodeTree_Otsu2018.to_dataset` method or
alternatively, loaded from disk with the
:meth:`colour.recovery.Dataset_Otsu2018.read` method.
Parameters
----------
shape: SpectralShape
Shape of the spectral data.
basis_functions : array_like, (n, 3, m)
Three basis functions for every cluster.
means : array_like, (n, m)
Mean for every cluster.
selector_array : array_like, (k, 4)
Array describing how to select the appropriate cluster. See
:meth:`colour.recovery.Dataset_Otsu2018.select` method for details.
Attributes
----------
- :attr:`~colour.recovery.Dataset_Otsu2018.shape`
- :attr:`~colour.recovery.Dataset_Otsu2018.basis_functions`
- :attr:`~colour.recovery.Dataset_Otsu2018.means`
- :attr:`~colour.recovery.Dataset_Otsu2018.selector_array`
Methods
-------
- :meth:`~colour.recovery.Dataset_Otsu2018.__init__`
- :meth:`~colour.recovery.Dataset_Otsu2018.select`
- :meth:`~colour.recovery.Dataset_Otsu2018.cluster`
- :meth:`~colour.recovery.Dataset_Otsu2018.read`
- :meth:`~colour.recovery.Dataset_Otsu2018.write`
References
----------
:cite:`Otsu2018`
Examples
--------
>>> import os
>>> import colour
>>> from colour.characterisation import SDS_COLOURCHECKERS
>>> reflectances = [
... sd.copy().align(SPECTRAL_SHAPE_OTSU2018).values
... for sd in SDS_COLOURCHECKERS['ColorChecker N Ohta'].values()
... ]
>>> node_tree = NodeTree_Otsu2018(reflectances)
>>> node_tree.optimise(iterations=2, print_callable=lambda x: x)
>>> dataset = node_tree.to_dataset()
>>> path = os.path.join(colour.__path__[0], 'recovery', 'tests',
... 'resources', 'ColorChecker_Otsu2018.npz')
>>> dataset.write(path) # doctest: +SKIP
>>> dataset = Dataset_Otsu2018() # doctest: +SKIP
>>> dataset.read(path) # doctest: +SKIP
"""
[docs] def __init__(self,
shape=None,
basis_functions=None,
means=None,
selector_array=None):
self._shape = shape
self._basis_functions = as_float_array(basis_functions)
self._means = as_float_array(means)
self._selector_array = selector_array
@property
def shape(self):
"""
Getter property for the shape used by the *Otsu et al. (2018)* dataset.
Returns
-------
SpectralShape
Shape used by the *Otsu et al. (2018)* dataset.
"""
return self._shape
@property
def basis_functions(self):
"""
Getter property for the basis functions of the *Otsu et al. (2018)*
dataset.
Returns
-------
ndarray
Basis functions of the *Otsu et al. (2018)* dataset.
"""
return self._basis_functions
@property
def means(self):
"""
Getter property for means of the *Otsu et al. (2018)* dataset.
Returns
-------
int
Means of the *Otsu et al. (2018)* dataset.
"""
return self._means
@property
def selector_array(self):
"""
Getter property for the selector array of the *Otsu et al. (2018)*
dataset.
Returns
-------
ndarray
Selector array of the *Otsu et al. (2018)* dataset.
"""
return self._selector_array
def __str__(self):
"""
Returns a formatted string representation of the dataset.
Returns
-------
unicode
Formatted string representation.
"""
return '{0}({1} basis functions)'.format(
self.__class__.__name__, self._basis_functions.shape[0])
def select(self, xy):
"""
Returns the cluster index appropriate for the given *CIE xy*
coordinates.
Parameters
----------
xy : array_like, (2,)
*CIE xy* chromaticity coordinates.
Returns
-------
int
Cluster index.
"""
i = 0
while True:
row = self._selector_array[i, :]
direction, origin, lesser_index, greater_index = row
if xy[int(direction)] <= origin:
index = int(lesser_index)
else:
index = int(greater_index)
if index < 0:
i = -index
else:
return index
def cluster(self, xy):
"""
Returns the basis functions and dataset mean for the given *CIE xy*
coordinates.
Parameters
----------
xy : array_like, (2,)
*CIE xy* chromaticity coordinates.
Returns
-------
basis_functions : ndarray, (3, n)
Three basis functions.
mean : ndarray, (n,)
Dataset mean.
"""
index = self.select(xy)
return self._basis_functions[index, :, :], self._means[index, :]
def read(self, path):
"""
Reads and loads a dataset from an *.npz* file.
Parameters
----------
path : unicode
Path to the file.
Raises
------
ValueError, KeyError
Raised when loading the file succeeded but it did not contain the
expected data.
Examples
--------
>>> import os
>>> import colour
>>> from colour.characterisation import SDS_COLOURCHECKERS
>>> reflectances = [
... sd.copy().align(SPECTRAL_SHAPE_OTSU2018).values
... for sd in SDS_COLOURCHECKERS['ColorChecker N Ohta'].values()
... ]
>>> node_tree = NodeTree_Otsu2018(reflectances)
>>> node_tree.optimise(iterations=2, print_callable=lambda x: x)
>>> dataset = node_tree.to_dataset()
>>> path = os.path.join(colour.__path__[0], 'recovery', 'tests',
... 'resources', 'ColorChecker_Otsu2018.npz')
>>> dataset.write(path) # doctest: +SKIP
>>> dataset = Dataset_Otsu2018() # doctest: +SKIP
>>> dataset.read(path) # doctest: +SKIP
"""
npz = np.load(path)
if not isinstance(npz, np.lib.npyio.NpzFile):
raise ValueError('The loaded file is not an ".npz" type file!')
start, end, interval = npz['shape']
self._shape = SpectralShape(start, end, interval)
self._basis_functions = npz['basis_functions']
self._means = npz['means']
self._selector_array = npz['selector_array']
n, three, m = self._basis_functions.shape
if (three != 3 or self._means.shape != (n, m) or
self._selector_array.shape[1] != 4):
raise ValueError(
'Unexpected array shapes encountered, the file could be '
'corrupted or in a wrong format!')
def write(self, path):
"""
Writes the dataset to an *.npz* file at given path.
Parameters
----------
path : unicode
Path to the file.
Examples
--------
>>> import os
>>> import colour
>>> from colour.characterisation import SDS_COLOURCHECKERS
>>> reflectances = [
... sd.copy().align(SPECTRAL_SHAPE_OTSU2018).values
... for sd in SDS_COLOURCHECKERS['ColorChecker N Ohta'].values()
... ]
>>> node_tree = NodeTree_Otsu2018(reflectances)
>>> node_tree.optimise(iterations=2, print_callable=lambda x: x)
>>> dataset = node_tree.to_dataset()
>>> path = os.path.join(colour.__path__[0], 'recovery', 'tests',
... 'resources', 'ColorChecker_Otsu2018.npz')
>>> dataset.write(path) # doctest: +SKIP
"""
shape_array = as_float_array(
[self._shape.start, self._shape.end, self._shape.interval])
np.savez(
path,
shape=shape_array,
basis_functions=self._basis_functions,
means=self._means,
selector_array=self._selector_array)
DATASET_REFERENCE_OTSU2018 = Dataset_Otsu2018(
SPECTRAL_SHAPE_OTSU2018, BASIS_FUNCTIONS_OTSU2018, CLUSTER_MEANS_OTSU2018,
SELECTOR_ARRAY_OTSU2018)
"""
Builtin *Otsu et al. (2018)* dataset as a
:class:`colour.recovery.Dataset_Otsu2018` class instance, usable by
:func:`colour.recovery.XYZ_to_sd_Otsu2018` definition among others.
"""
[docs]def XYZ_to_sd_Otsu2018(
XYZ,
cmfs=MSDS_CMFS_STANDARD_OBSERVER['CIE 1931 2 Degree Standard Observer']
.copy().align(SPECTRAL_SHAPE_OTSU2018),
illuminant=SDS_ILLUMINANTS['D65'].copy().align(
SPECTRAL_SHAPE_OTSU2018),
dataset=DATASET_REFERENCE_OTSU2018,
clip=True):
"""
Recovers the spectral distribution of given *CIE XYZ* tristimulus values
using *Otsu et al. (2018)* method.
Parameters
----------
XYZ : array_like, (3,)
*CIE XYZ* tristimulus values to recover the spectral distribution from.
cmfs : XYZ_ColourMatchingFunctions, optional
Standard observer colour matching functions.
illuminant : SpectralDistribution, optional
Illuminant spectral distribution.
dataset : Dataset_Otsu2018, optional
Dataset to use for reconstruction. The default is to use the published
data.
clip : bool, optional
If *True*, the default, values below zero and above unity in the
recovered spectral distributions will be clipped. This ensures that the
returned reflectance is physical and conserves energy, but will cause
noticeable colour differences in case of very saturated colours.
Returns
-------
SpectralDistribution
Recovered spectral distribution. Its shape is always that of the
:class:`colour.recovery.SPECTRAL_SHAPE_OTSU2018` class instance.
References
----------
:cite:`Otsu2018`
Examples
--------
>>> from colour.colorimetry import CCS_ILLUMINANTS, sd_to_XYZ_integration
>>> from colour.models import XYZ_to_sRGB
>>> from colour.utilities import numpy_print_options
>>> XYZ = np.array([0.20654008, 0.12197225, 0.05136952])
>>> cmfs = (
... MSDS_CMFS_STANDARD_OBSERVER['CIE 1931 2 Degree Standard Observer'].
... copy().align(SPECTRAL_SHAPE_OTSU2018)
... )
>>> illuminant = SDS_ILLUMINANTS['D65'].copy().align(cmfs.shape)
>>> sd = XYZ_to_sd_Otsu2018(XYZ, cmfs, illuminant)
>>> with numpy_print_options(suppress=True):
... # Doctests skip for Python 2.x compatibility.
... sd # doctest: +SKIP
SpectralDistribution([[ 380. , 0.0601939...],
[ 390. , 0.0568063...],
[ 400. , 0.0517429...],
[ 410. , 0.0495841...],
[ 420. , 0.0502007...],
[ 430. , 0.0506489...],
[ 440. , 0.0510020...],
[ 450. , 0.0493782...],
[ 460. , 0.0468046...],
[ 470. , 0.0437132...],
[ 480. , 0.0416957...],
[ 490. , 0.0403783...],
[ 500. , 0.0405197...],
[ 510. , 0.0406031...],
[ 520. , 0.0416912...],
[ 530. , 0.0430956...],
[ 540. , 0.0444474...],
[ 550. , 0.0459336...],
[ 560. , 0.0507631...],
[ 570. , 0.0628967...],
[ 580. , 0.0844661...],
[ 590. , 0.1334277...],
[ 600. , 0.2262428...],
[ 610. , 0.3599330...],
[ 620. , 0.4885571...],
[ 630. , 0.5752546...],
[ 640. , 0.6193023...],
[ 650. , 0.6450744...],
[ 660. , 0.6610548...],
[ 670. , 0.6688673...],
[ 680. , 0.6795426...],
[ 690. , 0.6887933...],
[ 700. , 0.7003469...],
[ 710. , 0.7084128...],
[ 720. , 0.7154674...],
[ 730. , 0.7234334...]],
interpolator=SpragueInterpolator,
interpolator_kwargs={},
extrapolator=Extrapolator,
extrapolator_kwargs={...})
>>> sd_to_XYZ_integration(sd, cmfs, illuminant) / 100 # doctest: +ELLIPSIS
array([ 0.2065494..., 0.1219712..., 0.0514002...])
"""
XYZ = to_domain_1(XYZ)
xy = XYZ_to_xy(XYZ)
basis_functions, mean = dataset.cluster(xy)
M = np.empty((3, 3))
for i in range(3):
sd = SpectralDistribution(basis_functions[i, :], dataset.shape.range())
with domain_range_scale('ignore'):
M[:, i] = sd_to_XYZ(sd, cmfs, illuminant) / 100
M_inverse = np.linalg.inv(M)
sd = SpectralDistribution(mean, dataset.shape.range())
with domain_range_scale('ignore'):
XYZ_mu = sd_to_XYZ(sd, cmfs, illuminant) / 100
weights = np.dot(M_inverse, XYZ - XYZ_mu)
recovered_sd = np.dot(weights, basis_functions) + mean
recovered_sd = np.clip(recovered_sd, 0, 1) if clip else recovered_sd
return SpectralDistribution(recovered_sd, dataset.shape.range())
class PartitionAxis(namedtuple('PartitionAxis', ('origin', 'direction'))):
"""
Represents a horizontal or vertical line, partitioning the 2D space in
two half-planes.
Parameters
----------
origin : numeric
The x coordinate of a vertical line or the y coordinate of a horizontal
line.
direction : int
*0* if vertical, *1* if horizontal.
Methods
-------
- :meth:`~colour.recovery.otsu2018.PartitionAxis.__str__`
"""
def __str__(self):
"""
Returns a formatted string representation of the partition axis.
Returns
-------
unicode
Formatted string representation.
"""
return '{0}({1} partition at {2} = {3})'.format(
self.__class__.__name__, 'horizontal'
if self.direction else 'vertical', 'y'
if self.direction else 'x', self.origin)
class ColourData(object):
"""
Represents the data for multiple colours: their spectral reflectance
distributions, *CIE XYZ* tristimulus values and *CIE xy* coordinates. The
standard observer colour matching functions and illuminant are accessed via
the parent tree.
This class also supports partitioning: Creating two smaller instances of
:class:`colour.recovery.otsu2018.ColourData` class by splitting along a
horizontal or a vertical axis on the *CIE xy* plane.
Parameters
----------
tree : NodeTree_Otsu2018, optional
The parent tree which determines the standard observer colour matching
functions and illuminant used in colourimetric calculations.
reflectances : ndarray, (n, m), optional
Reflectances of the *n* colours to be stored in this class. The shape
must match ``tree.shape`` with *m* points for each colour.
Attributes
----------
- :attr:`~colour.recovery.otsu2018.ColourData.tree`
- :attr:`~colour.recovery.otsu2018.ColourData.reflectances`
- :attr:`~colour.recovery.otsu2018.ColourData.XYZ`
- :attr:`~colour.recovery.otsu2018.ColourData.xy`
Methods
-------
- :meth:`~colour.recovery.otsu2018.ColourData.__init__`
- :meth:`~colour.recovery.otsu2018.ColourData.__str__`
- :meth:`~colour.recovery.otsu2018.ColourData.__len__`
- :meth:`~colour.recovery.otsu2018.ColourData.partition`
"""
def __init__(self, tree, reflectances):
self._tree = tree
self._XYZ = None
self._xy = None
self._reflectances = None
self.reflectances = reflectances
@property
def tree(self):
"""
Getter property for the colour data tree.
Returns
-------
NodeTree_Otsu2018
Colour data tree.
"""
return self._tree
@property
def reflectances(self):
"""
Getter and setter property for the colour data reflectances.
Parameters
----------
value : array_like
Value to set the colour data reflectances with.
Returns
-------
ndarray
Colour data reflectances.
"""
return self._reflectances
@reflectances.setter
def reflectances(self, value):
"""
Setter for the **self.reflectances** property.
"""
if value is not None:
self._reflectances = as_float_array(value)
self._XYZ = msds_to_XYZ(
self._reflectances,
self.tree.cmfs,
self.tree.illuminant,
method='Integration',
shape=self.tree.cmfs.shape) / 100
self._xy = XYZ_to_xy(self._XYZ)
@property
def XYZ(self):
"""
Getter property for the colour data *CIE XYZ* tristimulus values.
Returns
-------
ndarray
Colour data *CIE XYZ* tristimulus values.
"""
return self._XYZ
@property
def xy(self):
"""
Getter property for the colour data *CIE xy* tristimulus values.
Returns
-------
ndarray
Colour data *CIE xy* tristimulus values.
"""
return self._xy
def __str__(self):
"""
Returns a formatted string representation of the colour data.
Returns
-------
unicode
Formatted string representation.
"""
return '{0}({1} Reflectances)'.format(self.__class__.__name__,
len(self))
def __len__(self):
"""
Returns the number of colours in the colour data.
Returns
-------
int
Number of colours in the colour data.
"""
return self._reflectances.shape[0]
def partition(self, axis):
"""
Parameters
----------
axis : PartitionAxis
Partition axis used to partition the colour data.
Returns
-------
lesser : ColourData
The left or lower part.
greater : ColourData
The right or upper part.
"""
lesser = ColourData(self.tree, None)
greater = ColourData(self.tree, None)
mask = self.xy[:, axis.direction] <= axis.origin
lesser._reflectances = self.reflectances[mask, :]
greater._reflectances = self.reflectances[~mask, :]
lesser._XYZ = self.XYZ[mask, :]
greater._XYZ = self.XYZ[~mask, :]
lesser._xy = self.xy[mask, :]
greater._xy = self.xy[~mask, :]
return lesser, greater
class Node(object):
"""
Represents a node in a :meth:`colour.recovery.NodeTree_Otsu2018` class
instance node tree.
Parameters
----------
tree : NodeTree_Otsu2018
The parent tree which determines the standard observer colour matching
functions and illuminant used in colourimetric calculations.
colour_data : ColourData
The colour data belonging to this node.
Attributes
----------
- :attr:`~colour.recovery.otsu2018.Node.id`
- :attr:`~colour.recovery.otsu2018.Node.tree`
- :attr:`~colour.recovery.otsu2018.Node.colour_data`
- :attr:`~colour.recovery.otsu2018.Node.children`
- :attr:`~colour.recovery.otsu2018.Node.partition_axis`
- :attr:`~colour.recovery.otsu2018.Node.basis_functions`
- :attr:`~colour.recovery.otsu2018.Node.mean`
- :attr:`~colour.recovery.otsu2018.Node.leaves`
Methods
-------
- :meth:`~colour.recovery.otsu2018.Node.__init__`
- :meth:`~colour.recovery.otsu2018.Node.__str__`
- :meth:`~colour.recovery.otsu2018.Node.__len__`
- :meth:`~colour.recovery.otsu2018.Node.is_leaf`
- :meth:`~colour.recovery.otsu2018.Node.split`
- :meth:`~colour.recovery.otsu2018.Node.PCA`
- :meth:`~colour.recovery.otsu2018.Node.reconstruct`
- :meth:`~colour.recovery.otsu2018.Node.leaf_reconstruction_error`
- :meth:`~colour.recovery.otsu2018.Node.branch_reconstruction_error`
- :meth:`~colour.recovery.otsu2018.Node.partition_reconstruction_error`
- :meth:`~colour.recovery.otsu2018.Node.find_best_partition`
"""
_NODE_COUNT = 1
"""
Total node count.
_NODE_COUNT : int
"""
def __init__(self, tree, colour_data):
self._id = Node._NODE_COUNT
Node._NODE_COUNT += 1
self._tree = tree
self._colour_data = colour_data
self._children = []
self._partition_axis = None
self._mean = None
self._basis_functions = None
self._M = None
self._M_inverse = None
self._XYZ_mu = None
self._best_partition = None
self._cached_leaf_reconstruction_error = None
@property
def id(self):
"""
Getter property for the node id.
Returns
-------
int
Node id.
"""
return self._id
@property
def tree(self):
"""
Getter property for the node tree.
Returns
-------
NodeTree_Otsu2018
Node tree.
"""
return self._tree
@property
def colour_data(self):
"""
Getter property for the node colour data.
Returns
-------
ColourData
Node colour data.
"""
return self._colour_data
@property
def children(self):
"""
Getter property for the node children.
Returns
-------
tuple
Node children.
"""
return self._children
@property
def partition_axis(self):
"""
Getter property for the node partition axis.
Returns
-------
PartitionAxis
Node partition axis.
"""
return self._partition_axis
@property
def basis_functions(self):
"""
Getter property for the node basis functions.
Returns
-------
array_like
Node basis functions.
"""
return self._basis_functions
@property
def mean(self):
"""
Getter property for the node mean distribution.
Returns
-------
array_like
Node mean distribution.
"""
return self._mean
@property
def leaves(self):
"""
Getter property for the node leaves.
Returns
-------
generator
Generator of all the leaves connected to this node.
"""
if self.is_leaf():
yield self
else:
for child in self._children:
# TODO: Python 3 "yield from child.leaves".
for leaf in child.leaves:
yield leaf
def __str__(self):
"""
Returns a formatted string representation of the node.
Returns
-------
unicode
Formatted string representation.
"""
return '{0}#{1}({2})'.format(self.__class__.__name__, self._id,
self._colour_data)
def __len__(self):
"""
Returns the number of children of the node.
Returns
-------
int
Number of children of the node.
"""
return len(list(self.leaves))
def is_leaf(self):
"""
Returns whether the node is a leaf.
:class:`colour.recovery.NodeTree_Otsu2018` class instance tree leaves
do not have any children and store instances of
:class:`colour.recovery.otsu2018.ColourData` class.
Returns
-------
bool
Whether the node is a leaf.
"""
return len(self._children) == 0
def split(self, children, partition_axis):
"""
Converts the leaf node into a non-leaf node using given children and
partition axis.
Parameters
----------
children : tuple
Tuple of two :class:`colour.recovery.otsu2018.Node` classes
instances.
partition_axis : PartitionAxis
Partition axis.
"""
self._colour_data = None
self._children = children
self._partition_axis = partition_axis
self._mean = None
self._basis_functions = None
self._M = None
self._M_inverse = None
self._XYZ_mu = None
self._best_partition = None
self._cached_leaf_reconstruction_error = None
#
# PCA and Reconstruction
#
def PCA(self):
"""
Performs the *Principal Component Analysis* (PCA) on the colours data
of the node and sets the relevant private attributes accordingly.
Raises
------
RuntimeError
If the node is not a leaf node.
"""
if not self.is_leaf():
raise RuntimeError('{0} is not a leaf node!'.format(self))
if self._M is not None:
return
self._mean = np.mean(self._colour_data.reflectances, axis=0)
self._XYZ_mu = self._tree.msds_to_XYZ(self._mean)
matrix_data = self._colour_data.reflectances - self._mean
matrix_covariance = np.dot(np.transpose(matrix_data), matrix_data)
_eigenvalues, eigenvectors = np.linalg.eigh(matrix_covariance)
self._basis_functions = np.transpose(eigenvectors[:, -3:])
self._M = np.transpose(self._tree.msds_to_XYZ(self._basis_functions))
self._M_inverse = np.linalg.inv(self._M)
def reconstruct(self, XYZ):
"""
Reconstructs the reflectance for the given *CIE XYZ* tristimulus
values.
If the node is a leaf, the colour data from the node is used, otherwise
the branch is traversed recursively to find the leaves.
Parameters
----------
XYZ : ndarray, (3,)
*CIE XYZ* tristimulus values to recover the spectral distribution
from.
Returns
-------
SpectralDistribution
Recovered spectral distribution.
"""
xy = XYZ_to_xy(XYZ)
if not self.is_leaf():
if (xy[self._partition_axis.direction] <=
self._partition_axis.origin):
return self._children[0].reconstruct(XYZ)
else:
return self._children[1].reconstruct(XYZ)
weights = np.dot(self._M_inverse, XYZ - self._XYZ_mu)
reflectance = np.dot(weights, self._basis_functions) + self._mean
reflectance = np.clip(reflectance, 0, 1)
return SpectralDistribution(reflectance, self._tree.cmfs.wavelengths)
#
# Optimisation
#
def leaf_reconstruction_error(self):
"""
Reconstructs the reflectance of the *CIE XYZ* tristimulus values in
the colour data of this node using PCA and compares the reconstructed
spectrum against the measured spectrum. The reconstruction errors are
then summed up and returned.
Returns
-------
error : float
The reconstruction errors summation for the node.
Notes
-----
The reconstruction error is cached upon being computed and thus is only
computed once per node.
Raises
------
RuntimeError
If the node is not a leaf node.
"""
if not self.is_leaf():
raise RuntimeError('{0} is not a leaf node!'.format(self))
if self._cached_leaf_reconstruction_error:
return self._cached_leaf_reconstruction_error
if self._M is None:
self.PCA()
error = 0
for i in range(len(self.colour_data)):
sd = self.colour_data.reflectances[i, :]
XYZ = self.colour_data.XYZ[i, :]
recovered_sd = self.reconstruct(XYZ)
error += np.sum((sd - recovered_sd.values) ** 2)
self._cached_leaf_reconstruction_error = error
return error
def branch_reconstruction_error(self):
"""
Computes the reconstruction error for an entire branch of the tree,
starting from the node, i.e. the reconstruction errors summation for
all the leaves in the branch.
Returns
-------
error : float
Reconstruction errors summation for all the leaves in the branch.
"""
if self.is_leaf():
return self.leaf_reconstruction_error()
else:
return sum([
child.branch_reconstruction_error() for child in self._children
])
def partition_reconstruction_error(self, axis):
"""
Computes the reconstruction errors summation of the two nodes created
by splitting the node with a given partition.
Parameters
----------
axis : PartitionAxis
Partition axis used to compute the reconstruction error.
Returns
-------
error : float
Reconstruction errors summation of the two nodes created
by splitting the node with a given partition.
lesser, greater : tuple
Nodes created by splitting the node with the given partition.
"""
partition = self.colour_data.partition(axis)
if (len(partition[0]) < self._tree.minimum_cluster_size or
len(partition[1]) < self._tree.minimum_cluster_size):
raise RuntimeError('Partition generated parts smaller '
'than the minimum cluster size!')
lesser = Node(self._tree, partition[0])
lesser.PCA()
greater = Node(self._tree, partition[1])
greater.PCA()
error = (lesser.leaf_reconstruction_error() +
greater.leaf_reconstruction_error())
return error, (lesser, greater)
def find_best_partition(self):
"""
Finds the best partition for the node.
Returns
-------
partition_error : float
Partition error
axis : PartitionAxis
Horizontal or vertical line, partitioning the 2D space in
two half-planes.
partition : tuple
Nodes created by splitting a node with a given partition.
"""
if self._best_partition is not None:
return self._best_partition
leaf_error = self.leaf_reconstruction_error()
best_error = None
with tqdm(total=2 * len(self.colour_data)) as progress:
for direction in [0, 1]:
for i in range(len(self.colour_data)):
progress.update()
origin = self.colour_data.xy[i, direction]
axis = PartitionAxis(origin, direction)
try:
partition_error, partition = (
self.partition_reconstruction_error(axis))
except RuntimeError:
continue
if partition_error >= leaf_error:
continue
if best_error is None or partition_error < best_error:
self._best_partition = (partition_error, axis,
partition)
if self._best_partition is None:
raise RuntimeError('Could not find a best partition!')
return self._best_partition
[docs]class NodeTree_Otsu2018(Node):
"""
A sub-class of :class:`colour.recovery.otsu2018.Node` class representing
the root node of a tree containing information shared with all the nodes,
such as the standard observer colour matching functions and the illuminant,
if any is used.
Global operations involving the entire tree, such as optimisation and
reconstruction, are implemented in this sub-class.
Parameters
----------
reflectances : ndarray, (n, m)
Reflectances of the *n* reference colours to use for optimisation.
cmfs : XYZ_ColourMatchingFunctions, optional
Standard observer colour matching functions.
illuminant : SpectralDistribution, optional
Illuminant spectral distribution.
Attributes
----------
- :attr:`~colour.recovery.NodeTree_Otsu2018.reflectances`
- :attr:`~colour.recovery.NodeTree_Otsu2018.cmfs`
- :attr:`~colour.recovery.NodeTree_Otsu2018.illuminant`
- :attr:`~colour.recovery.NodeTree_Otsu2018.minimum_cluster_size`
Methods
-------
- :meth:`~colour.recovery.otsu2018.NodeTree_Otsu2018.__init__`
- :meth:`~colour.recovery.otsu2018.NodeTree_Otsu2018.__str__`
- :meth:`~colour.recovery.otsu2018.NodeTree_Otsu2018.msds_to_XYZ`
- :meth:`~colour.recovery.otsu2018.NodeTree_Otsu2018.optimise`
- :meth:`~colour.recovery.otsu2018.NodeTree_Otsu2018.to_dataset`
References
----------
:cite:`Otsu2018`
Examples
--------
>>> import os
>>> import colour
>>> from colour.characterisation import SDS_COLOURCHECKERS
>>> from colour.utilities import numpy_print_options
>>> XYZ = np.array([0.20654008, 0.12197225, 0.05136952])
>>> cmfs = (
... MSDS_CMFS_STANDARD_OBSERVER['CIE 1931 2 Degree Standard Observer'].
... copy().align(SpectralShape(360, 780, 10))
... )
>>> illuminant = SDS_ILLUMINANTS['D65'].copy().align(cmfs.shape)
>>> reflectances = [
... sd.copy().align(cmfs.shape).values
... for sd in SDS_COLOURCHECKERS['ColorChecker N Ohta'].values()
... ]
>>> node_tree = NodeTree_Otsu2018(reflectances, cmfs, illuminant)
>>> node_tree.optimise(iterations=2, print_callable=lambda x: x)
>>> dataset = node_tree.to_dataset()
>>> path = os.path.join(colour.__path__[0], 'recovery', 'tests',
... 'resources', 'ColorChecker_Otsu2018.npz')
>>> dataset.write(path) # doctest: +SKIP
>>> dataset = Dataset_Otsu2018() # doctest: +SKIP
>>> dataset.read(path) # doctest: +SKIP
>>> sd = XYZ_to_sd_Otsu2018(XYZ, cmfs, illuminant, dataset)
... # doctest: +SKIP
>>> with numpy_print_options(suppress=True):
... # Doctests skip for Python 2.x compatibility.
... sd # doctest: +SKIP
SpectralDistribution([[ 360. , 0.0651341...],
[ 370. , 0.0651341...],
[ 380. , 0.0651341...],
[ 390. , 0.0749684...],
[ 400. , 0.0815578...],
[ 410. , 0.0776439...],
[ 420. , 0.0721897...],
[ 430. , 0.0649064...],
[ 440. , 0.0567185...],
[ 450. , 0.0484685...],
[ 460. , 0.0409768...],
[ 470. , 0.0358964...],
[ 480. , 0.0307857...],
[ 490. , 0.0270148...],
[ 500. , 0.0273773...],
[ 510. , 0.0303157...],
[ 520. , 0.0331285...],
[ 530. , 0.0363027...],
[ 540. , 0.0425987...],
[ 550. , 0.0513442...],
[ 560. , 0.0579256...],
[ 570. , 0.0653850...],
[ 580. , 0.0929522...],
[ 590. , 0.1600326...],
[ 600. , 0.2586159...],
[ 610. , 0.3701242...],
[ 620. , 0.4702243...],
[ 630. , 0.5396261...],
[ 640. , 0.5737561...],
[ 650. , 0.590848 ...],
[ 660. , 0.5935371...],
[ 670. , 0.5923295...],
[ 680. , 0.5956326...],
[ 690. , 0.5982513...],
[ 700. , 0.6017904...],
[ 710. , 0.6016419...],
[ 720. , 0.5996892...],
[ 730. , 0.6000018...],
[ 740. , 0.5964443...],
[ 750. , 0.5868181...],
[ 760. , 0.5860973...],
[ 770. , 0.5614878...],
[ 780. , 0.5289331...]],
interpolator=SpragueInterpolator,
interpolator_kwargs={},
extrapolator=Extrapolator,
extrapolator_kwargs={...})
"""
[docs] def __init__(self,
reflectances,
cmfs=MSDS_CMFS_STANDARD_OBSERVER[
'CIE 1931 2 Degree Standard Observer'].copy().align(
SPECTRAL_SHAPE_OTSU2018),
illuminant=SDS_ILLUMINANTS['D65'].copy().align(
SPECTRAL_SHAPE_OTSU2018)):
self._reflectances = as_float_array(reflectances)
self._cmfs = cmfs
shape = cmfs.shape
if illuminant.shape != shape:
runtime_warning(
'Aligning "{0}" illuminant shape to "{1}" colour matching '
'functions shape.'.format(illuminant.name, cmfs.name))
illuminant = illuminant.copy().align(cmfs.shape)
self._illuminant = illuminant
self._dw = shape.interval
# Normalising constant :math:`k`, see :func:`colour.msds_to_XYZ`
# definition.
self._k = 1 / (np.sum(
self._cmfs.values[:, 1] * self._illuminant.values) * self._dw)
self._minimum_cluster_size = None
super(NodeTree_Otsu2018, self).__init__(
self, ColourData(self, self._reflectances))
@property
def reflectances(self):
"""
Getter property for the reflectances.
Returns
-------
ndarray
Reflectances.
"""
return self._reflectances
@property
def cmfs(self):
"""
Getter property for the standard observer colour matching functions.
Returns
-------
XYZ_ColourMatchingFunctions
Standard observer colour matching functions.
"""
return self._cmfs
@property
def illuminant(self):
"""
Getter property for the illuminant.
Returns
-------
SpectralDistribution
Illuminant.
"""
return self._illuminant
@property
def minimum_cluster_size(self):
"""
Getter property for the minimum cluster size.
Returns
-------
int
Minimum cluster size.
"""
return self._minimum_cluster_size
def __str__(self):
"""
Returns a formatted string representation of the tree.
Returns
-------
unicode
Formatted string representation.
"""
child_count = len(self)
return '{0}({1} {2})'.format(self.__class__.__name__, child_count,
'Node' if child_count == 1 else 'Nodes')
def _create_selector_array(self):
"""
Creates an array that describes how to select the appropriate cluster
for given *CIE xy* coordinates.
See :meth:`colour.recovery.Dataset_Otsu2018.select` method for
information about what the array structure and its usage.
"""
rows = []
leaf_number = [0]
symbol_table = {}
def add_rows(node):
"""
Add rows for given node and its children.
"""
if node.is_leaf():
symbol_table[node] = leaf_number[0]
leaf_number[0] += 1
return
symbol_table[node] = -len(rows)
rows.append([
node.partition_axis.direction, node.partition_axis.origin,
node.children[0], node.children[1]
])
for child in node.children:
add_rows(child)
add_rows(self)
# Special case for tree with just a root node.
if len(rows) == 0:
return zeros(4)
for i, (_direction, _origin, symbol_1, symbol_2) in enumerate(rows):
rows[i][2] = symbol_table[symbol_1]
rows[i][3] = symbol_table[symbol_2]
return as_float_array(rows)
def msds_to_XYZ(self, reflectances):
"""
Computes the XYZ tristimulus values of a given reflectance. Faster for
humans, by using cmfs and the illuminant stored in the ''tree'',
thus avoiding unnecessary repetition. Faster for computers, by using
a very simple and direct method.
Parameters
----------
reflectances : ndarray
Reflectance with shape matching the one used to construct this
``tree``.
Returns
-------
ndarray (3,)
XYZ tristimulus values, normalised to 1.
"""
E = self._illuminant.values * reflectances
return self._k * np.dot(E, self._cmfs.values) * self._dw
def optimise(self,
iterations=8,
minimum_cluster_size=None,
print_callable=print):
"""
Optimises the tree by repeatedly performing optimal partitioning of the
nodes, creating a tree that minimizes the total reconstruction error.
Parameters
----------
iterations : int, optional
Maximum number of splits. If the dataset is too small, this number
might not be reached. The default is to create 8 clusters, like in
:cite:`Otsu2018`.
minimum_cluster_size : int, optional
Smallest acceptable cluster size. By default it is chosen
automatically, based on the size of the dataset and desired number
of clusters. It must be at least 3 or the
*Principal Component Analysis* (PCA) will not be possible.
print_callable : callable, optional
Callable used to print progress and diagnostic information.
Examples
--------
>>> import os
>>> import colour
>>> from colour.characterisation import SDS_COLOURCHECKERS
>>> cmfs = MSDS_CMFS_STANDARD_OBSERVER[
... 'CIE 1931 2 Degree Standard Observer'].copy().align(
... SpectralShape(360, 780, 10))
>>> illuminant = SDS_ILLUMINANTS['D65'].copy().align(cmfs.shape)
>>> reflectances = [
... sd.copy().align(cmfs.shape).values
... for sd in SDS_COLOURCHECKERS['ColorChecker N Ohta'].values()
... ]
>>> node_tree = NodeTree_Otsu2018(reflectances, cmfs, illuminant)
>>> node_tree.optimise(iterations=2) # doctest: +ELLIPSIS
======================================================================\
=========
* \
*
* "Otsu et al. (2018)" Node Tree Optimisation \
*
* \
*
======================================================================\
=========
Initial branch error is: 4.8705353...
<BLANKLINE>
Iteration 1 of 2:
<BLANKLINE>
Optimising "NodeTree_Otsu2018(1 Node)"...
<BLANKLINE>
Split "NodeTree_Otsu2018(1 Node)" into \
"Node#...(ColourData(10 Reflectances))" and \
"Node#...(ColourData(14 Reflectances))" along "\
PartitionAxis(horizontal partition at y = 0.3240945...)".
Error is reduced by 0.0054840... and is now 4.8650513..., \
99.9% of the initial error.
<BLANKLINE>
Iteration 2 of 2:
<BLANKLINE>
Optimising "Node#...(ColourData(10 Reflectances))"...
Optimisation failed: Could not find a best partition!
Optimising "Node#...(ColourData(14 Reflectances))"...
<BLANKLINE>
Split "Node#...(ColourData(14 Reflectances))" into \
"Node#...(ColourData(7 Reflectances))" and \
"Node#...(ColourData(7 Reflectances))" along \
"PartitionAxis(horizontal partition at y = 0.3600663...)".
Error is reduced by 0.9681059... and is now 3.8969453..., \
80.0% of the initial error.
Node tree optimisation is complete!
>>> len(node_tree)
3
"""
self._minimum_cluster_size = (minimum_cluster_size
if minimum_cluster_size is not None else
len(self.colour_data) / iterations // 2)
self._minimum_cluster_size = max(self._minimum_cluster_size, 3)
initial_branch_error = self.branch_reconstruction_error()
message_box(
'"Otsu et al. (2018)" Node Tree Optimisation',
print_callable=print_callable)
print_callable(
'Initial branch error is: {0}'.format(initial_branch_error))
best_leaf, best_partition, best_axis, partition_error = [None] * 4
for i in range(iterations):
print_callable('\nIteration {0} of {1}:\n'.format(
i + 1, iterations))
total_error = self.branch_reconstruction_error()
optimised_total_error = None
for leaf in self.leaves:
print_callable('Optimising "{0}"...'.format(leaf))
try:
partition_error, axis, partition = (
leaf.find_best_partition())
except RuntimeError as error:
print_callable('Optimisation failed: {0}'.format(error))
continue
new_total_error = (
total_error - leaf.leaf_reconstruction_error() +
partition_error)
if (optimised_total_error is None or
new_total_error < optimised_total_error):
optimised_total_error = new_total_error
best_axis = axis
best_leaf = leaf
best_partition = partition
if optimised_total_error is None:
print_callable('\nNo further improvements are possible!\n'
'Terminating at iteration {0}.\n'.format(i))
break
print_callable(
'\nSplit "{0}" into "{1}" and "{2}" along "{3}".'.format(
best_leaf, best_partition[0], best_partition[1],
best_axis))
print_callable(
'Error is reduced by {0} and is now {1}, '
'{2:.1f}% of the initial error.'.format(
leaf.leaf_reconstruction_error() - partition_error,
optimised_total_error,
100 * optimised_total_error / initial_branch_error))
best_leaf.split(best_partition, best_axis)
print_callable('Node tree optimisation is complete!')
def to_dataset(self):
"""
Creates a :class:`colour.recovery.Dataset_Otsu2018` class instance
based on data stored in the tree.
The dataset can then be saved to disk or used to recover reflectance
with :func:`colour.recovery.XYZ_to_sd_Otsu2018` definition.
Returns
-------
Dataset_Otsu2018
The dataset object.
Examples
--------
>>> import os
>>> import colour
>>> from colour.characterisation import SDS_COLOURCHECKERS
>>> reflectances = [
... sd.copy().align(SPECTRAL_SHAPE_OTSU2018).values
... for sd in SDS_COLOURCHECKERS['ColorChecker N Ohta'].values()
... ]
>>> node_tree = NodeTree_Otsu2018(reflectances)
>>> node_tree.optimise(iterations=2, print_callable=lambda x: x)
>>> node_tree.to_dataset() # doctest: +ELLIPSIS
<colour.recovery.otsu2018.Dataset_Otsu2018 object at 0x...>
"""
basis_functions = [leaf.basis_functions for leaf in self.leaves]
means = [leaf.mean for leaf in self.leaves]
selector_array = self._create_selector_array()
return Dataset_Otsu2018(self._cmfs.shape, basis_functions, means,
selector_array)