Skip to content
Snippets Groups Projects
signal_flow_graph.py 9.94 KiB
Newer Older
  • Learn to ignore specific revisions
  • """@package docstring
    
    B-ASIC Signal Flow Graph Module.
    TODO: More info.
    """
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
    from typing import NewType, List, Iterable, Sequence, Dict, Optional, DefaultDict, Set
    from numbers import Number
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
    from b_asic.port import SignalSourceProvider, OutputPort
    from b_asic.operation import Operation, AbstractOperation
    
    Kevin Scott's avatar
    Kevin Scott committed
    from b_asic.signal import Signal
    
    from b_asic.graph_component import GraphComponent, Name, TypeName
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
    from b_asic.special_operations import Input, Output
    
    
    GraphID = NewType("GraphID", str)
    GraphIDNumber = NewType("GraphIDNumber", int)
    
    
    class GraphIDGenerator:
        """A class that generates Graph IDs for objects."""
    
        _next_id_number: DefaultDict[TypeName, GraphIDNumber]
    
        def __init__(self, id_number_offset: GraphIDNumber = 0):
            self._next_id_number = defaultdict(lambda: id_number_offset)
    
        def next_id(self, type_name: TypeName) -> GraphID:
            """Return the next graph id for a certain graph id type."""
            self._next_id_number[type_name] += 1
            return type_name + str(self._next_id_number[type_name])
    
    Jacob Wahlman's avatar
    Jacob Wahlman committed
        """Signal flow graph.
        TODO: More info.
        """
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        
        _components_by_id: Dict[GraphID, GraphComponent]
        _components_by_name: DefaultDict[Name, List[GraphComponent]]
    
    Jacob Wahlman's avatar
    Jacob Wahlman committed
        _graph_id_generator: GraphIDGenerator
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        _input_operations: List[Input]
        _output_operations: List[Output]
        _original_components_added: Set[GraphComponent]
        _original_input_signals: Dict[Signal, int]
        _original_output_signals: Dict[Signal, int]
    
        def __init__(self, input_signals: Sequence[Signal] = [], output_signals: Sequence[Signal] = [], \
                     inputs: Sequence[Input] = [], outputs: Sequence[Output] = [], operations: Sequence[Operation] = [], \
                     id_number_offset: GraphIDNumber = 0, name: Name = "", \
                     input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None):
            super().__init__(
                input_count = len(input_signals) + len(inputs),
                output_count = len(output_signals) + len(outputs),
                name = name,
                input_sources = input_sources)
    
            self._components_by_id = dict()
            self._components_by_name = defaultdict(list)
            self._graph_id_generator = GraphIDGenerator(id_number_offset)
            self._input_operations = []
            self._output_operations = []
            self._original_components_added = set()
            self._original_input_signals = {}
            self._original_output_signals = {}
    
            # Setup input operations and signals.
            for i, s in enumerate(input_signals):
                self._input_operations.append(self._add_component_copy_unconnected(Input()))
                self._original_input_signals[s] = i
            for i, op in enumerate(inputs, len(input_signals)):
                self._input_operations.append(self._add_component_copy_unconnected(op))
                for s in op.output(0).signals:
                    self._original_input_signals[s] = i
    
            # Setup output operations and signals.
            for i, s in enumerate(output_signals):
                self._output_operations.append(self._add_component_copy_unconnected(Output()))
                self._original_output_signals[s] = i
            for i, op in enumerate(outputs, len(output_signals)):
                self._output_operations.append(self._add_component_copy_unconnected(op))
                for s in op.input(0).signals:
                    self._original_output_signals[s] = i
            
            # Search the graph inwards from each input signal.
            for s, i in self._original_input_signals.items():
                if s.destination is None:
                    raise ValueError(f"Input signal #{i} is missing destination in SFG")
                if s.destination.operation not in self._original_components_added:
                    self._add_operation_copy_recursively(s.destination.operation)
    
            # Search the graph inwards from each output signal.
            for s, i in self._original_output_signals.items():
                if s.source is None:
                    raise ValueError(f"Output signal #{i} is missing source in SFG")
                if s.source.operation not in self._original_components_added:
                    self._add_operation_copy_recursively(s.source.operation)
    
            # Search the graph outwards from each operation.
            for op in operations:
                if op not in self._original_components_added:
                    self._add_operation_copy_recursively(op)
    
    Jacob Wahlman's avatar
    Jacob Wahlman committed
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        @property
        def type_name(self) -> TypeName:
            return "sfg"
    
    Jacob Wahlman's avatar
    Jacob Wahlman committed
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def evaluate(self, *args):
            if len(args) != self.input_count:
                raise ValueError("Wrong number of inputs supplied to SFG for evaluation")
            for arg, op in zip(args, self._input_operations):
                op.value = arg
            
            result = []
            for op in self._output_operations:
                result.append(self._evaluate_source(op.input(0).signals[0].source))
    
    Jacob Wahlman's avatar
    Jacob Wahlman committed
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
            n = len(result)
            return None if n == 0 else result[0] if n == 1 else result
    
    Jacob Wahlman's avatar
    Jacob Wahlman committed
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def evaluate_output(self, i: int, input_values: Sequence[Number]) -> Sequence[Optional[Number]]:
            assert i >= 0 and i < self.output_count, "Output index out of range"
            result = [None] * self.output_count
            result[i] = self._evaluate_source(self._output_operations[i].input(0).signals[0].source)
            return result
    
    Jacob Wahlman's avatar
    Jacob Wahlman committed
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def split(self) -> Iterable[Operation]:
            return filter(lambda comp: isinstance(comp, Operation), self._components_by_id.values())
    
    Jacob Wahlman's avatar
    Jacob Wahlman committed
    
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        @property
        def components(self) -> Iterable[GraphComponent]:
            """Get all components of this graph."""
            return self._components_by_id.values()
    
    Jacob Wahlman's avatar
    Jacob Wahlman committed
    
        def find_by_id(self, graph_id: GraphID) -> Optional[GraphComponent]:
    
            """Find a graph object based on the entered Graph ID and return it. If no graph
            object with the entered ID was found then return None.
    
    Jacob Wahlman's avatar
    Jacob Wahlman committed
    
            Keyword arguments:
            graph_id: Graph ID of the wanted object.
            """
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
            return self._components_by_id.get(graph_id, None)
    
    Jacob Wahlman's avatar
    Jacob Wahlman committed
    
        def find_by_name(self, name: Name) -> List[GraphComponent]:
    
            """Find all graph objects that have the entered name and return them
            in a list. If no graph object with the entered name was found then return an
    
    Jacob Wahlman's avatar
    Jacob Wahlman committed
            empty list.
    
            Keyword arguments:
            name: Name of the wanted object.
            """
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
            return self._components_by_name.get(name, [])
    
        def _add_component_copy_unconnected(self, original_comp: GraphComponent) -> GraphComponent:
            assert original_comp not in self._original_components_added, "Tried to add duplicate SFG component"
            self._original_components_added.add(original_comp)
    
            new_comp = original_comp.copy_unconnected()
            self._components_by_id[self._graph_id_generator.next_id(new_comp.type_name)] = new_comp
            self._components_by_name[new_comp.name].append(new_comp)
            return new_comp
    
        def _add_operation_copy_recursively(self, original_op: Operation) -> Operation:
            # Add a copy of the operation without any connections.
            new_op = self._add_component_copy_unconnected(original_op)
    
            # Connect input ports.
            for original_input_port, new_input_port in zip(original_op.inputs, new_op.inputs):
                if original_input_port.signal_count < 1:
                    raise ValueError("Unconnected input port in SFG")
                for original_signal in original_input_port.signals:
                    if original_signal in self._original_input_signals: # Check if the signal is one of the SFG's input signals.
                        new_signal = self._add_component_copy_unconnected(original_signal)
                        new_signal.set_destination(new_input_port)
                        new_signal.set_source(self._input_operations[self._original_input_signals[original_signal]].output(0))
                    elif original_signal not in self._original_components_added: # Only add the signal if it wasn't already added.
                        new_signal = self._add_component_copy_unconnected(original_signal)
                        new_signal.set_destination(new_input_port)
                        if original_signal.source is None:
                            raise ValueError("Dangling signal without source in SFG")
                        # Recursively add the connected operation.
                        new_connected_op = self._add_operation_copy_recursively(original_signal.source.operation)
                        new_signal.set_source(new_connected_op.output(original_signal.source.index))
    
            # Connect output ports.
            for original_output_port, new_output_port in zip(original_op.outputs, new_op.outputs):
                for original_signal in original_output_port.signals:
                    if original_signal in self._original_output_signals: # Check if the signal is one of the SFG's output signals.
                        new_signal = self._add_component_copy_unconnected(original_signal)
                        new_signal.set_source(new_output_port)
                        new_signal.set_destination(self._output_operations[self._original_output_signals[original_signal]].input(0))
                    elif original_signal not in self._original_components_added: # Only add the signal if it wasn't already added.
                        new_signal = self._add_component_copy_unconnected(original_signal)
                        new_signal.set_source(new_output_port)
                        if original_signal.destination is None:
                            raise ValueError("Dangling signal without destination in SFG")
                        # Recursively add the connected operation.
                        new_connected_op = self._add_operation_copy_recursively(original_signal.destination.operation)
                        new_signal.set_destination(new_connected_op.input(original_signal.destination.index))
    
            return new_op
        
        def _evaluate_source(self, src: OutputPort) -> Number:
            input_values = []
            for input_port in src.operation.inputs:
                input_src = input_port.signals[0].source
                input_values.append(self._evaluate_source(input_src))
            return src.operation.evaluate_output(src.index, input_values)