Newer
Older
B-ASIC Signal Flow Graph Module.
TODO: More info.
"""
from typing import NewType, List, Iterable, Sequence, Dict, Optional, DefaultDict, Set
from numbers import Number
Angus Lothian
committed
from collections import defaultdict
from b_asic.port import SignalSourceProvider, OutputPort
from b_asic.operation import Operation, AbstractOperation
Angus Lothian
committed
from b_asic.graph_component import GraphComponent, Name, TypeName
from b_asic.special_operations import Input, Output
GraphID = NewType("GraphID", str)
GraphIDNumber = NewType("GraphIDNumber", int)
class GraphIDGenerator:
"""A class that generates Graph IDs for objects."""
_next_id_number: DefaultDict[TypeName, GraphIDNumber]
def __init__(self, id_number_offset: GraphIDNumber = 0):
self._next_id_number = defaultdict(lambda: id_number_offset)
def next_id(self, type_name: TypeName) -> GraphID:
"""Return the next graph id for a certain graph id type."""
self._next_id_number[type_name] += 1
return type_name + str(self._next_id_number[type_name])
Angus Lothian
committed
class SFG(AbstractOperation):
_components_by_id: Dict[GraphID, GraphComponent]
_components_by_name: DefaultDict[Name, List[GraphComponent]]
_input_operations: List[Input]
_output_operations: List[Output]
_original_components_added: Set[GraphComponent]
_original_input_signals: Dict[Signal, int]
_original_output_signals: Dict[Signal, int]
def __init__(self, input_signals: Sequence[Signal] = [], output_signals: Sequence[Signal] = [],
inputs: Sequence[Input] = [], outputs: Sequence[Output] = [], operations: Sequence[Operation] = [],
id_number_offset: GraphIDNumber = 0, name: Name = "",
input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None):
super().__init__(
input_count=len(input_signals) + len(inputs),
output_count=len(output_signals) + len(outputs),
name=name,
input_sources=input_sources)
self._components_by_id = dict()
self._components_by_name = defaultdict(list)
self._graph_id_generator = GraphIDGenerator(id_number_offset)
self._input_operations = []
self._output_operations = []
self._original_components_added = set()
self._original_input_signals = {}
self._original_output_signals = {}
# Setup input operations and signals.
for i, s in enumerate(input_signals):
self._input_operations.append(
self._add_component_copy_unconnected(Input()))
self._original_input_signals[s] = i
for i, op in enumerate(inputs, len(input_signals)):
self._input_operations.append(
self._add_component_copy_unconnected(op))
for s in op.output(0).signals:
self._original_input_signals[s] = i
# Setup output operations and signals.
for i, s in enumerate(output_signals):
self._output_operations.append(
self._add_component_copy_unconnected(Output()))
self._original_output_signals[s] = i
for i, op in enumerate(outputs, len(output_signals)):
self._output_operations.append(
self._add_component_copy_unconnected(op))
for s in op.input(0).signals:
self._original_output_signals[s] = i
# Search the graph inwards from each input signal.
for s, i in self._original_input_signals.items():
if s.destination is None:
raise ValueError(
f"Input signal #{i} is missing destination in SFG")
if s.destination.operation not in self._original_components_added:
self._add_operation_copy_recursively(s.destination.operation)
# Search the graph inwards from each output signal.
for s, i in self._original_output_signals.items():
if s.source is None:
raise ValueError(
f"Output signal #{i} is missing source in SFG")
if s.source.operation not in self._original_components_added:
self._add_operation_copy_recursively(s.source.operation)
# Search the graph outwards from each operation.
for op in operations:
if op not in self._original_components_added:
self._add_operation_copy_recursively(op)
@property
def type_name(self) -> TypeName:
return "sfg"
def evaluate(self, *args):
if len(args) != self.input_count:
raise ValueError(
"Wrong number of inputs supplied to SFG for evaluation")
for arg, op in zip(args, self._input_operations):
op.value = arg
result = []
for op in self._output_operations:
result.append(self._evaluate_source(op.input(0).signals[0].source))
n = len(result)
return None if n == 0 else result[0] if n == 1 else result
def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]:
assert i >= 0 and i < self.output_count, "Output index out of range"
result = [None] * self.output_count
result[i] = self._evaluate_source(
self._output_operations[i].input(0).signals[0].source)
def split(self) -> Iterable[Operation]:
return filter(lambda comp: isinstance(comp, Operation), self._components_by_id.values())
@property
def components(self) -> Iterable[GraphComponent]:
"""Get all components of this graph."""
return self._components_by_id.values()
def find_by_id(self, graph_id: GraphID) -> Optional[GraphComponent]:
Angus Lothian
committed
"""Find a graph object based on the entered Graph ID and return it. If no graph
object with the entered ID was found then return None.
Keyword arguments:
graph_id: Graph ID of the wanted object.
"""
def find_by_name(self, name: Name) -> List[GraphComponent]:
Angus Lothian
committed
"""Find all graph objects that have the entered name and return them
in a list. If no graph object with the entered name was found then return an
empty list.
Keyword arguments:
name: Name of the wanted object.
"""
return self._components_by_name.get(name, [])
def _add_component_copy_unconnected(self, original_comp: GraphComponent) -> GraphComponent:
assert original_comp not in self._original_components_added, "Tried to add duplicate SFG component"
self._original_components_added.add(original_comp)
new_comp = original_comp.copy_unconnected()
self._components_by_id[self._graph_id_generator.next_id(
new_comp.type_name)] = new_comp
self._components_by_name[new_comp.name].append(new_comp)
return new_comp
def _add_operation_copy_recursively(self, original_op: Operation) -> Operation:
# Add a copy of the operation without any connections.
new_op = self._add_component_copy_unconnected(original_op)
# Connect input ports.
for original_input_port, new_input_port in zip(original_op.inputs, new_op.inputs):
if original_input_port.signal_count < 1:
raise ValueError("Unconnected input port in SFG")
for original_signal in original_input_port.signals:
new_signal = self._add_component_copy_unconnected(
original_signal)
new_signal.set_destination(new_input_port)
# Check if the signal is one of the SFG's input signals.
if original_signal in self._original_input_signals:
new_signal.set_source(
self._input_operations[self._original_input_signals[original_signal]].output(0))
# Only add the signal if it wasn't already added.
elif original_signal not in self._original_components_added:
raise ValueError(
"Dangling signal without source in SFG")
new_connected_op = self._add_operation_copy_recursively(
original_signal.source.operation)
new_signal.set_source(new_connected_op.output(
original_signal.source.index))
# Connect output ports.
for original_output_port, new_output_port in zip(original_op.outputs, new_op.outputs):
for original_signal in original_output_port.signals:
new_signal = self._add_component_copy_unconnected(
original_signal)
new_signal.set_source(new_output_port)
# Check if the signal is one of the SFG's output signals.
if original_signal in self._original_output_signals:
new_signal.set_destination(
self._output_operations[self._original_output_signals[original_signal]].input(0))
# Only add the signal if it wasn't already added.
elif original_signal not in self._original_components_added:
new_signal = self._add_component_copy_unconnected(
original_signal)
new_signal.set_source(new_output_port)
if original_signal.destination is None:
raise ValueError(
"Dangling signal without destination in SFG")
new_connected_op = self._add_operation_copy_recursively(
original_signal.destination.operation)
new_signal.set_destination(new_connected_op.input(
original_signal.destination.index))
def _evaluate_source(self, src: OutputPort) -> Number:
input_values = []
for input_port in src.operation.inputs:
input_src = input_port.signals[0].source
input_values.append(self._evaluate_source(input_src))
return src.operation.evaluate_output(src.index, input_values)