""" Visualizations for common meow-datatypes """
from collections.abc import Iterable
from typing import Any, Callable
import numpy as np
from meow.structures import _visualize_structures
try:
import matplotlib.pyplot as plt # fmt: skip
except ImportError:
plt = None
try:
import gdsfactory as gf # fmt: skip
except ImportError:
gf = None
try:
from jaxlib.xla_extension import DeviceArray # fmt: skip # type: ignore
except ImportError:
DeviceArray = None
def _visualize_s_matrix(S, fmt=None, title=None, show=True, phase=False, ax=None):
import matplotlib.pyplot as plt # fmt: skip
if phase:
fmt = ".0f"
else:
fmt = ".3f"
Z = np.abs(S)
y, x = np.arange(Z.shape[0])[::-1], np.arange(Z.shape[1])
Y, X = np.meshgrid(y, x)
if ax:
plt.sca(ax)
else:
plt.figure(figsize=(2 * x.shape[0] / 3, 2 * y.shape[0] / 3))
plt.pcolormesh(X, Y, Z[::-1].T, cmap="Greys", vmin=0.0, vmax=2.0 * Z.max())
coords_ = np.concatenate(
[np.array([x[0] - 1]), x, np.array([x[-1] + 1])], axis=0, dtype=float
)
labels = ["" for _ in coords_]
plt.xticks(coords_ + 0.5, labels)
plt.xlim(coords_[0] + 0.5, coords_[-1] - 0.5)
coords_ = np.concatenate(
[np.array([y[0] + 1]), y, np.array([y[-1] - 1])], axis=0, dtype=float
)
coords_ = coords_[::-1] # reverse
labels = ["" for _ in coords_]
plt.yticks(coords_ + 0.5, labels)
plt.ylim(coords_[-1] - 0.5, coords_[0] + 0.5)
plt.grid(True)
for x, y, z in zip(X.ravel(), Y.ravel(), S[::-1].T.ravel()):
if np.abs(z) > 0.0005:
if phase:
z = np.angle(z) * 180 / np.pi
text = eval(f"f'{{z:{fmt}}}'") # 😅
text = text.replace("+", "\n+")
text = text.replace("-", "\n-")
if text[0] == "\n":
text = text[1:]
plt.text(x, y, text, ha="center", va="center", fontsize=8)
if title is not None:
plt.title(title)
if show:
plt.show()
def _visualize_s_pm_matrix(Spm, fmt=None, title=None, show=True, phase=False, ax=None):
import matplotlib.pyplot as plt # fmt: skip
S, pm = Spm
_visualize_s_matrix(S, fmt=fmt, title=title, show=False, phase=phase, ax=ax)
num_left = len([p for p in pm if "left" in p])
Z = np.abs(S)
_, x = np.arange(Z.shape[0])[::-1], np.arange(Z.shape[1])
plt.axvline(x[num_left] - 0.5, color="red")
plt.axhline(x[num_left] - 0.5, color="red")
if show:
plt.show()
def _visualize_overlap_density(
two_modes,
conjugated=True,
x_symmetry=False,
y_symmetry=False,
ax=None,
n_cmap=None,
mode_cmap=None,
num_levels=8,
show=True,
):
import matplotlib.pyplot as plt # fmt: skip
from .mode import Mode # fmt: skip
mode1, mode2 = two_modes
if conjugated:
cross = mode1.Ex * mode2.Hy.conj() - mode1.Ey * mode2.Hx.conj()
else:
cross = mode1.Ex * mode2.Hy - mode1.Ey * mode2.Hx
if x_symmetry:
cross = 0.5 * (cross + cross[::-1])
if y_symmetry:
cross = 0.5 * (cross + cross[:, ::-1])
zeros = np.zeros_like(cross)
overlap = Mode(
neff=mode1.neff,
cs=mode1.cs,
Ex=cross,
Ey=zeros,
Ez=zeros,
Hx=zeros,
Hy=zeros,
Hz=zeros,
)
if ax is None:
W, H = _figsize_visualize_mode(mode1.cs, 5)
_, ax = plt.subplots(1, 3, figsize=(3 * W, H))
field = "Ex" if mode1.te_fraction > 0.5 else "Hx"
mode1._visualize(
title=f"mode 1: {field}",
fields=[field],
ax=ax[0],
n_cmap=n_cmap,
mode_cmap=mode_cmap,
num_levels=num_levels,
show=False,
)
field = "Ex" if mode2.te_fraction > 0.5 else "Hx"
mode2._visualize(
title=f"mode 2: {field}",
fields=[field],
ax=ax[1],
n_cmap=n_cmap,
mode_cmap=mode_cmap,
num_levels=num_levels,
show=False,
)
title = "overlap density" + ("" if conjugated else " (no conjugations)")
p = overlap._visualize(
title=title,
fields=["Ex"],
ax=ax[2],
n_cmap=n_cmap,
mode_cmap=mode_cmap,
num_levels=num_levels,
show=False,
)
if show:
plt.show()
return p
def _visualize_gf_component(comp):
import gdsfactory as gf # fmt: skip
gf.plot(comp) # type: ignore
def _is_two_tuple(obj):
if not isinstance(obj, tuple):
return False
try:
x, y = obj # type: ignore
return True
except Exception:
return False
def _figsize_visualize_mode(cs, W0):
x_min, x_max = cs.mesh.x.min(), cs.mesh.x.max()
y_min, y_max = cs.mesh.y.min(), cs.mesh.y.max()
delta_x = x_max - x_min
delta_y = y_max - y_min
aspect = delta_y / delta_x
W, H = W0 + 1, W0 * aspect + 1
return W, H
def _visualize_modes(
modes,
n_cmap=None,
mode_cmap=None,
num_levels=8,
operation=lambda x: np.abs(x) ** 2,
show=True,
plot_width=6.4,
fields=("Ex", "Hx"),
ax=None,
):
import matplotlib.pyplot as plt # fmt: skip
num_modes = len(modes)
cs = modes[0].cs
W, H = _figsize_visualize_mode(cs, plot_width)
if ax is None:
fig, ax = plt.subplots(
num_modes,
2,
figsize=(2 * W, num_modes * H),
sharex=True,
sharey=True,
squeeze=False,
)
else:
fig = None
for i, m in enumerate(modes):
m._visualize(
title=None,
title_prefix=f"m{i}: ",
fields=list(fields),
ax=ax[i],
n_cmap=n_cmap,
mode_cmap=mode_cmap,
num_levels=num_levels,
operation=operation,
show=False,
)
if fig is not None:
fig.subplots_adjust(hspace=0, wspace=2 / (2 * W))
if show:
plt.show()
def _visualize_base_model(obj, **kwargs):
return obj._visualize(**kwargs)
def _is_mode_list(obj: Any) -> bool:
from .mode import Mode # fmt: skip
return isinstance(obj, Iterable) and all(isinstance(o, Mode) for o in obj)
def _is_structure_3d_list(obj: Any) -> bool:
from .structures import Structure3D # fmt: skip
return isinstance(obj, Iterable) and all(isinstance(o, Structure3D) for o in obj)
def _is_base_model(obj: Any) -> bool:
from .base_model import BaseModel # fmt: skip
return isinstance(obj, BaseModel)
def _is_mode_overlap(obj: Any) -> bool:
from .mode import Mode # fmt: skip
return _is_two_tuple(obj) and all(isinstance(o, Mode) for o in obj)
def _is_s_matrix(obj: Any):
return (
(
isinstance(obj, np.ndarray)
or (DeviceArray is not None and isinstance(obj, DeviceArray))
)
and obj.ndim == 2
and obj.shape[0] > 1
and obj.shape[1] > 1
)
def _is_s_pm_matrix(obj):
return _is_two_tuple(obj) and _is_s_matrix(obj[0]) and isinstance(obj[1], dict)
def _is_gf_component(obj):
return gf is not None and isinstance(obj, gf.Component)
VISUALIZATION_MAPPING: dict[Callable, Callable] = {
_is_base_model: _visualize_base_model,
_is_mode_list: _visualize_modes,
_is_structure_3d_list: _visualize_structures,
_is_mode_overlap: _visualize_overlap_density,
_is_s_matrix: _visualize_s_matrix,
_is_s_pm_matrix: _visualize_s_pm_matrix,
_is_gf_component: _visualize_gf_component,
}
[docs]def visualize(obj: Any, **kwargs: Any):
"""visualize any meow object
Args:
obj: the meow object to visualize
**kwargs: extra configuration to visualize the object
Note:
Most meow objects have a `._visualize` method.
Check out its help to see which kwargs are accepted.
"""
if plt is None:
print(obj)
return
for check_func, vis_func in VISUALIZATION_MAPPING.items():
if check_func(obj):
return vis_func(obj, **kwargs)
print(obj)
vis = visualize # shorthand for visualize