Source code for sax.nn.loss

""" SAX Loss Functions """

from __future__ import annotations

from typing import Dict, cast

import jax.numpy as jnp

from ..saxtypes import ComplexArrayND


[docs]def mse(x: ComplexArrayND, y: ComplexArrayND) -> float: """mean squared error""" return cast(float, (abs(x - y) ** 2).mean())
[docs]def huber_loss(x: ComplexArrayND, y: ComplexArrayND, delta: float = 0.5) -> float: """huber loss""" return cast( float, ((delta**2) * ((1.0 + (abs(x - y) / delta) ** 2) ** 0.5 - 1.0)).mean() )
[docs]def l2_reg(weights: Dict[str, ComplexArrayND]) -> float: """L2 regularization loss""" numel = 0 loss = 0.0 for w in (v for k, v in weights.items() if k[0] in ("w", "b")): numel = numel + w.size loss = loss + (jnp.abs(w) ** 2).sum() return loss / numel