Source code for sax.circuit

""" SAX Circuit Definition """

from __future__ import annotations

import os
import shutil
import sys
from functools import partial
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TypedDict, Union

import black
import networkx as nx
import numpy as np

from .backends import circuit_backends
from .netlist import (
    Netlist,
    NetlistDict,
    RecursiveNetlist,
    RecursiveNetlistDict,
    remove_unused_instances,
)
from .saxtypes import Model, Settings, SType, scoo, sdense, sdict
from .utils import (
    _replace_kwargs,
    get_ports,
    get_settings,
    merge_dicts,
    update_settings,
)

try:
    from pydantic.v1 import ValidationError  # type: ignore
except ImportError:
    from pydantic import ValidationError  # type: ignore


[docs]class CircuitInfo(NamedTuple): """Information about the circuit function you created.""" dag: nx.DiGraph models: Dict[str, Model]
[docs]def circuit( netlist: Union[Netlist, NetlistDict, RecursiveNetlist, RecursiveNetlistDict], models: Optional[Dict[str, Model]] = None, backend: str = "default", return_type: str = "sdict", ignore_missing_ports: bool = False, ) -> Tuple[Model, CircuitInfo]: """create a circuit function for a given netlist""" netlist = _ensure_recursive_netlist_dict(netlist) # TODO: do the following two steps *after* recursive netlist parsing. netlist = remove_unused_instances(netlist) netlist, instance_models = _extract_instance_models(netlist) recnet: RecursiveNetlist = _validate_net(netlist) # type: ignore dependency_dag: nx.DiGraph = _validate_dag(_create_dag(recnet, models)) models = _validate_models({**(models or {}), **instance_models}, dependency_dag) backend = _validate_circuit_backend(backend) circuit = None new_models = {} current_models = {} model_names = list(nx.topological_sort(dependency_dag))[::-1] for model_name in model_names: if model_name in models: new_models[model_name] = models[model_name] continue flatnet = recnet.__root__[model_name] current_models.update(new_models) new_models = {} current_models[model_name] = circuit = _flat_circuit( flatnet.instances, flatnet.connections, flatnet.ports, current_models, backend, ignore_missing_ports=ignore_missing_ports, ) assert circuit is not None circuit = _enforce_return_type(circuit, return_type) return circuit, CircuitInfo(dag=dependency_dag, models=current_models)
def _create_dag( netlist: RecursiveNetlist, models: Optional[Dict[str, Any]] = None, ): if models is None: models = {} assert isinstance(models, dict) all_models = {} g = nx.DiGraph() for model_name, subnetlist in netlist.dict()["__root__"].items(): if model_name not in all_models: all_models[model_name] = models.get(model_name, subnetlist) g.add_node(model_name) if model_name in models: continue for instance in subnetlist["instances"].values(): info = instance.get("info", {}) if info and "model" in info: component = info["model"] else: component = instance["component"] if component not in all_models: all_models[component] = models.get(component, None) g.add_node(component) g.add_edge(model_name, component) # we only need the nodes that depend on the parent... parent_node = next(iter(netlist.__root__.keys())) nodes = [parent_node, *nx.descendants(g, parent_node)] g = nx.induced_subgraph(g, nodes) return g def _draw_dag(dag, with_labels=True, **kwargs): _patch_path() if shutil.which("dot"): return nx.draw( dag, nx.nx_pydot.pydot_layout(dag, prog="dot"), with_labels=with_labels, **kwargs, ) else: return nx.draw(dag, _my_dag_pos(dag), with_labels=with_labels, **kwargs) def _patch_path(): os_paths = {p: None for p in os.environ.get("PATH", "").split(os.pathsep)} sys_paths = {p: None for p in sys.path} other_paths = {os.path.dirname(sys.executable): None} os.environ["PATH"] = os.pathsep.join({**os_paths, **sys_paths, **other_paths}) def _my_dag_pos(dag): # inferior to pydot in_degree = {} for k, v in dag.in_degree(): if v not in in_degree: in_degree[v] = [] in_degree[v].append(k) widths = {k: len(vs) for k, vs in in_degree.items()} width = max(widths.values()) horizontal_pos = { k: np.linspace(0, 1, w + 2)[1:-1] * width for k, w in widths.items() } pos = {} for k, vs in in_degree.items(): for x, v in zip(horizontal_pos[k], vs): pos[v] = (x, -k) return pos def _find_root(g): nodes = [n for n, d in g.in_degree() if d == 0] return nodes def _find_leaves(g): nodes = [n for n, d in g.out_degree() if d == 0] return nodes def _validate_models(models, dag): required_models = _find_leaves(dag) missing_models = [m for m in required_models if m not in models] if missing_models: model_diff = { "Missing Models": missing_models, "Given Models": list(models), "Required Models": required_models, } raise ValueError( "Missing models. The following models are still missing to build " f"the circuit:\n{black.format_str(repr(model_diff), mode=black.Mode())}" ) return {**models} # shallow copy def _flat_circuit( instances, connections, ports, models, backend, ignore_missing_ports=False ): analyze_insts_fn, analyze_fn, evaluate_fn = circuit_backends[backend] dummy_instances = analyze_insts_fn(instances, models) inst_port_mode = { k: _port_modes_dict(get_ports(s)) for k, s in dummy_instances.items() } connections = _get_multimode_connections( connections, inst_port_mode, ignore_missing_ports=ignore_missing_ports ) ports = _get_multimode_ports( ports, inst_port_mode, ignore_missing_ports=ignore_missing_ports ) inst2model = {} for k, inst in instances.items(): if inst.info and "model" in inst.info: inst2model[k] = models[str(inst.info["model"])] else: inst2model[k] = models[str(inst.component)] model_settings = {name: get_settings(model) for name, model in inst2model.items()} netlist_settings = { name: { k: v for k, v in (inst.settings or {}).items() if k in model_settings[name] } for name, inst in instances.items() } default_settings = merge_dicts(model_settings, netlist_settings) analyzed = analyze_fn(dummy_instances, connections, ports) def _circuit(**settings: Settings) -> SType: full_settings = merge_dicts(default_settings, settings) full_settings = _forward_global_settings(inst2model, full_settings) full_settings = merge_dicts(full_settings, settings) instances: Dict[str, SType] = {} for inst_name, model in inst2model.items(): instances[inst_name] = model(**full_settings.get(inst_name, {})) S = evaluate_fn(analyzed, instances) return S _replace_kwargs(_circuit, **default_settings) return _circuit def _forward_global_settings(instances, settings): global_settings = {} for k in list(settings.keys()): if k in instances: continue global_settings[k] = settings.pop(k) if global_settings: settings = update_settings(settings, **global_settings) return settings def _port_modes_dict(port_modes): result = {} for port_mode in port_modes: if "@" in port_mode: port, mode = port_mode.split("@") else: port, mode = port_mode, None if port not in result: result[port] = set() if mode is not None: result[port].add(mode) return result def _get_multimode_connections(connections, inst_port_mode, ignore_missing_ports=False): mm_connections = {} for inst_port1, inst_port2 in connections.items(): inst1, port1 = inst_port1.split(",") inst2, port2 = inst_port2.split(",") try: modes1 = inst_port_mode[inst1][port1] except KeyError: if ignore_missing_ports: continue raise RuntimeError( f"Instance {inst1} does not contain port {port1}. Available ports: {list(inst_port_mode[inst1])}." ) try: modes2 = inst_port_mode[inst2][port2] except KeyError: if ignore_missing_ports: continue raise RuntimeError( f"Instance {inst2} does not contain port {port2}. Available ports: {list(inst_port_mode[inst2])}." ) if not modes1 and not modes2: mm_connections[f"{inst1},{port1}"] = f"{inst2},{port2}" elif (not modes1) or (not modes2): raise ValueError( "trying to connect a multimode model to single mode model.\n" "Please update your models dictionary.\n" f"Problematic connection: '{inst_port1}':'{inst_port2}'" ) else: common_modes = modes1.intersection(modes2) for mode in sorted(common_modes): mm_connections[f"{inst1},{port1}@{mode}"] = f"{inst2},{port2}@{mode}" return mm_connections def _get_multimode_ports(ports, inst_port_mode, ignore_missing_ports=False): mm_ports = {} for port, inst_port2 in ports.items(): inst2, port2 = inst_port2.split(",") try: modes2 = inst_port_mode[inst2][port2] except KeyError: if ignore_missing_ports: continue raise RuntimeError( f"Instance {inst2} does not contain port {port2}. Available ports: {list(inst_port_mode[inst2])}" ) if not modes2: mm_ports[port] = f"{inst2},{port2}" else: for mode in sorted(modes2): mm_ports[f"{port}@{mode}"] = f"{inst2},{port2}@{mode}" return mm_ports def _enforce_return_type(model, return_type): stype_func = { "default": lambda x: x, "stype": lambda x: x, "sdict": sdict, "scoo": scoo, "sdense": sdense, }[return_type] return stype_func(model) def _ensure_recursive_netlist_dict(netlist): if not isinstance(netlist, dict): netlist = netlist.dict() if "__root__" in netlist: netlist = netlist["__root__"] if "instances" in netlist: netlist = {"top_level": netlist} netlist = {**netlist} for k, v in netlist.items(): netlist[k] = {**v} return netlist def _extract_instance_models(netlist): models = {} for netname, net in netlist.items(): net = {**net} net["instances"] = {**net["instances"]} for name, inst in net["instances"].items(): if callable(inst): settings = get_settings(inst) if isinstance(inst, partial) and inst.args: raise ValueError( "SAX circuits and netlists don't support partials " "with positional arguments." ) while isinstance(inst, partial): inst = inst.func models[inst.__name__] = inst net["instances"][name] = { "component": inst.__name__, "settings": settings, } netlist[netname] = net return netlist, models def _validate_circuit_backend(backend): backend = backend.lower() # assert valid circuit_backend if backend not in circuit_backends: raise KeyError( f"circuit backend {backend} not found. Allowed circuit backends: " f"{', '.join(circuit_backends.keys())}." ) return backend def _validate_net(netlist: Union[Netlist, RecursiveNetlist]) -> RecursiveNetlist: if isinstance(netlist, dict): try: netlist = Netlist.parse_obj(netlist) except ValidationError: netlist = RecursiveNetlist.parse_obj(netlist) elif isinstance(netlist, Netlist): netlist = RecursiveNetlist(__root__={"top_level": netlist}) return netlist def _validate_dag(dag): nodes = _find_root(dag) if len(nodes) > 1: raise ValueError(f"Multiple top_levels found in netlist: {nodes}") if len(nodes) < 1: raise ValueError("Netlist does not contain any nodes.") if not dag.is_directed(): raise ValueError("Netlist dependency cycles detected!") return dag
[docs]def get_required_circuit_models( netlist: Union[Netlist, NetlistDict, RecursiveNetlist, RecursiveNetlistDict], models: Optional[Dict[str, Model]] = None, ) -> List: """Figure out which models are needed for a given netlist""" if models is None: models = {} assert isinstance(models, dict) netlist = _ensure_recursive_netlist_dict(netlist) # TODO: do the following two steps *after* recursive netlist parsing. netlist = remove_unused_instances(netlist) netlist, _ = _extract_instance_models(netlist) recnet: RecursiveNetlist = _validate_net(netlist) # type: ignore missing_models = {} missing_model_names = [] g = nx.DiGraph() for model_name, subnetlist in recnet.dict()["__root__"].items(): if model_name not in missing_models: missing_models[model_name] = models.get(model_name, subnetlist) g.add_node(model_name) if model_name in models: continue for instance in subnetlist["instances"].values(): info = instance.get("info", {}) if info and "model" in info: component = info["model"] else: component = instance["component"] if (component not in missing_models) and (component not in models): missing_models[component] = models.get(component, None) missing_model_names.append(component) g.add_node(component) g.add_edge(model_name, component) return missing_model_names