Source code for meow.base_model

""" Provides a custom pydantic BaseModel which handles numpy arrays better """

from functools import wraps
from hashlib import md5
from typing import Any, Tuple

import numpy as np
from pydantic.v1.main import BaseModel as _BaseModel
from pydantic.v1.main import ModelMetaclass, PrivateAttr

from .cache import cache_array, cache_model


class _array(np.ndarray):
    """just an immutable array with a nicer repr"""

    def __new__(cls, *args, **kwargs):
        arr = np.ndarray.__new__(cls, *args, **kwargs)
        arr.setflags(write=False)
        return arr

    def __hash__(self):
        return (
            int.from_bytes(md5(self.tobytes()).digest(), byteorder="big")
            % 1000000000000
        )

    def __repr__(self):
        if self.ndim == 0:
            return f"{self:.3e}"
        cls_str = f"array{{{'x'.join(str(i) for i in self.shape)}}}"
        self = self.ravel()
        num_els = self.shape[0]
        if num_els == 0:
            return f"{cls_str}([])"
        elif num_els == 1:
            return f"{cls_str}([{self[0]:.3e}])"
        elif num_els == 2:
            return f"{cls_str}([{self[0]:.3e}, {self[1]:.3e}])"
        elif num_els == 3:
            return f"{cls_str}([{self[0]:.3e}, {self[1]:.3e}, {self[2]:.3e}])"
        else:
            return f"{cls_str}([{self[0]:.3e}, {self[1]:.3e}, ..., {self[-1]:.3e}])"


def _array_cls(annot) -> type:
    def get_validators(cls) -> Any:
        yield cls.validate

    def validate(cls, x) -> _array:
        if isinstance(x, dict):
            r = np.asarray(x.get("real", 0.0), dtype=np.float_)
            i = np.asarray(x.get("imag", 0.0), dtype=np.float_)
            x = r + 1j * i

        shape, dtype = _parse_array_type_info(cls.annot)
        x = np.asarray(x, dtype=(None if dtype is Any else dtype))
        if shape is not Any:
            try:
                shape = np.broadcast_shapes(shape, x.shape)
            except ValueError:
                raise ValueError(
                    f"Invalid shape for attribute 'x': Expected: {shape}. Got: {x.shape}."
                )
            x = np.broadcast_to(x, shape)

        return x.view(_array)

    def modify_schema(cls, schema, field):
        schema["title"] = field.name.replace("_", " ").title()
        schema["default"] = "array"

    Array = type(
        "Array",
        (np.ndarray,),
        {
            "annot": annot,
            "__get_validators__": classmethod(get_validators),
            "validate": classmethod(validate),
            "__modify_schema__": classmethod(modify_schema),
        },
    )
    return Array


def _serialize_array(arr):
    if np.iscomplexobj(arr):
        return {
            "real": _serialize_array(np.real(arr)),
            "imag": _serialize_array(np.imag(arr)),
        }
    else:
        return np.round(arr, 16).tolist()


def _parse_array_type_info(annotation) -> Tuple[Any, Any]:
    try:
        type_info = annotation.__args__
    except AttributeError:
        type_info = tuple()
    if len(type_info) == 1:
        (shape,) = type_info
        dtype = Any
    elif len(type_info) == 2:
        shape, dtype = type_info
    else:
        shape, dtype = Any, Any
    try:
        (dtype,) = dtype.__args__  # type: ignore
    except Exception:
        dtype = Any

    try:
        shape = shape.__args__  # type: ignore
        shape = tuple(_try_parse_shape_int(i) for i in shape)
    except Exception:
        shape = Any

    return shape, dtype


def _try_parse_shape_int(value):
    from typing import _LiteralGenericAlias  # fmt: skip # type: ignore

    if isinstance(value, _LiteralGenericAlias):
        (value,) = value.__args__
    try:
        return int(value)
    except Exception:
        return 1


class _ModelMetaclass(ModelMetaclass):
    """A metaclass (used in our custom BaseModel) to handle numpy array type hints better.

    Generic numpy type hints (will) have the following syntax::

        arr: np.ndarray[Shape, DType]

    for example::

        arr: np.ndarray[Tuple[Literal[1], int], np.dtype[np.float64]] = np.array([[1, 2, 3]])

    When using this _ModelMetaclass, the BaseModel will automatically try to cast the arrays according to the given array type-hint.

    Moreover, the following config will be injected (added to the config class that might already be present)::

        Config:
            arbitrary_types_allowed = True
            json_encoders = {
                np.ndarray: lambda arr: np.round(arr, 16).tolist()
            }

    """

    def __new__(cls, name, bases, dct, **kwargs):
        # Inject sensible default Configuration (Config class)
        extra_config = {
            "allow_mutation": False,
            "arbitrary_types_allowed": True,
            "json_encoders": {
                np.ndarray: _serialize_array,
                _array: _serialize_array,
                complex: lambda value: {"real": np.real(value), "imag": np.imag(value)},
            },
        }
        if "Config" not in dct:
            dct["Config"] = type("Config", (), extra_config)
        else:
            config = {**dct["Config"].__dict__}
            config["arbitrary_types_allowed"] = True
            config["json_encoders"] = {
                **config.get("json_encoders", {}),
                **extra_config["json_encoders"],
            }
            dct["Config"] = type("Config", (dct["Config"],), config)

        # Enforce numpy array annotations
        annotations = dct.get("__annotations__", {})
        for attr, annot in annotations.items():
            if cls._is_array_annot(annot):
                annotations[attr] = _array_cls(annot)
            if annot is complex:
                annotations[attr] = _array_cls(np.ndarray[Any, np.dtype[np.complex_]])
        return super().__new__(cls, name, bases, dct, **kwargs)

    @staticmethod
    def _is_array_annot(annot):
        try:
            arr = annot(0)
        except Exception:
            return False
        return isinstance(arr, np.ndarray)


[docs]class BaseModel(_BaseModel, metaclass=_ModelMetaclass): """A customized pydantic base model that handles numpy array type hints""" _cache = PrivateAttr({})
[docs] def __init__(self, **kwargs): super().__init__(**kwargs) self.__dict__.update({k: _view_arrays(k, v) for k, v in self.__dict__.items()}) model = cache_model(self) if not model is self: object.__setattr__(self, "__dict__", model.__dict__)
def _repr(self, indent=0, shift=2): s = f"{self.__class__.__name__}(" dic = self.dict() if dic: s = s + "\n" for key in dic: if key == "data": continue attr = getattr(self, key) if isinstance(attr, BaseModel): attr_str = attr._repr(indent=indent + shift, shift=shift) else: attr_str = repr(attr) s += f"{' '*(indent + shift)}{key}={attr_str}\n" s += f"{' '*indent})" return s def _visualize(self): raise NotImplementedError( f"visualization for {self.__class__.__name__!r} not (yet) implemented." ) def __hash__(self): try: arr = np.frombuffer(md5(self.json().encode()).digest(), dtype=np.uint8)[-8:] idx = np.arange(arr.shape[0], dtype=np.int64)[::-1] return int(np.sum(arr * 255**idx)) except Exception: return None def __repr__(self): return self._repr() def __str__(self): return self._repr()
def _view_arrays(key, obj): if isinstance(obj, dict): return {k: _view_arrays(k, v) for k, v in obj.items()} elif isinstance(obj, np.ndarray) or isinstance(obj, _array): obj = obj.view(_array) # arbitrary: let's not spam the cache with this. if not key in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]: obj = cache_array(obj) return obj else: return obj
[docs]def cache(prop): prop_name = prop.__name__ @wraps(prop) def getter(self): stored_value = self._cache.get(prop_name) if stored_value is not None: return stored_value computed = prop(self) self._cache[prop_name] = computed return computed return getter
[docs]def cached_property(method): return property(cache(method))