""" SAX Additive Backend """
from __future__ import annotations
from typing import Any, Dict, Tuple
import jax.numpy as jnp
import networkx as nx
from ..netlist import Component
from ..saxtypes import Model, SDict, sdict
[docs]def analyze_instances_additive(
instances: Dict[str, Component],
models: Dict[str, Model],
) -> Dict[str, SDict]:
instances, instances_old = {}, instances
for k, v in instances_old.items():
if not isinstance(v, Component):
v = Component(**v)
instances[k] = v
model_names = set()
for i in instances.values():
if i.info and "model" in i.info and isinstance(i.info["model"], str):
model_names.add(str(i.info["model"]))
else:
model_names.add(str(i.component))
dummy_models = {k: sdict(models[k]()) for k in model_names}
dummy_instances = {}
for k, i in instances.items():
if i.info and "model" in i.info and isinstance(i.info["model"], str):
dummy_instances[k] = dummy_models[str(i.info["model"])]
else:
dummy_instances[k] = dummy_models[str(i.component)]
return dummy_instances
[docs]def analyze_circuit_additive(
analyzed_instances: Dict[str, SDict],
connections: Dict[str, str],
ports: Dict[str, str],
) -> Any:
return connections, ports
[docs]def evaluate_circuit_additive(
analyzed: Any,
instances: Dict[str, SDict],
) -> SDict:
"""evaluate a circuit for the given sdicts."""
connections, ports = analyzed
edges = _graph_edges(instances, connections, ports)
graph = nx.Graph()
graph.add_edges_from(edges)
_prune_internal_output_nodes(graph)
sdict = {}
for source in ports:
for target in ports:
paths = _get_possible_paths(graph, source=("", source), target=("", target))
if not paths:
continue
sdict[source, target] = _path_lengths(graph, paths)
return sdict
def _split_port(port: str) -> Tuple[str, str]:
try:
instance, port = port.split(",")
except ValueError:
(port,) = port.split(",")
instance = ""
return instance, port
def _graph_edges(
instances: Dict[str, SDict],
connections: Dict[str, str],
ports: Dict[str, str],
):
zero = jnp.array([0.0], dtype=float)
edges = {}
edges.update({_split_port(k): _split_port(v) for k, v in connections.items()})
edges.update({_split_port(v): _split_port(k) for k, v in connections.items()})
edges.update({_split_port(k): _split_port(v) for k, v in ports.items()})
edges.update({_split_port(v): _split_port(k) for k, v in ports.items()})
edges = [(n1, n2, {"type": "C", "length": zero}) for n1, n2 in edges.items()]
_instances = {
**{i1: None for (i1, _), (_, _), _ in edges},
**{i2: None for (_, _), (i2, _), _ in edges},
}
if "" in _instances:
del _instances[""] # external ports don't belong to an instance
for instance in _instances:
s = instances[instance]
edges += [
(
(instance, p1),
(instance, p2),
{"type": "S", "length": jnp.asarray(length, dtype=float).ravel()},
)
for (p1, p2), length in sdict(s).items()
]
return edges
def _prune_internal_output_nodes(graph):
broken = True
while broken:
broken = False
for (i, p), dic in list(graph.adjacency()):
if (
i != ""
and len(dic) == 2
and all(prop.get("type", "C") == "C" for prop in dic.values())
):
graph.remove_node((i, p))
graph.add_edge(*dic.keys(), type="C", length=0.0)
broken = True
break
return graph
def _get_possible_paths(graph, source, target):
paths = []
default_props = {"type": "C", "length": 0.0}
for path in nx.all_simple_edge_paths(graph, source, target):
prevtype = "C"
for n1, n2 in path:
curtype = graph.get_edge_data(n1, n2, default_props)["type"]
if curtype == prevtype == "S":
break
else:
prevtype = curtype
else:
paths.append(path)
return paths
def _path_lengths(graph, paths):
lengths = []
for path in paths:
length = zero = jnp.array([0.0], dtype=float)
default_edge_data = {"type": "C", "length": zero}
for edge in path:
edge_data = graph.get_edge_data(*edge, default_edge_data)
length = (length[None, :] + edge_data.get("length", zero)[:, None]).ravel()
lengths.append(length)
return lengths