Source code for colour.recovery.otsu2018

"""
Otsu, Yamamoto and Hachisuka (2018) - Reflectance Recovery
==========================================================

Define the objects for reflectance recovery, i.e., spectral upsampling, using
*Otsu et al. (2018)* method:

-   :class:`colour.recovery.Dataset_Otsu2018`
-   :func:`colour.recovery.XYZ_to_sd_Otsu2018`
-   :func:`colour.recovery.Tree_Otsu2018`

References
----------
-   :cite:`Otsu2018` : Otsu, H., Yamamoto, M., & Hachisuka, T. (2018).
    Reproducing Spectral Reflectances From Tristimulus Colours. Computer
    Graphics Forum, 37(6), 370-381. doi:10.1111/cgf.13332
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path

import numpy as np

from colour.algebra import eigen_decomposition
from colour.colorimetry import (
    MultiSpectralDistributions,
    SpectralDistribution,
    SpectralShape,
    handle_spectral_arguments,
    msds_to_XYZ_integration,
    reshape_msds,
    sd_to_XYZ,
)
from colour.hints import (
    Any,
    ArrayLike,
    Callable,
    Dict,
    NDArrayFloat,
    Self,
    Sequence,
    Tuple,
    cast,
)
from colour.models import XYZ_to_xy
from colour.recovery import (
    BASIS_FUNCTIONS_OTSU2018,
    CLUSTER_MEANS_OTSU2018,
    SELECTOR_ARRAY_OTSU2018,
    SPECTRAL_SHAPE_OTSU2018,
)
from colour.utilities import (
    TreeNode,
    as_float_array,
    as_float_scalar,
    domain_range_scale,
    is_tqdm_installed,
    message_box,
    optional,
    to_domain_1,
    zeros,
)

if is_tqdm_installed():
    from tqdm import tqdm
else:  # pragma: no cover
    from unittest import mock

    tqdm = mock.MagicMock()

__author__ = "Colour Developers"
__copyright__ = "Copyright 2013 Colour Developers"
__license__ = "BSD-3-Clause - https://opensource.org/licenses/BSD-3-Clause"
__maintainer__ = "Colour Developers"
__email__ = "colour-developers@colour-science.org"
__status__ = "Production"

__all__ = [
    "Dataset_Otsu2018",
    "DATASET_REFERENCE_OTSU2018",
    "XYZ_to_sd_Otsu2018",
    "PartitionAxis",
    "Data_Otsu2018",
    "Node_Otsu2018",
    "Tree_Otsu2018",
]


[docs] class Dataset_Otsu2018: """ Store all the information needed for the *Otsu et al. (2018)* spectral upsampling method. Datasets can be either generated and converted as a :class:`colour.recovery.Dataset_Otsu2018` class instance using the :meth:`colour.recovery.Tree_Otsu2018.to_dataset` method or alternatively, loaded from disk with the :meth:`colour.recovery.Dataset_Otsu2018.read` method. Parameters ---------- shape Shape of the spectral data. basis_functions Three basis functions for every cluster. means Mean for every cluster. selector_array Array describing how to select the appropriate cluster. See :meth:`colour.recovery.Dataset_Otsu2018.select` method for details. Attributes ---------- - :attr:`~colour.recovery.Dataset_Otsu2018.shape` - :attr:`~colour.recovery.Dataset_Otsu2018.basis_functions` - :attr:`~colour.recovery.Dataset_Otsu2018.means` - :attr:`~colour.recovery.Dataset_Otsu2018.selector_array` Methods ------- - :meth:`~colour.recovery.Dataset_Otsu2018.__init__` - :meth:`~colour.recovery.Dataset_Otsu2018.select` - :meth:`~colour.recovery.Dataset_Otsu2018.cluster` - :meth:`~colour.recovery.Dataset_Otsu2018.read` - :meth:`~colour.recovery.Dataset_Otsu2018.write` References ---------- :cite:`Otsu2018` Examples -------- >>> import os >>> import colour >>> from colour.characterisation import SDS_COLOURCHECKERS >>> from colour.colorimetry import sds_and_msds_to_msds >>> reflectances = sds_and_msds_to_msds( ... SDS_COLOURCHECKERS["ColorChecker N Ohta"].values() ... ) >>> node_tree = Tree_Otsu2018(reflectances) >>> node_tree.optimise(iterations=2, print_callable=lambda x: x) >>> dataset = node_tree.to_dataset() >>> path = os.path.join( ... colour.__path__[0], ... "recovery", ... "tests", ... "resources", ... "ColorChecker_Otsu2018.npz", ... ) >>> dataset.write(path) # doctest: +SKIP >>> dataset = Dataset_Otsu2018() # doctest: +SKIP >>> dataset.read(path) # doctest: +SKIP """
[docs] def __init__( self, shape: SpectralShape | None = None, basis_functions: NDArrayFloat | None = None, means: NDArrayFloat | None = None, selector_array: NDArrayFloat | None = None, ) -> None: self._shape: SpectralShape | None = shape self._basis_functions: NDArrayFloat | None = ( basis_functions if basis_functions is None else as_float_array(basis_functions) ) self._means: NDArrayFloat | None = ( means if means is None else as_float_array(means) ) self._selector_array: NDArrayFloat | None = ( selector_array if selector_array is None else as_float_array(selector_array) )
@property def shape(self) -> SpectralShape | None: """ Getter property for the shape used by the *Otsu et al. (2018)* dataset. Returns ------- :py:data:`None` or :class:`colour.SpectralShape` Shape used by the *Otsu et al. (2018)* dataset. """ return self._shape @property def basis_functions(self) -> NDArrayFloat | None: """ Getter property for the basis functions of the *Otsu et al. (2018)* dataset. Returns ------- :py:data:`None` or :class:`numpy.ndarray` Basis functions of the *Otsu et al. (2018)* dataset. """ return self._basis_functions @property def means(self) -> NDArrayFloat | None: """ Getter property for means of the *Otsu et al. (2018)* dataset. Returns ------- :py:data:`None` or :class:`numpy.ndarray` Means of the *Otsu et al. (2018)* dataset. """ return self._means @property def selector_array(self) -> NDArrayFloat | None: """ Getter property for the selector array of the *Otsu et al. (2018)* dataset. Returns ------- :py:data:`None` or :class:`numpy.ndarray` Selector array of the *Otsu et al. (2018)* dataset. """ return self._selector_array def __str__(self) -> str: """ Return a formatted string representation of the dataset. Returns ------- :class:`str` Formatted string representation. """ if self._basis_functions is not None: return ( f"{self.__class__.__name__}" f"({self._basis_functions.shape[0]} basis functions)" ) else: return f"{self.__class__.__name__}()" def select(self, xy: ArrayLike) -> int: """ Return the cluster index appropriate for the given *CIE xy* coordinates. Parameters ---------- xy *CIE xy* chromaticity coordinates. Returns ------- :class:`int` Cluster index. Raises ------ ValueError If the selector array is undefined. """ xy = as_float_array(xy) if self._selector_array is not None: i = 0 while True: row = self._selector_array[i, :] origin, direction, lesser_index, greater_index = row if xy[int(direction)] <= origin: index = int(lesser_index) else: index = int(greater_index) if index < 0: i = -index else: return index else: raise ValueError('The "selector array" is undefined!') def cluster(self, xy: ArrayLike) -> Tuple[NDArrayFloat, NDArrayFloat]: """ Return the basis functions and dataset mean for the given *CIE xy* coordinates. Parameters ---------- xy *CIE xy* chromaticity coordinates. Returns ------- :class:`tuple` Tuple of three basis functions and dataset mean. Raises ------ ValueError If the basis functions or means are undefined. """ if self._basis_functions is not None and self._means is not None: index = self.select(xy) return self._basis_functions[index, :, :], self._means[index, :] else: raise ValueError('The "basis functions" or "means" are undefined!') def read(self, path: str | Path) -> None: """ Read and loads a dataset from an *.npz* file. Parameters ---------- path Path to the file. Examples -------- >>> import os >>> import colour >>> from colour.characterisation import SDS_COLOURCHECKERS >>> from colour.colorimetry import sds_and_msds_to_msds >>> reflectances = sds_and_msds_to_msds( ... SDS_COLOURCHECKERS["ColorChecker N Ohta"].values() ... ) >>> node_tree = Tree_Otsu2018(reflectances) >>> node_tree.optimise(iterations=2, print_callable=lambda x: x) >>> dataset = node_tree.to_dataset() >>> path = os.path.join( ... colour.__path__[0], ... "recovery", ... "tests", ... "resources", ... "ColorChecker_Otsu2018.npz", ... ) >>> dataset.write(path) # doctest: +SKIP >>> dataset = Dataset_Otsu2018() # doctest: +SKIP >>> dataset.read(path) # doctest: +SKIP """ path = str(path) data = np.load(path) start, end, interval = data["shape"] self._shape = SpectralShape(start, end, interval) self._basis_functions = data["basis_functions"] self._means = data["means"] self._selector_array = data["selector_array"] def write(self, path: str | Path) -> None: """ Write the dataset to an *.npz* file at given path. Parameters ---------- path Path to the file. Raises ------ ValueError If the shape is undefined. Examples -------- >>> import os >>> import colour >>> from colour.characterisation import SDS_COLOURCHECKERS >>> from colour.colorimetry import sds_and_msds_to_msds >>> reflectances = sds_and_msds_to_msds( ... SDS_COLOURCHECKERS["ColorChecker N Ohta"].values() ... ) >>> node_tree = Tree_Otsu2018(reflectances) >>> node_tree.optimise(iterations=2, print_callable=lambda x: x) >>> dataset = node_tree.to_dataset() >>> path = os.path.join( ... colour.__path__[0], ... "recovery", ... "tests", ... "resources", ... "ColorChecker_Otsu2018.npz", ... ) >>> dataset.write(path) # doctest: +SKIP """ path = str(path) if self._shape is not None: np.savez( path, shape=as_float_array( [ self._shape.start, self._shape.end, self._shape.interval, ] ), basis_functions=cast(NDArrayFloat, self._basis_functions), means=cast(NDArrayFloat, self._means), selector_array=cast(NDArrayFloat, self._selector_array), ) else: raise ValueError('The "shape" is undefined!')
DATASET_REFERENCE_OTSU2018: Dataset_Otsu2018 = Dataset_Otsu2018( SPECTRAL_SHAPE_OTSU2018, BASIS_FUNCTIONS_OTSU2018, CLUSTER_MEANS_OTSU2018, SELECTOR_ARRAY_OTSU2018, ) """ Builtin *Otsu et al. (2018)* dataset as a :class:`colour.recovery.Dataset_Otsu2018` class instance, usable by :func:`colour.recovery.XYZ_to_sd_Otsu2018` definition among others. """
[docs] def XYZ_to_sd_Otsu2018( XYZ: ArrayLike, cmfs: MultiSpectralDistributions | None = None, illuminant: SpectralDistribution | None = None, dataset: Dataset_Otsu2018 = DATASET_REFERENCE_OTSU2018, clip: bool = True, ) -> SpectralDistribution: """ Recover the spectral distribution of given *CIE XYZ* tristimulus values using *Otsu et al. (2018)* method. Parameters ---------- XYZ *CIE XYZ* tristimulus values to recover the spectral distribution from. cmfs Standard observer colour matching functions, default to the *CIE 1931 2 Degree Standard Observer*. illuminant Illuminant spectral distribution, default to *CIE Standard Illuminant D65*. dataset Dataset to use for reconstruction. The default is to use the published data. clip If *True*, the default, values below zero and above unity in the recovered spectral distributions will be clipped. This ensures that the returned reflectance is physical and conserves energy, but will cause noticeable colour differences in case of very saturated colours. Returns ------- :class:`colour.SpectralDistribution` Recovered spectral distribution. Its shape is always that of the :class:`colour.recovery.SPECTRAL_SHAPE_OTSU2018` class instance. Raises ------ ValueError If the dataset shape is undefined. References ---------- :cite:`Otsu2018` Examples -------- >>> from colour import ( ... CCS_ILLUMINANTS, ... SDS_ILLUMINANTS, ... MSDS_CMFS, ... XYZ_to_sRGB, ... ) >>> from colour.colorimetry import sd_to_XYZ_integration >>> from colour.utilities import numpy_print_options >>> XYZ = np.array([0.20654008, 0.12197225, 0.05136952]) >>> cmfs = ( ... MSDS_CMFS["CIE 1931 2 Degree Standard Observer"] ... .copy() ... .align(SPECTRAL_SHAPE_OTSU2018) ... ) >>> illuminant = SDS_ILLUMINANTS["D65"].copy().align(cmfs.shape) >>> sd = XYZ_to_sd_Otsu2018(XYZ, cmfs, illuminant) >>> with numpy_print_options(suppress=True): ... sd # doctest: +ELLIPSIS SpectralDistribution([[ 380. , 0.0601939...], [ 390. , 0.0568063...], [ 400. , 0.0517429...], [ 410. , 0.0495841...], [ 420. , 0.0502007...], [ 430. , 0.0506489...], [ 440. , 0.0510020...], [ 450. , 0.0493782...], [ 460. , 0.0468046...], [ 470. , 0.0437132...], [ 480. , 0.0416957...], [ 490. , 0.0403783...], [ 500. , 0.0405197...], [ 510. , 0.0406031...], [ 520. , 0.0416912...], [ 530. , 0.0430956...], [ 540. , 0.0444474...], [ 550. , 0.0459336...], [ 560. , 0.0507631...], [ 570. , 0.0628967...], [ 580. , 0.0844661...], [ 590. , 0.1334277...], [ 600. , 0.2262428...], [ 610. , 0.3599330...], [ 620. , 0.4885571...], [ 630. , 0.5752546...], [ 640. , 0.6193023...], [ 650. , 0.6450744...], [ 660. , 0.6610548...], [ 670. , 0.6688673...], [ 680. , 0.6795426...], [ 690. , 0.6887933...], [ 700. , 0.7003469...], [ 710. , 0.7084128...], [ 720. , 0.7154674...], [ 730. , 0.7234334...]], SpragueInterpolator, {}, Extrapolator, {'method': 'Constant', 'left': None, 'right': None}) >>> sd_to_XYZ_integration(sd, cmfs, illuminant) / 100 # doctest: +ELLIPSIS array([ 0.2065494..., 0.1219712..., 0.0514002...]) """ shape = dataset.shape if shape is not None: XYZ = to_domain_1(XYZ) cmfs, illuminant = handle_spectral_arguments( cmfs, illuminant, shape_default=SPECTRAL_SHAPE_OTSU2018 ) xy = XYZ_to_xy(XYZ) basis_functions, mean = dataset.cluster(xy) M = np.empty((3, 3)) for i in range(3): sd = SpectralDistribution(basis_functions[i, :], shape.wavelengths) with domain_range_scale("ignore"): M[:, i] = sd_to_XYZ(sd, cmfs, illuminant) / 100 M_inverse = np.linalg.inv(M) sd = SpectralDistribution(mean, shape.wavelengths) with domain_range_scale("ignore"): XYZ_mu = sd_to_XYZ(sd, cmfs, illuminant) / 100 weights = np.dot(M_inverse, XYZ - XYZ_mu) recovered_sd = np.dot(weights, basis_functions) + mean recovered_sd = np.clip(recovered_sd, 0, 1) if clip else recovered_sd return SpectralDistribution(recovered_sd, shape.wavelengths) else: raise ValueError('The dataset "shape" is undefined!')
@dataclass class PartitionAxis: """ Represent a horizontal or vertical line, partitioning the 2D space in two half-planes. Parameters ---------- origin The x coordinate of a vertical line or the y coordinate of a horizontal line. direction *0* if vertical, *1* if horizontal. Methods ------- - :meth:`~colour.recovery.otsu2018.PartitionAxis.__str__` """ origin: float direction: int def __str__(self) -> str: """ Return a formatted string representation of the partition axis. Returns ------- :class:`str` Formatted string representation. """ return ( f"{self.__class__.__name__}" f"({'horizontal' if self.direction else 'vertical'} partition " f"at {'y' if self.direction else 'x'} = {self.origin})" ) class Data_Otsu2018: """ Store the reference reflectances and derived information along with the methods to process them for a leaf :class:`colour.recovery.otsu2018.Node` class instance. This class also supports partitioning: Creating two smaller instances of :class:`colour.recovery.otsu2018.Data` class by splitting along an horizontal or a vertical axis on the *CIE xy* plane. Parameters ---------- reflectances Reference reflectances of the *n* colours to be stored. The shape must match ``tree.shape`` with *m* points for each colour. cmfs Standard observer colour matching functions. illuminant Illuminant spectral distribution. Attributes ---------- - :attr:`~colour.recovery.otsu2018.Data.reflectances` - :attr:`~colour.recovery.otsu2018.Data.cmfs` - :attr:`~colour.recovery.otsu2018.Data.illuminant` - :attr:`~colour.recovery.otsu2018.Data.basis_functions` - :attr:`~colour.recovery.otsu2018.Data.mean` Methods ------- - :meth:`~colour.recovery.otsu2018.Data.__str__` - :meth:`~colour.recovery.otsu2018.Data.__len__` - :meth:`~colour.recovery.otsu2018.Data.origin` - :meth:`~colour.recovery.otsu2018.Data.partition` - :meth:`~colour.recovery.otsu2018.Data.PCA` - :meth:`~colour.recovery.otsu2018.Data.reconstruct` - :meth:`~colour.recovery.otsu2018.Data.reconstruction_error` """ def __init__( self, reflectances: ArrayLike | None, cmfs: MultiSpectralDistributions, illuminant: SpectralDistribution, ) -> None: self._cmfs: MultiSpectralDistributions = cmfs self._illuminant: SpectralDistribution = illuminant self._XYZ: NDArrayFloat | None = None self._xy: NDArrayFloat | None = None self._reflectances: NDArrayFloat | None = np.array([]) self.reflectances = reflectances self._basis_functions: NDArrayFloat | None = None self._mean: NDArrayFloat | None = None self._M: NDArrayFloat | None = None self._XYZ_mu: NDArrayFloat | None = None self._reconstruction_error: float | None = None @property def reflectances(self) -> NDArrayFloat | None: """ Getter and setter property for the reference reflectances. Parameters ---------- value Value to set the reference reflectances with. Returns ------- :class:`numpy.ndarray` Reference reflectances. """ return self._reflectances @reflectances.setter def reflectances(self, value: ArrayLike | None): """Setter for the **self.reflectances** property.""" if value is not None: self._reflectances = as_float_array(value) self._XYZ = ( msds_to_XYZ_integration( self._reflectances, self._cmfs, self._illuminant, shape=self._cmfs.shape, ) / 100 ) self._xy = XYZ_to_xy(self._XYZ) else: self._reflectances, self._XYZ, self._xy = None, None, None @property def cmfs(self) -> MultiSpectralDistributions: """ Getter property for the standard observer colour matching functions. Returns ------- :class:`colour.MultiSpectralDistributions` Standard observer colour matching functions. """ return self._cmfs @property def illuminant(self) -> SpectralDistribution: """ Getter property for the illuminant. Returns ------- :class:`colour.SpectralDistribution` Illuminant. """ return self._illuminant @property def basis_functions(self) -> NDArrayFloat | None: """ Getter property for the basis functions. Returns ------- :class:`numpy.ndarray` Basis functions. """ return self._basis_functions @property def mean(self) -> NDArrayFloat | None: """ Getter property for the mean distribution. Returns ------- :py:data:`None` or :class:`numpy.ndarray` Mean distribution. """ return self._mean def __str__(self) -> str: """ Return a formatted string representation of the data. Returns ------- :class:`str` Formatted string representation. """ return f"{self.__class__.__name__}({len(self)} Reflectances)" def __len__(self) -> int: """ Return the number of colours in the data. Returns ------- :class:`int` Number of colours in the data. """ return self._reflectances.shape[0] if self._reflectances is not None else 0 def origin(self, i: int, direction: int) -> float: """ Return the origin *CIE x* or *CIE y* chromaticity coordinate for given index and direction. Parameters ---------- i Origin index. direction Origin direction. Returns ------- :class:`float` Origin *CIE x* or *CIE y* chromaticity coordinate. Raises ------ ValueError If the chromaticity coordinates are undefined. """ if self._xy is not None: return self._xy[i, direction] else: raise ValueError('The "chromaticity coordinates" are undefined!') def partition(self, axis: PartitionAxis) -> Tuple[Data_Otsu2018, Data_Otsu2018]: """ Partition the data using given partition axis. Parameters ---------- axis Partition axis used to partition the data. Returns ------- :class:`tuple` Tuple of left or lower part and right or upper part. Raises ------ ValueError If the tristimulus values or chromaticity coordinates are undefined. """ lesser = Data_Otsu2018(None, self._cmfs, self._illuminant) greater = Data_Otsu2018(None, self._cmfs, self._illuminant) if ( self._XYZ is not None and self._xy is not None and self._reflectances is not None ): mask = self._xy[:, axis.direction] <= axis.origin lesser._reflectances = self._reflectances[mask, :] greater._reflectances = self._reflectances[~mask, :] lesser._XYZ = self._XYZ[mask, :] greater._XYZ = self._XYZ[~mask, :] lesser._xy = self._xy[mask, :] greater._xy = self._xy[~mask, :] return lesser, greater else: raise ValueError( 'The "tristimulus values" or "chromaticity coordinates" are ' "undefined!" ) def PCA(self) -> None: """ Perform the *Principal Component Analysis* (PCA) on the data and sets the relevant attributes accordingly. """ if self._M is None and self._reflectances is not None: settings: Dict[str, Any] = { "cmfs": self._cmfs, "illuminant": self._illuminant, "shape": self._cmfs.shape, } self._mean = np.mean(self._reflectances, axis=0) self._XYZ_mu = ( msds_to_XYZ_integration(cast(NDArrayFloat, self._mean), **settings) / 100 ) _w, w = eigen_decomposition( self._reflectances - self._mean, # pyright: ignore descending_order=False, covariance_matrix=True, ) self._basis_functions = np.transpose(w[:, -3:]) self._M = np.transpose( msds_to_XYZ_integration(self._basis_functions, **settings) / 100 ) def reconstruct(self, XYZ: ArrayLike) -> SpectralDistribution: """ Reconstruct the reflectance for the given *CIE XYZ* tristimulus values. Parameters ---------- XYZ *CIE XYZ* tristimulus values to recover the spectral distribution from. Returns ------- :class:`colour.SpectralDistribution` Recovered spectral distribution. Raises ------ ValueError If the matrix :math:`M`, the mean tristimulus values or the basis functions are undefined. """ if ( self._M is not None and self._XYZ_mu is not None and self._basis_functions is not None ): XYZ = as_float_array(XYZ) weights = np.dot(np.linalg.inv(self._M), XYZ - self._XYZ_mu) reflectance = np.dot(weights, self._basis_functions) + self._mean reflectance = np.clip(reflectance, 0, 1) return SpectralDistribution(reflectance, self._cmfs.wavelengths) else: raise ValueError( 'The matrix "M", the "mean tristimulus values" or the ' '"basis functions" are undefined!' ) def reconstruction_error(self) -> float: """ Return the reconstruction error of the data. The error is computed by reconstructing the reflectances for the reference *CIE XYZ* tristimulus values using PCA and, comparing the reconstructed reflectances against the reference reflectances. Returns ------- :class:`float` The reconstruction error for the data. Raises ------ ValueError If the tristimulus values are undefined. Notes ----- - The reconstruction error is cached upon being computed and thus is only computed once per node. """ if self._reconstruction_error is not None: return self._reconstruction_error if self._XYZ is not None and self._reflectances is not None: self.PCA() error: float = 0.0 for i in range(len(self)): sd = self._reflectances[i, :] XYZ = self._XYZ[i, :] recovered_sd = self.reconstruct(XYZ) error += cast(float, np.sum((sd - recovered_sd.values) ** 2)) self._reconstruction_error = error return error else: raise ValueError('The "tristimulus values" are undefined!') class Node_Otsu2018(TreeNode): """ Represent a node in a :meth:`colour.recovery.Tree_Otsu2018` class instance node tree. Parameters ---------- parent Parent of the node. children Children of the node. data The colour data belonging to this node. Attributes ---------- - :attr:`~colour.recovery.otsu2018.Node.partition_axis` - :attr:`~colour.recovery.otsu2018.Node.row` Methods ------- - :meth:`~colour.recovery.otsu2018.Node.__init__` - :meth:`~colour.recovery.otsu2018.Node.split` - :meth:`~colour.recovery.otsu2018.Node.minimise` - :meth:`~colour.recovery.otsu2018.Node.leaf_reconstruction_error` - :meth:`~colour.recovery.otsu2018.Node.branch_reconstruction_error` """ def __init__( self, parent: Self | None = None, children: list | None = None, data: Self | None = None, ) -> None: super().__init__(parent=parent, children=children, data=data) self._partition_axis: PartitionAxis | None = None self._best_partition: ( Tuple[Sequence[Node_Otsu2018], PartitionAxis, float] | None ) = None @property def partition_axis(self) -> PartitionAxis | None: """ Getter property for the node partition axis. Returns ------- :class:`colour.recovery.otsu2018.PartitionAxis` Node partition axis. """ return self._partition_axis @property def row(self) -> Tuple[float, float, Self, Self]: """ Getter property for the node row for the selector array. Returns ------- :class:`tuple` Node row for the selector array. Raises ------ ValueError If the partition axis is undefined. """ if self._partition_axis is not None: return ( self._partition_axis.origin, self._partition_axis.direction, self.children[0], self.children[1], ) else: raise ValueError('The "partition axis" is undefined!') def split(self, children: Sequence[Self], axis: PartitionAxis): """ Convert the leaf node into an inner node using given children and partition axis. Parameters ---------- children Tuple of two :class:`colour.recovery.otsu2018.Node` class instances. axis Partition axis. """ self.data = None self.children = list(children) # pyright: ignore self._best_partition = None self._partition_axis = axis def minimise( self, minimum_cluster_size: int ) -> Tuple[Sequence[Node_Otsu2018], PartitionAxis, float]: """ Find the best partition for the node that minimises the leaf reconstruction error. Parameters ---------- minimum_cluster_size Smallest acceptable cluster size. It must be at least 3 or the *Principal Component Analysis* (PCA) is not possible. Returns ------- :class:`tuple` Tuple of tuple of nodes created by splitting a node with a given partition, partition axis, i.e., the horizontal or vertical line, partitioning the 2D space in two half-planes and partition error. """ if self._best_partition is not None: return self._best_partition leaf_error = self.leaf_reconstruction_error() best_error = None with tqdm(total=2 * len(self.data)) as progress: for direction in [0, 1]: for i in range(len(self.data)): progress.update() axis = PartitionAxis(self.data.origin(i, direction), direction) data_lesser, data_greater = self.data.partition(axis) if np.any( np.array( [ len(data_lesser), len(data_greater), ] ) < minimum_cluster_size ): continue lesser = Node_Otsu2018(data=data_lesser) lesser.data.PCA() greater = Node_Otsu2018(data=data_greater) greater.data.PCA() partition_error = ( lesser.leaf_reconstruction_error() + greater.leaf_reconstruction_error() ) partition = [lesser, greater] if partition_error >= leaf_error: continue if best_error is None or partition_error < best_error: self._best_partition = ( partition, axis, partition_error, ) if self._best_partition is None: raise RuntimeError("Could not find the best partition!") return self._best_partition def leaf_reconstruction_error(self) -> float: """ Return the reconstruction error of the node data. The error is computed by reconstructing the reflectances for the data reference *CIE XYZ* tristimulus values using PCA and, comparing the reconstructed reflectances against the data reference reflectances. Returns ------- :class:`float` The reconstruction errors summation for the node data. """ return self.data.reconstruction_error() def branch_reconstruction_error(self) -> float: """ Compute the reconstruction error for all the leaves data connected to the node or its children, i.e., the reconstruction errors summation for all the leaves in the branch. Returns ------- :class:`float` Reconstruction errors summation for all the leaves' data in the branch. """ if self.is_leaf(): return self.leaf_reconstruction_error() else: return as_float_scalar( np.sum([child.branch_reconstruction_error() for child in self.children]) )
[docs] class Tree_Otsu2018(Node_Otsu2018): """ A sub-class of :class:`colour.recovery.otsu2018.Node` class representing the root node of a tree containing information shared with all the nodes, such as the standard observer colour matching functions and the illuminant, if any is used. Global operations involving the entire tree, such as optimisation and conversion to dataset, are implemented in this sub-class. Parameters ---------- reflectances Reference reflectances of the *n* reference colours to use for optimisation. cmfs Standard observer colour matching functions, default to the *CIE 1931 2 Degree Standard Observer*. illuminant Illuminant spectral distribution, default to *CIE Standard Illuminant D65*. Attributes ---------- - :attr:`~colour.recovery.Tree_Otsu2018.reflectances` - :attr:`~colour.recovery.Tree_Otsu2018.cmfs` - :attr:`~colour.recovery.Tree_Otsu2018.illuminant` Methods ------- - :meth:`~colour.recovery.otsu2018.Tree_Otsu2018.__init__` - :meth:`~colour.recovery.otsu2018.Tree_Otsu2018.__str__` - :meth:`~colour.recovery.otsu2018.Tree_Otsu2018.optimise` - :meth:`~colour.recovery.otsu2018.Tree_Otsu2018.to_dataset` References ---------- :cite:`Otsu2018` Examples -------- >>> import os >>> import colour >>> from colour import MSDS_CMFS, SDS_COLOURCHECKERS, SDS_ILLUMINANTS >>> from colour.colorimetry import sds_and_msds_to_msds >>> from colour.utilities import numpy_print_options >>> XYZ = np.array([0.20654008, 0.12197225, 0.05136952]) >>> cmfs = ( ... MSDS_CMFS["CIE 1931 2 Degree Standard Observer"] ... .copy() ... .align(SpectralShape(360, 780, 10)) ... ) >>> illuminant = SDS_ILLUMINANTS["D65"].copy().align(cmfs.shape) >>> reflectances = sds_and_msds_to_msds( ... SDS_COLOURCHECKERS["ColorChecker N Ohta"].values() ... ) >>> node_tree = Tree_Otsu2018(reflectances, cmfs, illuminant) >>> node_tree.optimise(iterations=2, print_callable=lambda x: x) >>> dataset = node_tree.to_dataset() >>> path = os.path.join( ... colour.__path__[0], ... "recovery", ... "tests", ... "resources", ... "ColorChecker_Otsu2018.npz", ... ) >>> dataset.write(path) # doctest: +SKIP >>> dataset = Dataset_Otsu2018() # doctest: +SKIP >>> dataset.read(path) # doctest: +SKIP >>> sd = XYZ_to_sd_Otsu2018(XYZ, cmfs, illuminant, dataset) >>> with numpy_print_options(suppress=True): ... sd # doctest: +ELLIPSIS SpectralDistribution([[ 360. , 0.0651341...], [ 370. , 0.0651341...], [ 380. , 0.0651341...], [ 390. , 0.0749684...], [ 400. , 0.0815578...], [ 410. , 0.0776439...], [ 420. , 0.0721897...], [ 430. , 0.0649064...], [ 440. , 0.0567185...], [ 450. , 0.0484685...], [ 460. , 0.0409768...], [ 470. , 0.0358964...], [ 480. , 0.0307857...], [ 490. , 0.0270148...], [ 500. , 0.0273773...], [ 510. , 0.0303157...], [ 520. , 0.0331285...], [ 530. , 0.0363027...], [ 540. , 0.0425987...], [ 550. , 0.0513442...], [ 560. , 0.0579256...], [ 570. , 0.0653850...], [ 580. , 0.0929522...], [ 590. , 0.1600326...], [ 600. , 0.2586159...], [ 610. , 0.3701242...], [ 620. , 0.4702243...], [ 630. , 0.5396261...], [ 640. , 0.5737561...], [ 650. , 0.590848 ...], [ 660. , 0.5935371...], [ 670. , 0.5923295...], [ 680. , 0.5956326...], [ 690. , 0.5982513...], [ 700. , 0.6017904...], [ 710. , 0.6016419...], [ 720. , 0.5996892...], [ 730. , 0.6000018...], [ 740. , 0.5964443...], [ 750. , 0.5868181...], [ 760. , 0.5860973...], [ 770. , 0.5614878...], [ 780. , 0.5289331...]], SpragueInterpolator, {}, Extrapolator, {'method': 'Constant', 'left': None, 'right': None}) """
[docs] def __init__( self, reflectances: MultiSpectralDistributions, cmfs: MultiSpectralDistributions | None = None, illuminant: SpectralDistribution | None = None, ) -> None: super().__init__() cmfs, illuminant = handle_spectral_arguments( cmfs, illuminant, shape_default=SPECTRAL_SHAPE_OTSU2018 ) self._cmfs: MultiSpectralDistributions = cmfs self._illuminant: SpectralDistribution = illuminant self._reflectances: NDArrayFloat = np.transpose( reshape_msds(reflectances, self._cmfs.shape, copy=False).values ) self.data: Data_Otsu2018 = Data_Otsu2018( self._reflectances, self._cmfs, self._illuminant )
@property def reflectances(self) -> NDArrayFloat: """ Getter property for the reference reflectances. Returns ------- :class:`numpy.ndarray` Reference reflectances. """ return self._reflectances @property def cmfs(self) -> MultiSpectralDistributions: """ Getter property for the standard observer colour matching functions. Returns ------- :class:`colour.MultiSpectralDistributions` Standard observer colour matching functions. """ return self._cmfs @property def illuminant(self) -> SpectralDistribution: """ Getter property for the illuminant. Returns ------- :class:`colour.SpectralDistribution` Illuminant. """ return self._illuminant def optimise( self, iterations: int = 8, minimum_cluster_size: int | None = None, print_callable: Callable = print, ): """ Optimise the tree by repeatedly performing optimal partitioning of the nodes, creating a tree that minimises the total reconstruction error. Parameters ---------- iterations Maximum number of splits. If the dataset is too small, this number might not be reached. The default is to create 8 clusters, like in :cite:`Otsu2018`. minimum_cluster_size Smallest acceptable cluster size. By default, it is chosen automatically, based on the size of the dataset and desired number of clusters. It must be at least 3 or the *Principal Component Analysis* (PCA) is not be possible. print_callable Callable used to print progress and diagnostic information. Examples -------- >>> from colour.colorimetry import sds_and_msds_to_msds >>> from colour import MSDS_CMFS, SDS_COLOURCHECKERS, SDS_ILLUMINANTS >>> cmfs = ( ... MSDS_CMFS["CIE 1931 2 Degree Standard Observer"] ... .copy() ... .align(SpectralShape(360, 780, 10)) ... ) >>> illuminant = SDS_ILLUMINANTS["D65"].copy().align(cmfs.shape) >>> reflectances = sds_and_msds_to_msds( ... SDS_COLOURCHECKERS["ColorChecker N Ohta"].values() ... ) >>> node_tree = Tree_Otsu2018(reflectances, cmfs, illuminant) >>> node_tree.optimise(iterations=2) # doctest: +ELLIPSIS ======================================================================\ ========= * \ * * "Otsu et al. (2018)" Tree Optimisation \ * * \ * ======================================================================\ ========= Initial branch error is: 4.8705353... <BLANKLINE> Iteration 1 of 2: <BLANKLINE> Optimising "Tree_Otsu2018#...(Data_Otsu2018(24 Reflectances))"... <BLANKLINE> Splitting "Tree_Otsu2018#...(Data_Otsu2018(24 Reflectances))" into \ "Node_Otsu2018#...(Data_Otsu2018(10 Reflectances))" and \ "Node_Otsu2018#...(Data_Otsu2018(14 Reflectances))" along \ "PartitionAxis(horizontal partition at y = 0.3240945...)". Error is reduced by 0.0054840... and is now 4.8650513..., 99.9% of \ the initial error. <BLANKLINE> Iteration 2 of 2: <BLANKLINE> Optimising "Node_Otsu2018#...(Data_Otsu2018(10 Reflectances))"... Optimisation failed: Could not find the best partition! Optimising "Node_Otsu2018#...(Data_Otsu2018(14 Reflectances))"... <BLANKLINE> Splitting "Node_Otsu2018#...(Data_Otsu2018(14 Reflectances))" into \ "Node_Otsu2018#...(Data_Otsu2018(7 Reflectances))" and \ "Node_Otsu2018#...(Data_Otsu2018(7 Reflectances))" along \ "PartitionAxis(horizontal partition at y = 0.3600663...)". Error is reduced by 0.9681059... and is now 3.8969453..., 80.0% of \ the initial error. Tree optimisation is complete! >>> print(node_tree.render()) # doctest: +ELLIPSIS |----"Tree_Otsu2018#..." |----"Node_Otsu2018#..." |----"Node_Otsu2018#..." |----"Node_Otsu2018#..." |----"Node_Otsu2018#..." <BLANKLINE> >>> len(node_tree) 4 """ default_cluster_size = len(self.data) / iterations // 2 minimum_cluster_size = max( cast(int, optional(minimum_cluster_size, default_cluster_size)), 3 ) initial_branch_error = self.branch_reconstruction_error() message_box( '"Otsu et al. (2018)" Tree Optimisation', print_callable=print_callable, ) print_callable(f"Initial branch error is: {initial_branch_error}") best_leaf, best_partition, best_axis, partition_error = [None] * 4 for i in range(iterations): print_callable(f"\nIteration {i + 1} of {iterations}:\n") total_error = self.branch_reconstruction_error() optimised_total_error = None for leaf in self.leaves: print_callable(f'Optimising "{leaf}"...') try: partition, axis, partition_error = leaf.minimise( minimum_cluster_size ) except RuntimeError as error: print_callable(f"Optimisation failed: {error}") continue new_total_error = ( total_error - leaf.leaf_reconstruction_error() + partition_error ) if ( optimised_total_error is None or new_total_error < optimised_total_error ): optimised_total_error = new_total_error best_axis = axis best_leaf = leaf best_partition = partition if optimised_total_error is None: print_callable( f"\nNo further improvement is possible!" f"\nTerminating at iteration {i}.\n" ) break if best_partition is not None: print_callable( f'\nSplitting "{best_leaf}" into "{best_partition[0]}" ' f'and "{best_partition[1]}" along "{best_axis}".' ) print_callable( f"Error is reduced by " f"{leaf.leaf_reconstruction_error() - partition_error} and " f"is now {optimised_total_error}, " f"{100 * optimised_total_error / initial_branch_error:.1f}% " f"of the initial error." ) if best_leaf is not None: best_leaf.split(best_partition, best_axis) print_callable("Tree optimisation is complete!") def to_dataset(self) -> Dataset_Otsu2018: """ Create a :class:`colour.recovery.Dataset_Otsu2018` class instance based on data stored in the tree. The dataset can then be saved to disk or used to recover reflectance with :func:`colour.recovery.XYZ_to_sd_Otsu2018` definition. Returns ------- :class:`colour.recovery.Dataset_Otsu2018` The dataset object. Examples -------- >>> from colour.colorimetry import sds_and_msds_to_msds >>> from colour.characterisation import SDS_COLOURCHECKERS >>> reflectances = sds_and_msds_to_msds( ... SDS_COLOURCHECKERS["ColorChecker N Ohta"].values() ... ) >>> node_tree = Tree_Otsu2018(reflectances) >>> node_tree.optimise(iterations=2, print_callable=lambda x: x) >>> node_tree.to_dataset() # doctest: +ELLIPSIS <colour.recovery.otsu2018.Dataset_Otsu2018 object at 0x...> """ basis_functions = as_float_array( [leaf.data.basis_functions for leaf in self.leaves] ) means = as_float_array([leaf.data.mean for leaf in self.leaves]) if len(self.children) == 0: selector_array = zeros(4) else: def add_rows(node: Node_Otsu2018, data: dict | None = None) -> dict | None: """Add rows for given node and its children.""" data = optional(data, {"rows": [], "node_to_leaf_id": {}, "leaf_id": 0}) if node.is_leaf(): data["node_to_leaf_id"][node] = data["leaf_id"] data["leaf_id"] += 1 return None data["node_to_leaf_id"][node] = -len(data["rows"]) data["rows"].append(list(node.row)) for child in node.children: add_rows(child, data) return data data = cast(dict, add_rows(self)) rows = data["rows"] for i, row in enumerate(rows): for j in (2, 3): rows[i][j] = data["node_to_leaf_id"][row[j]] selector_array = as_float_array(rows) return Dataset_Otsu2018( self._cmfs.shape, basis_functions, means, selector_array, )