"""@package docstring B-ASIC Signal Flow Graph Module. TODO: More info. """ from typing import NewType, List, Iterable, Sequence, Dict, Optional, DefaultDict, Set from numbers import Number from collections import defaultdict from b_asic.port import SignalSourceProvider, OutputPort from b_asic.operation import Operation, AbstractOperation from b_asic.signal import Signal from b_asic.graph_component import GraphComponent, Name, TypeName from b_asic.special_operations import Input, Output GraphID = NewType("GraphID", str) GraphIDNumber = NewType("GraphIDNumber", int) class GraphIDGenerator: """A class that generates Graph IDs for objects.""" _next_id_number: DefaultDict[TypeName, GraphIDNumber] def __init__(self, id_number_offset: GraphIDNumber = 0): self._next_id_number = defaultdict(lambda: id_number_offset) def next_id(self, type_name: TypeName) -> GraphID: """Return the next graph id for a certain graph id type.""" self._next_id_number[type_name] += 1 return type_name + str(self._next_id_number[type_name]) class SFG(AbstractOperation): """Signal flow graph. TODO: More info. """ _components_by_id: Dict[GraphID, GraphComponent] _components_by_name: DefaultDict[Name, List[GraphComponent]] _graph_id_generator: GraphIDGenerator _input_operations: List[Input] _output_operations: List[Output] _original_components_added: Set[GraphComponent] _original_input_signals: Dict[Signal, int] _original_output_signals: Dict[Signal, int] def __init__(self, input_signals: Sequence[Signal] = [], output_signals: Sequence[Signal] = [], \ inputs: Sequence[Input] = [], outputs: Sequence[Output] = [], operations: Sequence[Operation] = [], \ id_number_offset: GraphIDNumber = 0, name: Name = "", \ input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None): super().__init__( input_count = len(input_signals) + len(inputs), output_count = len(output_signals) + len(outputs), name = name, input_sources = input_sources) self._components_by_id = dict() self._components_by_name = defaultdict(list) self._graph_id_generator = GraphIDGenerator(id_number_offset) self._input_operations = [] self._output_operations = [] self._original_components_added = set() self._original_input_signals = {} self._original_output_signals = {} # Setup input operations and signals. for i, s in enumerate(input_signals): self._input_operations.append(self._add_component_copy_unconnected(Input())) self._original_input_signals[s] = i for i, op in enumerate(inputs, len(input_signals)): self._input_operations.append(self._add_component_copy_unconnected(op)) for s in op.output(0).signals: self._original_input_signals[s] = i # Setup output operations and signals. for i, s in enumerate(output_signals): self._output_operations.append(self._add_component_copy_unconnected(Output())) self._original_output_signals[s] = i for i, op in enumerate(outputs, len(output_signals)): self._output_operations.append(self._add_component_copy_unconnected(op)) for s in op.input(0).signals: self._original_output_signals[s] = i # Search the graph inwards from each input signal. for s, i in self._original_input_signals.items(): if s.destination is None: raise ValueError(f"Input signal #{i} is missing destination in SFG") if s.destination.operation not in self._original_components_added: self._add_operation_copy_recursively(s.destination.operation) # Search the graph inwards from each output signal. for s, i in self._original_output_signals.items(): if s.source is None: raise ValueError(f"Output signal #{i} is missing source in SFG") if s.source.operation not in self._original_components_added: self._add_operation_copy_recursively(s.source.operation) # Search the graph outwards from each operation. for op in operations: if op not in self._original_components_added: self._add_operation_copy_recursively(op) @property def type_name(self) -> TypeName: return "sfg" def evaluate(self, *args): if len(args) != self.input_count: raise ValueError("Wrong number of inputs supplied to SFG for evaluation") for arg, op in zip(args, self._input_operations): op.value = arg result = [] for op in self._output_operations: result.append(self._evaluate_source(op.input(0).signals[0].source)) n = len(result) return None if n == 0 else result[0] if n == 1 else result def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]: assert i >= 0 and i < self.output_count, "Output index out of range" result = [None] * self.output_count result[i] = self._evaluate_source(self._output_operations[i].input(0).signals[0].source) return result def split(self) -> Iterable[Operation]: return filter(lambda comp: isinstance(comp, Operation), self._components_by_id.values()) @property def components(self) -> Iterable[GraphComponent]: """Get all components of this graph.""" return self._components_by_id.values() def find_by_id(self, graph_id: GraphID) -> Optional[GraphComponent]: """Find a graph object based on the entered Graph ID and return it. If no graph object with the entered ID was found then return None. Keyword arguments: graph_id: Graph ID of the wanted object. """ return self._components_by_id.get(graph_id, None) def find_by_name(self, name: Name) -> List[GraphComponent]: """Find all graph objects that have the entered name and return them in a list. If no graph object with the entered name was found then return an empty list. Keyword arguments: name: Name of the wanted object. """ return self._components_by_name.get(name, []) def _add_component_copy_unconnected(self, original_comp: GraphComponent) -> GraphComponent: assert original_comp not in self._original_components_added, "Tried to add duplicate SFG component" self._original_components_added.add(original_comp) new_comp = original_comp.copy_unconnected() self._components_by_id[self._graph_id_generator.next_id(new_comp.type_name)] = new_comp self._components_by_name[new_comp.name].append(new_comp) return new_comp def _add_operation_copy_recursively(self, original_op: Operation) -> Operation: # Add a copy of the operation without any connections. new_op = self._add_component_copy_unconnected(original_op) # Connect input ports. for original_input_port, new_input_port in zip(original_op.inputs, new_op.inputs): if original_input_port.signal_count < 1: raise ValueError("Unconnected input port in SFG") for original_signal in original_input_port.signals: if original_signal in self._original_input_signals: # Check if the signal is one of the SFG's input signals. new_signal = self._add_component_copy_unconnected(original_signal) new_signal.set_destination(new_input_port) new_signal.set_source(self._input_operations[self._original_input_signals[original_signal]].output(0)) elif original_signal not in self._original_components_added: # Only add the signal if it wasn't already added. new_signal = self._add_component_copy_unconnected(original_signal) new_signal.set_destination(new_input_port) if original_signal.source is None: raise ValueError("Dangling signal without source in SFG") # Recursively add the connected operation. new_connected_op = self._add_operation_copy_recursively(original_signal.source.operation) new_signal.set_source(new_connected_op.output(original_signal.source.index)) # Connect output ports. for original_output_port, new_output_port in zip(original_op.outputs, new_op.outputs): for original_signal in original_output_port.signals: if original_signal in self._original_output_signals: # Check if the signal is one of the SFG's output signals. new_signal = self._add_component_copy_unconnected(original_signal) new_signal.set_source(new_output_port) new_signal.set_destination(self._output_operations[self._original_output_signals[original_signal]].input(0)) elif original_signal not in self._original_components_added: # Only add the signal if it wasn't already added. new_signal = self._add_component_copy_unconnected(original_signal) new_signal.set_source(new_output_port) if original_signal.destination is None: raise ValueError("Dangling signal without destination in SFG") # Recursively add the connected operation. new_connected_op = self._add_operation_copy_recursively(original_signal.destination.operation) new_signal.set_destination(new_connected_op.input(original_signal.destination.index)) return new_op def _evaluate_source(self, src: OutputPort) -> Number: input_values = [] for input_port in src.operation.inputs: input_src = input_port.signals[0].source input_values.append(self._evaluate_source(input_src)) return src.operation.evaluate_output(src.index, input_values)