Skip to content
Snippets Groups Projects
operation.py 6.66 KiB
Newer Older
  • Learn to ignore specific revisions
  • """@package docstring
    
    B-ASIC Operation Module.
    TODO: More info.
    """
    
    from numbers import Number
    
    from typing import List, Dict, Optional, Any, Set, TYPE_CHECKING
    from collections import deque
    
    from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name
    from b_asic.simulation import SimulationState, OperationState
    from b_asic.signal import Signal
    
    if TYPE_CHECKING:
    
    Jacob Wahlman's avatar
    Jacob Wahlman committed
        from b_asic.port import InputPort, OutputPort
    
    Jacob Wahlman's avatar
    Jacob Wahlman committed
        """Operation interface.
        TODO: More info.
        """
    
        @abstractmethod
        def inputs(self) -> "List[InputPort]":
            """Get a list of all input ports."""
            raise NotImplementedError
    
        @abstractmethod
        def outputs(self) -> "List[OutputPort]":
            """Get a list of all output ports."""
            raise NotImplementedError
    
        @abstractmethod
        def input_count(self) -> int:
            """Get the number of input ports."""
            raise NotImplementedError
    
        @abstractmethod
        def output_count(self) -> int:
            """Get the number of output ports."""
            raise NotImplementedError
    
        @abstractmethod
        def input(self, i: int) -> "InputPort":
            """Get the input port at index i."""
            raise NotImplementedError
    
        @abstractmethod
        def output(self, i: int) -> "OutputPort":
            """Get the output port at index i."""
            raise NotImplementedError
    
        @abstractmethod
        def params(self) -> Dict[str, Optional[Any]]:
            """Get a dictionary of all parameter values."""
            raise NotImplementedError
    
        @abstractmethod
        def param(self, name: str) -> Optional[Any]:
            """Get the value of a parameter.
            Returns None if the parameter is not defined.
            """
            raise NotImplementedError
    
        @abstractmethod
        def set_param(self, name: str, value: Any) -> None:
            """Set the value of a parameter.
            The parameter must be defined.
            """
            raise NotImplementedError
    
        @abstractmethod
        def evaluate_outputs(self, state: "SimulationState") -> List[Number]:
            """Simulate the circuit until its iteration count matches that of the simulation state,
            then return the resulting output vector.
            """
            raise NotImplementedError
    
        @abstractmethod
        def split(self) -> "List[Operation]":
            """Split the operation into multiple operations.
            If splitting is not possible, this may return a list containing only the operation itself.
            """
            raise NotImplementedError
    
        @property
        @abstractmethod
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def neighbors(self) -> "List[Operation]":
    
    Jacob Wahlman's avatar
    Jacob Wahlman committed
            """Return all operations that are connected by signals to this operation.
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
            If no neighbors are found, this returns an empty list.
    
    Jacob Wahlman's avatar
    Jacob Wahlman committed
            """
            raise NotImplementedError
    
    
    
    class AbstractOperation(Operation, AbstractGraphComponent):
        """Generic abstract operation class which most implementations will derive from.
        TODO: More info.
        """
    
        _input_ports: List["InputPort"]
        _output_ports: List["OutputPort"]
        _parameters: Dict[str, Optional[Any]]
    
        def __init__(self, name: Name = ""):
            super().__init__(name)
            self._input_ports = []
            self._output_ports = []
            self._parameters = {}
    
        @abstractmethod
        def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ
            """Evaluate the operation and generate a list of output values given a
            list of input values."""
            raise NotImplementedError
    
        def inputs(self) -> List["InputPort"]:
            return self._input_ports.copy()
    
        def outputs(self) -> List["OutputPort"]:
            return self._output_ports.copy()
    
        def input_count(self) -> int:
            return len(self._input_ports)
    
        def output_count(self) -> int:
            return len(self._output_ports)
    
        def input(self, i: int) -> "InputPort":
            return self._input_ports[i]
    
        def output(self, i: int) -> "OutputPort":
            return self._output_ports[i]
    
        def params(self) -> Dict[str, Optional[Any]]:
            return self._parameters.copy()
    
        def param(self, name: str) -> Optional[Any]:
            return self._parameters.get(name)
    
        def set_param(self, name: str, value: Any) -> None:
            assert name in self._parameters # TODO: Error message.
            self._parameters[name] = value
    
        def evaluate_outputs(self, state: SimulationState) -> List[Number]:
            # TODO: Check implementation.
            input_count: int = self.input_count()
            output_count: int = self.output_count()
            assert input_count == len(self._input_ports) # TODO: Error message.
            assert output_count == len(self._output_ports) # TODO: Error message.
    
            self_state: OperationState = state.operation_states[self]
    
            while self_state.iteration < state.iteration:
                input_values: List[Number] = [0] * input_count
                for i in range(input_count):
                    source: Signal = self._input_ports[i].signal
                    input_values[i] = source.operation.evaluate_outputs(state)[source.port_index]
    
                self_state.output_values = self.evaluate(input_values)
                assert len(self_state.output_values) == output_count # TODO: Error message.
                self_state.iteration += 1
                for i in range(output_count):
                    for signal in self._output_ports[i].signals():
                        destination: Signal = signal.destination
                        destination.evaluate_outputs(state)
    
            return self_state.output_values
    
        def split(self) -> List[Operation]:
            # TODO: Check implementation.
            results = self.evaluate(self._input_ports)
            if all(isinstance(e, Operation) for e in results):
                return results
            return [self]
    
        @property
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
        def neighbors(self) -> List[Operation]:
            neighbors: List[Operation] = []
    
            for port in self._input_ports:
                for signal in port.signals:
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
                    neighbors.append(signal.source.operation)
    
    
            for port in self._output_ports:
                for signal in port.signals:
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
                    neighbors.append(signal.destination.operation)
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
            return neighbors
    
    
        def traverse(self) -> Operation:
            """Traverse the operation tree and return a generator with start point in the operation."""
            return self._breadth_first_search()
    
        def _breadth_first_search(self) -> Operation:
            """Use breadth first search to traverse the operation tree."""
            visited: Set[Operation] = {self}
            queue = deque([self])
            while queue:
                operation = queue.popleft()
                yield operation
    
    Ivar Härnqvist's avatar
    Ivar Härnqvist committed
                for n_operation in operation.neighbors:
    
                    if n_operation not in visited:
                        visited.add(n_operation)
                        queue.append(n_operation)