"""gustaf/gustaf/helpers/data.py.
Helps helpee to manage data. Some useful data structures.
"""
from collections import namedtuple
from functools import wraps
import numpy as np
from gustaf.helpers._base import HelperBase
[docs]
class TrackedArray(np.ndarray):
"""numpy array object that keeps mirroring inplace changes to the source.
Meant to help control_points.
"""
__slots__ = (
"_super_arr",
"_modified",
)
def __array_finalize__(self, obj):
"""Sets default flags for any arrays that maybe generated based on
physical space array. For more information,
see https://numpy.org/doc/stable/user/basics.subclassing.html"""
self._super_arr = None
self._modified = True
# for arrays created based on this subclass
if isinstance(obj, type(self)):
# this is copy. nothing to worry here
if self.base is None:
return None
# first child array
if self.base is obj:
# make sure this is not a recursively born child
# for example, `arr[[1,2]][:,2]`
# we should have set _super_arr to True
# if we made this array using `make_tracked_array`
if obj._super_arr is True:
self._super_arr = obj
return None
# multi generation child array
if obj._super_arr is not None and self.base is obj.base:
self._super_arr = obj._super_arr
return None
return None
@property
def modified(self):
"""
Modified flag getter
"""
# have super arr and self is not super_arr,
if self._super_arr is not None and self._super_arr is not True:
return self._super_arr._modified
return self._modified
@modified.setter
def modified(self, m):
if self._super_arr is not None and self._super_arr is not True:
self._super_arr._modified = m
else:
self._modified = m
[docs]
def copy(self, *args, **kwargs):
"""copy creates regular numpy array"""
return np.array(self, *args, copy=True, **kwargs)
[docs]
def view(self, *args, **kwargs):
"""Set writeable flags to False for the view."""
v = super(self.__class__, self).view(*args, **kwargs)
v.flags.writeable = False
return v
def __iadd__(self, *args, **kwargs):
sr = super(self.__class__, self).__iadd__(*args, **kwargs)
self.modified = True
return sr
def __isub__(self, *args, **kwargs):
sr = super(self.__class__, self).__isub__(*args, **kwargs)
self.modified = True
return sr
def __imul__(self, *args, **kwargs):
sr = super(self.__class__, self).__imul__(*args, **kwargs)
self.modified = True
return sr
def __idiv__(self, *args, **kwargs):
sr = super(self.__class__, self).__idiv__(*args, **kwargs)
self.modified = True
return sr
def __itruediv__(self, *args, **kwargs):
sr = super(self.__class__, self).__itruediv__(*args, **kwargs)
self.modified = True
return sr
def __imatmul__(self, *args, **kwargs):
sr = super(self.__class__, self).__imatmul__(*args, **kwargs)
self.modified = True
return sr
def __ipow__(self, *args, **kwargs):
sr = super(self.__class__, self).__ipow__(*args, **kwargs)
self.modified = True
return sr
def __imod__(self, *args, **kwargs):
sr = super(self.__class__, self).__imod__(*args, **kwargs)
self.modified = True
return sr
def __ifloordiv__(self, *args, **kwargs):
sr = super(self.__class__, self).__ifloordiv__(*args, **kwargs)
self.modified = True
return sr
def __ilshift__(self, *args, **kwargs):
sr = super(self.__class__, self).__ilshift__(*args, **kwargs)
self.modified = True
return sr
def __irshift__(self, *args, **kwargs):
sr = super(self.__class__, self).__irshift__(*args, **kwargs)
self.modified = True
return sr
def __iand__(self, *args, **kwargs):
sr = super(self.__class__, self).__iand__(*args, **kwargs)
self.modified = True
return sr
def __ixor__(self, *args, **kwargs):
sr = super(self.__class__, self).__ixor__(*args, **kwargs)
self.modified = True
return sr
def __ior__(self, *args, **kwargs):
sr = super(self.__class__, self).__ior__(*args, **kwargs)
self.modified = True
return sr
def __setitem__(self, key, value):
# set first. invalid setting will cause error
sr = super(self.__class__, self).__setitem__(key, value)
self.modified = True
return sr
[docs]
def make_tracked_array(array, dtype=None, copy=True):
"""Motivated by nice implementations of `trimesh` (see LICENSE.txt).
`https://github.com/mikedh/trimesh/blob/main/trimesh/caching.py`.
Factory-like wrapper function for TrackedArray.
If you want to use TrackedArray, it is recommended to use this function.
Parameters
------------
array: array- like object
To be turned into a TrackedArray
dtype: np.dtype
Which dtype to use for the array
copy: bool
Default is True. copy if True.
Returns
------------
tracked : TrackedArray
Contains input array data
"""
# if someone passed us None, just create an empty array
if array is None:
array = []
if copy:
array = np.array(array, dtype=dtype)
else:
array = np.asanyarray(array, dtype=dtype)
tracked = array.view(TrackedArray)
# this marks original array
tracked._super_arr = True
return tracked
[docs]
class DataHolder(HelperBase):
__slots__ = ("_saved",)
def __init__(self, helpee):
"""Base class for any data holder. Behaves similar to dict.
Parameters
-----------
helpee: object
GustafBase objects would probably make the most sense here.
"""
self._helpee = helpee
self._saved = {}
def __setitem__(self, key, value):
"""Raise Error to disable direct value setting.
Parameters
-----------
key: str
value: object
"""
raise NotImplementedError(
"Sorry, you can't set items directly for "
f"{type(self).__qualname__}"
)
def __getitem__(self, key):
"""Returns stored item if the key exists.
Parameters
-----------
key: str
Returns
--------
value: object
"""
if key in self._saved:
return self._saved[key]
else:
raise KeyError(f"`{key}` is not stored for {type(self._helpee)}")
def __contains__(self, key):
"""Returns if saved data contains the given key.
Parameters
----------
key: str
Returns
-------
value
"""
return key in self._saved
def __len__(self):
"""
Returns number of items.
Parameters
----------
None
Returns
-------
len: int
"""
return len(self._saved)
[docs]
def pop(self, key, default=None):
"""
Applied pop() to saved data
Parameters
----------
key: str
default: object
Returns
-------
value: object
"""
return self._saved.pop(key, default)
[docs]
def clear(self):
"""
Clears saved data by reassigning new dict
"""
self._saved = {}
[docs]
def get(self, key, default_values=None):
"""Returns stored item if the key exists. Else, given default value. If
the key exist, default value always exists, since it is initialized
that way.
Parameters
-----------
key: str
default_values: object
Returns
--------
value: object
"""
if key in self._saved:
return self._saved[key]
else:
return default_values
[docs]
def keys(self):
"""Returns keys of data holding dict.
Returns
--------
keys: dict_keys
"""
return self._saved.keys()
[docs]
def values(self):
"""Returns values of data holding dict.
Returns
--------
values: dict_values
"""
return self._saved.values()
[docs]
def items(self):
"""Returns items of data holding dict.
Returns
--------
values: dict_values
"""
return self._saved.items()
[docs]
def update(self, **kwargs):
"""
Updates given kwargs using __setitem__.
Parameters
----------
**kwargs: kwargs
Returns
-------
None
"""
self._saved.update(**kwargs)
[docs]
class ComputedData(DataHolder):
_depends = None
_inv_depends = None
__slots__ = ()
def __init__(self, helpee, **_kwargs):
"""Stores last computed values.
Keys are expected to be the same as helpee's function that computes the
value.
Parameters
-----------
helpee: GustafBase
"""
super().__init__(helpee)
[docs]
@classmethod
def depends_on(cls, var_names, make_property=False):
"""Decorator as classmethod.
checks if the key should be computed. Two cases, where the answer is
yes:
1. there's modification on arrays that the key depend on.
->erases all other
2. is corresponding value None?
Supports multi-dependency
Parameters
-----------
var_name: list
make_property:
"""
def inner(func):
# following are done once while modules are loaded
# just subclass this class to make a special helper
# for each helpee class.
assert isinstance(var_names, list), "var_names should be a list"
# initialize property
# _depends is dict(str: list)
if cls._depends is None:
cls._depends = {}
if cls._depends.get(func.__name__, None) is None:
cls._depends[func.__name__] = []
# add dependency info
cls._depends[func.__name__].extend(var_names)
# _inv_depends is dict(str: list)
if cls._inv_depends is None:
cls._inv_depends = {}
# add inverse dependency
for vn in var_names:
if cls._inv_depends.get(vn, None) is None:
cls._inv_depends[vn] = []
cls._inv_depends[vn].append(func.__name__)
@wraps(func)
def compute_or_return_saved(*args, **kwargs):
"""Check if the key should be computed,"""
# extract some related info
self = args[0] # the helpee itself
# explicitly settable kwargs.
# unless recompute flag is set False,
# it will always recompute and save them
# if you call the same function without kwargs
# the last one with kwargs will be returned
recompute = False
if kwargs:
recompute = kwargs.get("recompute", True)
# computed arrays are called _computed.
# loop over dependencies and check if they are modified
for dependee_str in cls._depends[func.__name__]:
dependee = getattr(self, dependee_str)
# is modified?
if dependee._modified:
for inv in cls._inv_depends[dependee_str]:
self._computed._saved[inv] = None
# is saved / want to recompute?
# recompute is added for computed values that accepts params.
saved = self._computed._saved.get(func.__name__, None)
if saved is not None and not recompute:
return saved
# we've reached this point because we have to compute this
computed = func(*args, **kwargs)
if isinstance(computed, np.ndarray):
computed.flags.writeable = False # configurable?
self._computed._saved[func.__name__] = computed
# so, all fresh. we can press NOT-modified button
for dependee_str in cls._depends[func.__name__]:
dependee = getattr(self, dependee_str)
dependee._modified = False
return computed
if make_property:
return property(compute_or_return_saved)
else:
return compute_or_return_saved
return inner
[docs]
class VertexData(DataHolder):
"""
Minimal manager for vertex data. Checks input array size, transforms
data on request. __setitem__ and __getitem__ will perform length checks.
key(), values(), items(), and get() will return whatever is currently
stored.
gustaf supports two kinds of data representation: scalar-data with cmap
and vector-data with arrows.
"""
__slots__ = ()
def __init__(self, helpee):
"""Checks if helpee has vertices as attr beforehand.
Parameters
----------
helpee: Vertices
Vertices and its derived classes.
"""
if not hasattr(helpee, "vertices"):
raise AttributeError("Helpee does not have `vertices`.")
super().__init__(helpee)
def _validate_len(self, value=None, raise_=True):
"""Checks if given value is a valid vertex_data based of its length.
If raise_, throws error, else, deletes all incompatible values.
Only checks len(). If array has (1, len) shape, this will still return
False.
Parameters
----------
value: array-like
Default is None. If None, checks all existing values.
raise_: bool
Default is True, If True, raises in case of incompatibility.
Returns
-------
validity: bool
If raise_ is False.
"""
valid = True
helpee_len = len(self._helpee.vertices)
if value is not None:
if len(value) != helpee_len:
valid = False
if raise_ and not valid:
raise ValueError(
f"Expected ({helpee_len}) length data, "
f"Given ({len(value)})"
)
return valid
# here, check all saved values.
to_pop = []
for key, d_value in self._saved.items():
if len(d_value) != helpee_len:
valid = False
if not valid:
if raise_:
raise ValueError(
f"`{key}`-data len ({len(d_value)}) doesn't match "
f"expected len ({helpee_len})"
)
else:
self._logd(
f"`{key}`-data len ({len(d_value)}) doesn't match "
f"expected len ({helpee_len}). Deleting `{key}`."
)
# pop invalid data
to_pop.append(key)
to_pop.append(key + "__norm")
# pop if needed
for tp in to_pop:
self._saved.pop(tp, None)
return valid
def __setitem__(self, key, value):
"""
Performs len() based check before storing vertex_data.
Parameters
----------
key: str
value: object
Returns
-------
None
"""
self._validate_len(value, raise_=True)
# we are here because this is valid
self._saved[key] = make_tracked_array(value, copy=False).reshape(
len(self._helpee.vertices), -1
)
# if "data" or "arrow_data" is empty in show_options, we want to
# set this data to show. We will always set this as "data".
show_options = getattr(self._helpee, "show_options", None)
if show_options is not None:
if "data" in show_options or "arrow_data" in show_options:
return None
show_options["data"] = key
def __getitem__(self, key):
"""
Validates data length before returning item.
Parameters
----------
key: str
Returns
-------
data: array-like
"""
value = super().__getitem__(key) # raises KeyError
valid = self._validate_len(value, raise_=False)
if valid:
return value
else:
raise KeyError(
"Either requested data is not stored or deleted due to "
"changes in number of vertices."
)
[docs]
def as_scalar(self, key, default=None):
"""
Returns scalar version of requested data. If it is already a scalar,
will return as is. Else, will return a norm. using `np.linalg.norm()`.
Parameters
----------
key: str
default: object
Returns
-------
data_as_scalar: (n, 1) np.ndarray
"""
if key not in self.keys():
return default
# interpret scalar as norm
# save the norm once it is called.
if "__norm" not in key:
norm_key = key + "__norm"
else:
norm_key = key
key = key.replace("__norm", "")
if norm_key in self.keys():
saved = self[norm_key] # performs len check
# return if original is not modified
if not self[key]._modified: # check if original data is modified
return saved
else:
self._saved.pop(norm_key)
# we are here because we have to compute norm. let's save norm
value = self[key]
if value.shape[1] == 1:
value_norm = value
else:
value_norm = np.linalg.norm(value, axis=1).reshape(-1, 1)
# save norm
self[norm_key] = value_norm
# considered not modified
self[key]._modified = False
return value_norm
[docs]
def as_arrow(self, key, default=None, raise_=True):
"""
Returns an array as is, only if it is showable as arrow.
Parameters
----------
key: str
default: object
raise_: bool
Returns
-------
data: (n, d) np.ndarray
"""
if key not in self.keys():
return default
value = self[key]
if value.shape[1] == 1:
self._logd(f"as_arrow() requested data ({key}) is 1D data.")
if raise_:
raise ValueError(
f"`{key}`-data is 1D and cannot be represented as arrows."
)
return value
Unique2DFloats = namedtuple(
"Unique2DFloats", ["values", "ids", "inverse", "intersection"]
)
Unique2DFloats.__doc__ = """
namedtuple to hold unique information of float type arrays.
Note that for float types, "close enough" might be a better name than unique.
This way, all tracked arrays, as long as they are 2D, have a dot separated
syntax to access unique info. For example, `mesh.unique_vertices.ids`.
"""
Unique2DFloats.values.__doc__ = """`(n, d) np.ndarray`
Field number 0"""
Unique2DFloats.ids.__doc__ = """`(n, d) np.ndarray`
Field number 1"""
Unique2DFloats.inverse.__doc__ = """`(n, d) np.ndarray`
Field number 2"""
Unique2DFloats.intersection.__doc__ = """`(m) list of list`
given original array's index, returns overlapping arrays, including itself.
Field number 3
"""
Unique2DIntegers = namedtuple(
"Unique2DIntegers", ["values", "ids", "inverse", "counts"]
)
Unique2DIntegers.__doc__ = """
namedtuple to hold unique information of integer type arrays.
Similar approach to Unique2DFloats.
"""
Unique2DIntegers.values.__doc__ = """`(n, d) np.ndarray`
Field number 0"""
Unique2DIntegers.ids.__doc__ = """`(n) np.ndarray`
Field number 1"""
Unique2DIntegers.inverse.__doc__ = """`(m) np.ndarray`
Field number 2"""
Unique2DIntegers.counts.__doc__ = """`(n) np.ndarray`
Field number 3"""
[docs]
class ComputedMeshData(ComputedData):
"""A class to hold computed-mesh-data.
Subclassed to keep its own dependency info.
"""
pass