Skip to content
Snippets Groups Projects
abstract_operation.py 3.91 KiB
Newer Older
  • Learn to ignore specific revisions
  • """@package docstring
    
    TODO: More info.
    """
    
    
    from abc import abstractmethod
    
    from typing import List, Set, Dict, Optional, Any
    
    from numbers import Number
    
    
    from b_asic.port import InputPort, OutputPort
    
    Kevin Scott's avatar
    Kevin Scott committed
    from b_asic.signal import Signal
    
    from b_asic.simulation import SimulationState, OperationState
    
    from b_asic.utilities import breadth_first_search
    
    from b_asic.abstract_graph_component import AbstractGraphComponent
    
    class AbstractOperation(Operation, AbstractGraphComponent):
    
    Jacob Wahlman's avatar
    Jacob Wahlman committed
        """Generic abstract operation class which most implementations will derive from.
        TODO: More info.
        """
    
        _input_ports: List[InputPort]
        _output_ports: List[OutputPort]
        _parameters: Dict[str, Optional[Any]]
    
        def __init__(self, **kwds):
            super().__init__(**kwds)
            self._input_ports = []
            self._output_ports = []
            self._parameters = {}
    
        @abstractmethod
        def evaluate(self, inputs: list) -> list:
            """Evaluate the operation and generate a list of output values given a list of input values."""
            raise NotImplementedError
    
        def inputs(self) -> List[InputPort]:
            return self._input_ports.copy()
    
        def outputs(self) -> List[OutputPort]:
            return self._output_ports.copy()
    
        def input_count(self) -> int:
            return len(self._input_ports)
    
        def output_count(self) -> int:
            return len(self._output_ports)
    
        def input(self, i: int) -> InputPort:
            return self._input_ports[i]
    
        def output(self, i: int) -> OutputPort:
            return self._output_ports[i]
    
        def params(self) -> Dict[str, Optional[Any]]:
            return self._parameters.copy()
    
        def param(self, name: str) -> Optional[Any]:
            return self._parameters.get(name)
    
        def set_param(self, name: str, value: Any) -> None:
            assert name in self._parameters # TODO: Error message.
            self._parameters[name] = value
    
        def evaluate_outputs(self, state: SimulationState) -> List[Number]:
            # TODO: Check implementation.
            input_count: int = self.input_count()
            output_count: int = self.output_count()
            assert input_count == len(self._input_ports) # TODO: Error message.
            assert output_count == len(self._output_ports) # TODO: Error message.
    
            self_state: OperationState = state.operation_states[self]
    
            while self_state.iteration < state.iteration:
                input_values: List[Number] = [0] * input_count
                for i in range(input_count):
                    source: Signal = self._input_ports[i].signal
                    input_values[i] = source.operation.evaluate_outputs(state)[source.port_index]
    
                self_state.output_values = self.evaluate(input_values)
                assert len(self_state.output_values) == output_count # TODO: Error message.
                self_state.iteration += 1
                for i in range(output_count):
                    for signal in self._output_ports[i].signals():
                        destination: Signal = signal.destination
                        destination.evaluate_outputs(state)
    
            return self_state.output_values
    
        def split(self) -> List[Operation]:
            # TODO: Check implementation.
            results = self.evaluate(self._input_ports)
            if all(isinstance(e, Operation) for e in results):
                return results
            return [self]
    
        @property
        def neighbours(self) -> List[Operation]:
            neighbours: List[Operation] = []
            for port in self._input_ports:
                for signal in port.signals:
                    neighbours.append(signal.source.operation)
    
            for port in self._output_ports:
                for signal in port.signals:
                    neighbours.append(signal.destination.operation)
    
            return neighbours
    
        def traverse(self) -> Operation:
            """Traverse the operation tree and return a generator with start point in the operation."""
            return breadth_first_search(self)
    
        # TODO: More stuff.