From 8b7975190baae9b496d4055c5227d314bddfe2f0 Mon Sep 17 00:00:00 2001 From: Angus Lothian <anglo547@student.liu.se> Date: Wed, 18 Mar 2020 20:36:36 +0100 Subject: [PATCH] Solve pull request comments and change so evaluate function in SFG uses the same interface as the abstract evaluate function --- b_asic/__init__.py | 3 - b_asic/abstract_graph_component.py | 26 --- b_asic/abstract_operation.py | 117 ----------- b_asic/core_operations.py | 17 +- b_asic/graph_component.py | 20 ++ b_asic/operation.py | 120 ++++++++++- b_asic/port.py | 192 ++++++++++-------- b_asic/signal.py | 37 ++-- b_asic/signal_flow_graph.py | 4 +- b_asic/utilities.py | 21 -- .../test_core_operations.py} | 32 +-- test/fixtures/operation_tree.py | 12 +- test/port/test_inputport.py | 27 ++- test/port/test_outputport.py | 31 ++- test/signal/test_signal.py | 12 +- test/traverse/test_traverse_tree.py | 6 +- 16 files changed, 347 insertions(+), 330 deletions(-) delete mode 100644 b_asic/abstract_graph_component.py delete mode 100644 b_asic/abstract_operation.py delete mode 100644 b_asic/utilities.py rename test/{basic_operations/test_basic_operations.py => core_operations/test_core_operations.py} (94%) diff --git a/b_asic/__init__.py b/b_asic/__init__.py index fc787edf..7e40ad52 100644 --- a/b_asic/__init__.py +++ b/b_asic/__init__.py @@ -2,8 +2,6 @@ Better ASIC Toolbox. TODO: More info. """ -from b_asic.abstract_graph_component import * -from b_asic.abstract_operation import * from b_asic.core_operations import * from b_asic.graph_component import * from b_asic.graph_id import * @@ -14,4 +12,3 @@ from b_asic.schema import * from b_asic.signal_flow_graph import * from b_asic.signal import * from b_asic.simulation import * -from b_asic.utilities import * diff --git a/b_asic/abstract_graph_component.py b/b_asic/abstract_graph_component.py deleted file mode 100644 index a0b71b41..00000000 --- a/b_asic/abstract_graph_component.py +++ /dev/null @@ -1,26 +0,0 @@ -"""@package docstring -B-ASIC module for Graph Components of a signal flow graph. -TODO: More info. -""" - -from b_asic.graph_component import GraphComponent, Name - - -class AbstractGraphComponent(GraphComponent): - """Abstract Graph Component class which is a component of a signal flow graph. - - TODO: More info. - """ - - _name: Name - - def __init__(self, name: Name = ""): - self._name = name - - @property - def name(self) -> Name: - return self._name - - @name.setter - def name(self, name: Name) -> None: - self._name = name diff --git a/b_asic/abstract_operation.py b/b_asic/abstract_operation.py deleted file mode 100644 index fc3a9205..00000000 --- a/b_asic/abstract_operation.py +++ /dev/null @@ -1,117 +0,0 @@ -"""@package docstring -B-ASIC Abstract Operation Module. -TODO: More info. -""" - -from abc import abstractmethod -from typing import List, Dict, Optional, Any -from numbers import Number - -from b_asic.port import InputPort, OutputPort -from b_asic.signal import Signal -from b_asic.operation import Operation -from b_asic.simulation import SimulationState, OperationState -from b_asic.utilities import breadth_first_search -from b_asic.abstract_graph_component import AbstractGraphComponent -from b_asic.graph_component import Name - - -class AbstractOperation(Operation, AbstractGraphComponent): - """Generic abstract operation class which most implementations will derive from. - TODO: More info. - """ - - _input_ports: List[InputPort] - _output_ports: List[OutputPort] - _parameters: Dict[str, Optional[Any]] - - def __init__(self, name: Name = ""): - super().__init__(name) - self._input_ports = [] - self._output_ports = [] - self._parameters = {} - - @abstractmethod - def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ - """Evaluate the operation and generate a list of output values given a - list of input values.""" - raise NotImplementedError - - def inputs(self) -> List[InputPort]: - return self._input_ports.copy() - - def outputs(self) -> List[OutputPort]: - return self._output_ports.copy() - - def input_count(self) -> int: - return len(self._input_ports) - - def output_count(self) -> int: - return len(self._output_ports) - - def input(self, i: int) -> InputPort: - return self._input_ports[i] - - def output(self, i: int) -> OutputPort: - return self._output_ports[i] - - def params(self) -> Dict[str, Optional[Any]]: - return self._parameters.copy() - - def param(self, name: str) -> Optional[Any]: - return self._parameters.get(name) - - def set_param(self, name: str, value: Any) -> None: - assert name in self._parameters # TODO: Error message. - self._parameters[name] = value - - def evaluate_outputs(self, state: SimulationState) -> List[Number]: - # TODO: Check implementation. - input_count: int = self.input_count() - output_count: int = self.output_count() - assert input_count == len(self._input_ports) # TODO: Error message. - assert output_count == len(self._output_ports) # TODO: Error message. - - self_state: OperationState = state.operation_states[self] - - while self_state.iteration < state.iteration: - input_values: List[Number] = [0] * input_count - for i in range(input_count): - source: Signal = self._input_ports[i].signal - input_values[i] = source.operation.evaluate_outputs(state)[source.port_index] - - self_state.output_values = self.evaluate(input_values) - assert len(self_state.output_values) == output_count # TODO: Error message. - self_state.iteration += 1 - for i in range(output_count): - for signal in self._output_ports[i].signals(): - destination: Signal = signal.destination - destination.evaluate_outputs(state) - - return self_state.output_values - - def split(self) -> List[Operation]: - # TODO: Check implementation. - results = self.evaluate(self._input_ports) - if all(isinstance(e, Operation) for e in results): - return results - return [self] - - @property - def neighbours(self) -> List[Operation]: - neighbours: List[Operation] = [] - for port in self._input_ports: - for signal in port.signals: - neighbours.append(signal.source.operation) - - for port in self._output_ports: - for signal in port.signals: - neighbours.append(signal.destination.operation) - - return neighbours - - def traverse(self) -> Operation: - """Traverse the operation tree and return a generator with start point in the operation.""" - return breadth_first_search(self) - - # TODO: More stuff. diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index f64c63db..ce1019f3 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -8,9 +8,10 @@ from typing import Any from numpy import conjugate, sqrt, abs as np_abs from b_asic.port import InputPort, OutputPort from b_asic.graph_id import GraphIDType -from b_asic.abstract_operation import AbstractOperation +from b_asic.operation import AbstractOperation from b_asic.graph_component import Name, TypeName + class Input(AbstractOperation): """Input operation. TODO: More info. @@ -23,7 +24,6 @@ class Input(AbstractOperation): return "in" - class Constant(AbstractOperation): """Constant value operation. TODO: More info. @@ -43,7 +43,6 @@ class Constant(AbstractOperation): return "c" - class Addition(AbstractOperation): """Binary addition operation. TODO: More info. @@ -52,7 +51,7 @@ class Addition(AbstractOperation): def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""): super().__init__(name) - self._input_ports = [InputPort(0, self), InputPort(0, self)] + self._input_ports = [InputPort(0, self), InputPort(1, self)] self._output_ports = [OutputPort(0, self)] if source1 is not None: @@ -75,7 +74,7 @@ class Subtraction(AbstractOperation): def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""): super().__init__(name) - self._input_ports = [InputPort(0, self), InputPort(0, self)] + self._input_ports = [InputPort(0, self), InputPort(1, self)] self._output_ports = [OutputPort(0, self)] if source1 is not None: @@ -98,7 +97,7 @@ class Multiplication(AbstractOperation): def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""): super().__init__(name) - self._input_ports = [InputPort(0, self), InputPort(0, self)] + self._input_ports = [InputPort(0, self), InputPort(1, self)] self._output_ports = [OutputPort(0, self)] if source1 is not None: @@ -121,7 +120,7 @@ class Division(AbstractOperation): def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""): super().__init__(name) - self._input_ports = [InputPort(0, self), InputPort(0, self)] + self._input_ports = [InputPort(0, self), InputPort(1, self)] self._output_ports = [OutputPort(0, self)] if source1 is not None: @@ -188,7 +187,7 @@ class Max(AbstractOperation): def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""): super().__init__(name) - self._input_ports = [InputPort(0, self), InputPort(0, self)] + self._input_ports = [InputPort(0, self), InputPort(1, self)] self._output_ports = [OutputPort(0, self)] if source1 is not None: @@ -213,7 +212,7 @@ class Min(AbstractOperation): def __init__(self, source1: OutputPort = None, source2: OutputPort = None, name: Name = ""): super().__init__(name) - self._input_ports = [InputPort(0, self), InputPort(0, self)] + self._input_ports = [InputPort(0, self), InputPort(1, self)] self._output_ports = [OutputPort(0, self)] if source1 is not None: diff --git a/b_asic/graph_component.py b/b_asic/graph_component.py index ec39a28d..1987d449 100644 --- a/b_asic/graph_component.py +++ b/b_asic/graph_component.py @@ -32,3 +32,23 @@ class GraphComponent(ABC): def name(self, name: Name) -> None: """Set the name of the graph component to the entered name.""" raise NotImplementedError + + +class AbstractGraphComponent(GraphComponent): + """Abstract Graph Component class which is a component of a signal flow graph. + + TODO: More info. + """ + + _name: Name + + def __init__(self, name: Name = ""): + self._name = name + + @property + def name(self) -> Name: + return self._name + + @name.setter + def name(self, name: Name) -> None: + self._name = name diff --git a/b_asic/operation.py b/b_asic/operation.py index acd26672..75644b73 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -5,13 +5,16 @@ TODO: More info. from abc import abstractmethod from numbers import Number -from typing import List, Dict, Optional, Any, TYPE_CHECKING +from typing import List, Dict, Optional, Any, Set, TYPE_CHECKING +from collections import deque -from b_asic.graph_component import GraphComponent +from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name +from b_asic.simulation import SimulationState, OperationState +from b_asic.signal import Signal if TYPE_CHECKING: from b_asic.port import InputPort, OutputPort - from b_asic.simulation import SimulationState + class Operation(GraphComponent): """Operation interface. @@ -88,3 +91,114 @@ class Operation(GraphComponent): If no neighbours are found this returns an empty list """ raise NotImplementedError + + +class AbstractOperation(Operation, AbstractGraphComponent): + """Generic abstract operation class which most implementations will derive from. + TODO: More info. + """ + + _input_ports: List["InputPort"] + _output_ports: List["OutputPort"] + _parameters: Dict[str, Optional[Any]] + + def __init__(self, name: Name = ""): + super().__init__(name) + self._input_ports = [] + self._output_ports = [] + self._parameters = {} + + @abstractmethod + def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ + """Evaluate the operation and generate a list of output values given a + list of input values.""" + raise NotImplementedError + + def inputs(self) -> List["InputPort"]: + return self._input_ports.copy() + + def outputs(self) -> List["OutputPort"]: + return self._output_ports.copy() + + def input_count(self) -> int: + return len(self._input_ports) + + def output_count(self) -> int: + return len(self._output_ports) + + def input(self, i: int) -> "InputPort": + return self._input_ports[i] + + def output(self, i: int) -> "OutputPort": + return self._output_ports[i] + + def params(self) -> Dict[str, Optional[Any]]: + return self._parameters.copy() + + def param(self, name: str) -> Optional[Any]: + return self._parameters.get(name) + + def set_param(self, name: str, value: Any) -> None: + assert name in self._parameters # TODO: Error message. + self._parameters[name] = value + + def evaluate_outputs(self, state: SimulationState) -> List[Number]: + # TODO: Check implementation. + input_count: int = self.input_count() + output_count: int = self.output_count() + assert input_count == len(self._input_ports) # TODO: Error message. + assert output_count == len(self._output_ports) # TODO: Error message. + + self_state: OperationState = state.operation_states[self] + + while self_state.iteration < state.iteration: + input_values: List[Number] = [0] * input_count + for i in range(input_count): + source: Signal = self._input_ports[i].signal + input_values[i] = source.operation.evaluate_outputs(state)[source.port_index] + + self_state.output_values = self.evaluate(input_values) + assert len(self_state.output_values) == output_count # TODO: Error message. + self_state.iteration += 1 + for i in range(output_count): + for signal in self._output_ports[i].signals(): + destination: Signal = signal.destination + destination.evaluate_outputs(state) + + return self_state.output_values + + def split(self) -> List[Operation]: + # TODO: Check implementation. + results = self.evaluate(self._input_ports) + if all(isinstance(e, Operation) for e in results): + return results + return [self] + + @property + def neighbours(self) -> List[Operation]: + neighbours: List[Operation] = [] + for port in self._input_ports: + for signal in port.signals: + neighbours.append(signal.source.operation) + + for port in self._output_ports: + for signal in port.signals: + neighbours.append(signal.destination.operation) + + return neighbours + + def traverse(self) -> Operation: + """Traverse the operation tree and return a generator with start point in the operation.""" + return self._breadth_first_search() + + def _breadth_first_search(self) -> Operation: + """Use breadth first search to traverse the operation tree.""" + visited: Set[Operation] = {self} + queue = deque([self]) + while queue: + operation = queue.popleft() + yield operation + for n_operation in operation.neighbours: + if n_operation not in visited: + visited.add(n_operation) + queue.append(n_operation) diff --git a/b_asic/port.py b/b_asic/port.py index eff9db9d..c22053df 100644 --- a/b_asic/port.py +++ b/b_asic/port.py @@ -9,31 +9,25 @@ from typing import NewType, Optional, List from b_asic.operation import Operation from b_asic.signal import Signal -PortId = NewType("PortId", int) - +PortIndex = NewType("PortIndex", int) class Port(ABC): - """Abstract port class. + """Port Interface. - Handles functionality for port id and saves the connection to the parent operation. + TODO: More documentaiton? """ - _port_id: PortId - _operation: Operation - - def __init__(self, port_id: PortId, operation: Operation): - self._port_id = port_id - self._operation = operation - - @property - def id(self) -> PortId: - """Return the unique portid.""" - return self._port_id - @property + @abstractmethod def operation(self) -> Operation: """Return the connected operation.""" - return self._operation + raise NotImplementedError + + @property + @abstractmethod + def index(self) -> PortIndex: + """Return the unique PortIndex.""" + raise NotImplementedError @property @abstractmethod @@ -62,136 +56,168 @@ class Port(ABC): raise NotImplementedError @abstractmethod - def connect_port(self, port: "Port") -> Signal: + def connect(self, port: "Port") -> Signal: """Create and return a signal that is connected to this port and the entered port and connect this port to the signal and the entered port to the signal.""" raise NotImplementedError @abstractmethod - def connect_signal(self, signal: Signal) -> None: + def add_signal(self, signal: Signal) -> None: """Connect this port to the entered signal. If the entered signal isn't connected to this port then connect the entered signal to the port aswell.""" raise NotImplementedError @abstractmethod - def disconnect_signal(self, i: int = 0) -> None: - """Disconnect a signal from the port. If the port is still connected to the entered signal - then the port is disconnected from the the entered signal aswell.""" + def disconnect(self, port: "Port") -> None: + """Disconnect the entered port from the port by removing it from the ports signal. + If the entered port is still connected to this ports signal then disconnect the entered + port from the signal aswell.""" raise NotImplementedError @abstractmethod - def is_connected_to_signal(self, signal: Signal) -> bool: - """Return true if the port is connected to the entered signal else false.""" + def remove_signal(self, signal: Signal) -> None: + """Remove the signal that was entered from the Ports signals. + If the entered signal still is connected to this port then disconnect the + entered signal from the port aswell. + + Keyword arguments: + - signal: Signal to remove. + """ + raise NotImplementedError + + @abstractmethod + def clear(self) -> None: + """Removes all connected signals from the Port.""" raise NotImplementedError -class InputPort(Port): +class AbstractPort(Port): + """Abstract port class. + + Handles functionality for port id and saves the connection to the parent operation. + """ + + _index: int + _operation: Operation + + def __init__(self, index: int, operation: Operation): + self._index = index + self._operation = operation + + @property + def operation(self) -> Operation: + return self._operation + + @property + def index(self) -> PortIndex: + return self._index + + +class InputPort(AbstractPort): """Input port. TODO: More info. """ - _signal: Optional[Signal] + _source_signal: Optional[Signal] - def __init__(self, port_id: PortId, operation: Operation): + def __init__(self, port_id: PortIndex, operation: Operation): super().__init__(port_id, operation) - self._signal = None + self._source_signal = None @property def signals(self) -> List[Signal]: - return [] if self._signal is None else [self._signal] + return [] if self._source_signal is None else [self._source_signal] def signal(self, i: int = 0) -> Signal: assert 0 <= i < self.signal_count(), "Signal index out of bound." - assert self._signal is not None, "No Signal connect to InputPort." - return self._signal + assert self._source_signal is not None, "No Signal connect to InputPort." + return self._source_signal @property def connected_ports(self) -> List[Port]: - return [] if self._signal is None else [self._signal.source] + return [] if self._source_signal is None or self._source_signal.source is None \ + else [self._source_signal.source] def signal_count(self) -> int: - return 0 if self._signal is None else 1 + return 0 if self._source_signal is None else 1 - def connect_port(self, port: "OutputPort") -> Signal: - assert self._signal is None, "Connecting new port to already connected input port." - return Signal(port, self) # self._signal is set by the signal constructor + def connect(self, port: "OutputPort") -> Signal: + assert self._source_signal is None, "Connecting new port to already connected input port." + return Signal(port, self) # self._source_signal is set by the signal constructor - def connect_signal(self, signal: Signal) -> None: - assert self._signal is None, "Connecting new port to already connected input port." - self._signal = signal + def add_signal(self, signal: Signal) -> None: + assert self._source_signal is None, "Connecting new port to already connected input port." + self._source_signal: Signal = signal if self is not signal.destination: # Connect this inputport as destination for this signal if it isn't already. - signal.connect_destination(self) + signal.set_destination(self) - def disconnect_signal(self, i: int = 0) -> None: - assert 0 <= i < self.signal_count(), "Signal Index out of range." - old_signal: Signal = self._signal - self._signal = None + def disconnect(self, port: "OutputPort") -> None: + assert self._source_signal.source is port, "The entered port is not connected to this port." + self._source_signal.remove_source() + + def remove_signal(self, signal: Signal) -> None: + old_signal: Signal = self._source_signal + self._source_signal = None if self is old_signal.destination: # Disconnect the dest of the signal if this inputport currently is the dest - old_signal.disconnect_destination() - old_signal.disconnect_destination() - - def is_connected_to_signal(self, signal: Signal) -> bool: - return self._signal is signal + old_signal.remove_destination() + def clear(self) -> None: + self.remove_signal(self._source_signal) -class OutputPort(Port): +class OutputPort(AbstractPort): """Output port. TODO: More info. """ - _signals: List[Signal] + _destination_signals: List[Signal] - def __init__(self, port_id: PortId, operation: Operation): + def __init__(self, port_id: PortIndex, operation: Operation): super().__init__(port_id, operation) - self._signals = [] + self._destination_signals = [] @property def signals(self) -> List[Signal]: - return self._signals.copy() + return self._destination_signals.copy() def signal(self, i: int = 0) -> Signal: assert 0 <= i < self.signal_count(), "Signal index out of bounds." - return self._signals[i] + return self._destination_signals[i] @property def connected_ports(self) -> List[Port]: - return [signal.destination for signal in self._signals \ + return [signal.destination for signal in self._destination_signals \ if signal.destination is not None] def signal_count(self) -> int: - return len(self._signals) + return len(self._destination_signals) - def connect_port(self, port: InputPort) -> Signal: - return Signal(self, port) # Signal is added to self._signals in signal constructor + def connect(self, port: InputPort) -> Signal: + return Signal(self, port) # Signal is added to self._destination_signals in signal constructor - def connect_signal(self, signal: Signal) -> None: - assert not self.is_connected_to_signal(signal), \ + def add_signal(self, signal: Signal) -> None: + assert signal not in self.signals, \ "Attempting to connect to Signal already connected." - self._signals.append(signal) + self._destination_signals.append(signal) if self is not signal.source: # Connect this outputport to the signal if it isn't already - signal.connect_source(self) - - def disconnect_signal(self, i: int = 0) -> None: - assert 0 <= i < self.signal_count(), "Signal index out of bounds." - old_signal: Signal = self._signals[i] - del self._signals[i] + signal.set_source(self) + + def disconnect(self, port: InputPort) -> None: + assert port in self.connected_ports, "Attempting to disconnect port that isn't connected." + for sig in self._destination_signals: + if sig.destination is port: + sig.remove_destination() + break + + def remove_signal(self, signal: Signal) -> None: + i: int = self._destination_signals.index(signal) + old_signal: Signal = self._destination_signals[i] + del self._destination_signals[i] if self is old_signal.source: - # Disconnect the source of the signal if this outputport currently is the source - old_signal.disconnect_source() - - def disconnect_signal_by_ref(self, signal: Signal) -> None: - """Remove the signal that was entered from the OutputPorts signals. - If the entered signal still is connected to this port then disconnect the - entered signal from the port aswell. - - Keyword arguments: - - signal: Signal to remove. - """ - i: int = self._signals.index(signal) - self.disconnect_signal(i) + old_signal.remove_source() - def is_connected_to_signal(self, signal: Signal) -> bool: - return signal in self._signals # O(n) complexity + def clear(self) -> None: + for signal in self._destination_signals: + self.remove_signal(signal) diff --git a/b_asic/signal.py b/b_asic/signal.py index 917e4af3..64c25948 100644 --- a/b_asic/signal.py +++ b/b_asic/signal.py @@ -3,13 +3,12 @@ B-ASIC Signal Module. """ from typing import Optional, TYPE_CHECKING -from b_asic.graph_component import TypeName -from b_asic.abstract_graph_component import AbstractGraphComponent -from b_asic.graph_component import Name +from b_asic.graph_component import AbstractGraphComponent, TypeName, Name if TYPE_CHECKING: from b_asic.port import InputPort, OutputPort + class Signal(AbstractGraphComponent): """A connection between two ports.""" @@ -25,10 +24,10 @@ class Signal(AbstractGraphComponent): self._destination = destination if source is not None: - self.connect_source(source) + self.set_source(source) if destination is not None: - self.connect_destination(destination) + self.set_destination(destination) @property def source(self) -> "OutputPort": @@ -40,7 +39,7 @@ class Signal(AbstractGraphComponent): """Return the destination "InputPort" of the signal.""" return self._destination - def connect_source(self, src: "OutputPort") -> None: + def set_source(self, src: "OutputPort") -> None: """Disconnect the previous source OutputPort of the signal and connect to the entered source OutputPort. Also connect the entered source port to the signal if it hasn't already been connected. @@ -48,13 +47,13 @@ class Signal(AbstractGraphComponent): Keyword arguments: - src: OutputPort to connect as source to the signal. """ - self.disconnect_source() + self.remove_source() self._source = src - if not src.is_connected_to_signal(self): + if self not in src.signals: # If the new source isn't connected to this signal then connect it. - src.connect_signal(self) + src.add_signal(self) - def connect_destination(self, dest: "InputPort") -> None: + def set_destination(self, dest: "InputPort") -> None: """Disconnect the previous destination InputPort of the signal and connect to the entered destination InputPort. Also connect the entered destination port to the signal if it hasn't already been connected. @@ -62,34 +61,34 @@ class Signal(AbstractGraphComponent): Keywords argments: - dest: InputPort to connect as destination to the signal. """ - self.disconnect_destination() + self.remove_destination() self._destination = dest - if not dest.is_connected_to_signal(self): + if self not in dest.signals: # If the new destination isn't connected to tis signal then connect it. - dest.connect_signal(self) + dest.add_signal(self) @property def type_name(self) -> TypeName: return "s" - def disconnect_source(self) -> None: + def remove_source(self) -> None: """Disconnect the source OutputPort of the signal. If the source port still is connected to this signal then also disconnect the source port.""" if self._source is not None: old_source: "OutputPort" = self._source self._source = None - if old_source.is_connected_to_signal(self): + if self in old_source.signals: # If the old destination port still is connected to this signal, then disconnect it. - old_source.disconnect_signal_by_ref(self) + old_source.remove_signal(self) - def disconnect_destination(self) -> None: + def remove_destination(self) -> None: """Disconnect the destination InputPort of the signal.""" if self._destination is not None: old_destination: "InputPort" = self._destination self._destination = None - if old_destination.is_connected_to_signal(self): + if self in old_destination.signals: # If the old destination port still is connected to this signal, then disconnect it. - old_destination.disconnect_signal() + old_destination.remove_signal(self) def is_connected(self) -> bool: """Returns true if the signal is connected to both a source and a destination, diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index ab2c3e94..9c08aecc 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -7,7 +7,7 @@ from typing import List, Dict, Optional, DefaultDict from collections import defaultdict from b_asic.operation import Operation -from b_asic.abstract_operation import AbstractOperation +from b_asic.operation import AbstractOperation from b_asic.signal import Signal from b_asic.graph_id import GraphIDGenerator, GraphID from b_asic.graph_component import GraphComponent, Name, TypeName @@ -46,7 +46,7 @@ class SFG(AbstractOperation): # TODO: Traverse the graph between the inputs/outputs and add to self._operations. # TODO: Connect ports with signals with appropriate IDs. - def evaluate(self, inputs: list) -> list: + def evaluate(self, *inputs) -> list: return [] # TODO: Implement def _add_graph_component(self, graph_component: GraphComponent) -> GraphID: diff --git a/b_asic/utilities.py b/b_asic/utilities.py deleted file mode 100644 index 25707ff8..00000000 --- a/b_asic/utilities.py +++ /dev/null @@ -1,21 +0,0 @@ -"""@package docstring -B-ASIC Operation Module. -TODO: More info. -""" - -from typing import Set -from collections import deque - -from b_asic.operation import Operation - -def breadth_first_search(start: Operation) -> Operation: - """Use breadth first search to traverse the operation tree.""" - visited: Set[Operation] = {start} - queue = deque([start]) - while queue: - operation = queue.popleft() - yield operation - for n_operation in operation.neighbours: - if n_operation not in visited: - visited.add(n_operation) - queue.append(n_operation) diff --git a/test/basic_operations/test_basic_operations.py b/test/core_operations/test_core_operations.py similarity index 94% rename from test/basic_operations/test_basic_operations.py rename to test/core_operations/test_core_operations.py index 21561074..1d33bfe1 100644 --- a/test/basic_operations/test_basic_operations.py +++ b/test/core_operations/test_core_operations.py @@ -1,12 +1,12 @@ """ -B-ASIC test suite for the basic operations. +B-ASIC test suite for the core operations. """ from b_asic.core_operations import Constant, Addition, Subtraction, Multiplication, Division, SquareRoot, ComplexConjugate, Max, Min, Absolute, ConstantMultiplication, ConstantAddition, ConstantSubtraction, ConstantDivision from b_asic.signal import Signal import pytest -""" Constant tests. """ +# Constant tests. def test_constant(): constant_operation = Constant(3) assert constant_operation.evaluate() == 3 @@ -19,7 +19,7 @@ def test_constant_complex(): constant_operation = Constant(3+4j) assert constant_operation.evaluate() == 3+4j -""" Addition tests. """ +# Addition tests. def test_addition(): test_operation = Addition() constant_operation = Constant(3) @@ -38,7 +38,7 @@ def test_addition_complex(): constant_operation_2 = Constant((4+6j)) assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (7+11j) -""" Subtraction tests. """ +# Subtraction tests. def test_subtraction(): test_operation = Subtraction() constant_operation = Constant(5) @@ -57,7 +57,7 @@ def test_subtraction_complex(): constant_operation_2 = Constant((4+6j)) assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (-1-1j) -""" Multiplication tests. """ +# Multiplication tests. def test_multiplication(): test_operation = Multiplication() constant_operation = Constant(5) @@ -76,7 +76,7 @@ def test_multiplication_complex(): constant_operation_2 = Constant((4+6j)) assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (-18+38j) -""" Division tests. """ +# Division tests. def test_division(): test_operation = Division() constant_operation = Constant(30) @@ -95,7 +95,7 @@ def test_division_complex(): constant_operation_2 = Constant((10+20j)) assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (2.8-1.6j) -""" SquareRoot tests. """ +# SquareRoot tests. def test_squareroot(): test_operation = SquareRoot() constant_operation = Constant(36) @@ -111,7 +111,7 @@ def test_squareroot_complex(): constant_operation = Constant((48+64j)) assert test_operation.evaluate(constant_operation.evaluate()) == (8+4j) -""" ComplexConjugate tests. """ +# ComplexConjugate tests. def test_complexconjugate(): test_operation = ComplexConjugate() constant_operation = Constant(3+4j) @@ -122,7 +122,7 @@ def test_test_complexconjugate_negative(): constant_operation = Constant(-3-4j) assert test_operation.evaluate(constant_operation.evaluate()) == (-3+4j) -""" Max tests. """ +# Max tests. def test_max(): test_operation = Max() constant_operation = Constant(30) @@ -135,7 +135,7 @@ def test_max_negative(): constant_operation_2 = Constant(-5) assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -5 -""" Min tests. """ +# Min tests. def test_min(): test_operation = Min() constant_operation = Constant(30) @@ -148,7 +148,7 @@ def test_min_negative(): constant_operation_2 = Constant(-5) assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -30 -""" Absolute tests. """ +# Absolute tests. def test_absolute(): test_operation = Absolute() constant_operation = Constant(30) @@ -164,7 +164,7 @@ def test_absolute_complex(): constant_operation = Constant((3+4j)) assert test_operation.evaluate(constant_operation.evaluate()) == 5.0 -""" ConstantMultiplication tests. """ +# ConstantMultiplication tests. def test_constantmultiplication(): test_operation = ConstantMultiplication(5) constant_operation = Constant(20) @@ -180,7 +180,7 @@ def test_constantmultiplication_complex(): constant_operation = Constant((3+4j)) assert test_operation.evaluate(constant_operation.evaluate()) == (1+18j) -""" ConstantAddition tests. """ +# ConstantAddition tests. def test_constantaddition(): test_operation = ConstantAddition(5) constant_operation = Constant(20) @@ -196,7 +196,7 @@ def test_constantaddition_complex(): constant_operation = Constant((3+4j)) assert test_operation.evaluate(constant_operation.evaluate()) == (6+6j) -""" ConstantSubtraction tests. """ +# ConstantSubtraction tests. def test_constantsubtraction(): test_operation = ConstantSubtraction(5) constant_operation = Constant(20) @@ -212,7 +212,7 @@ def test_constantsubtraction_complex(): constant_operation = Constant((3+4j)) assert test_operation.evaluate(constant_operation.evaluate()) == (-1-2j) -""" ConstantDivision tests. """ +# ConstantDivision tests. def test_constantdivision(): test_operation = ConstantDivision(5) constant_operation = Constant(20) @@ -226,4 +226,4 @@ def test_constantdivision_negative(): def test_constantdivision_complex(): test_operation = ConstantDivision(2+2j) constant_operation = Constant((10+10j)) - assert test_operation.evaluate(constant_operation.evaluate()) == (5+0j) \ No newline at end of file + assert test_operation.evaluate(constant_operation.evaluate()) == (5+0j) diff --git a/test/fixtures/operation_tree.py b/test/fixtures/operation_tree.py index b97e89d7..74d3b8c6 100644 --- a/test/fixtures/operation_tree.py +++ b/test/fixtures/operation_tree.py @@ -10,9 +10,9 @@ def operation(): def create_operation(_type, dest_oper, index, **kwargs): oper = _type(**kwargs) oper_signal = Signal() - oper._output_ports[0].connect_signal(oper_signal) + oper._output_ports[0].add_signal(oper_signal) - dest_oper._input_ports[index].connect_signal(oper_signal) + dest_oper._input_ports[index].add_signal(oper_signal) return oper @pytest.fixture @@ -51,10 +51,10 @@ def large_operation_tree(): add_oper_3 = Addition() add_oper_signal = Signal(add_oper.output(0), add_oper_3.output(0)) - add_oper._output_ports[0].connect_signal(add_oper_signal) - add_oper_3._input_ports[0].connect_signal(add_oper_signal) + add_oper._output_ports[0].add_signal(add_oper_signal) + add_oper_3._input_ports[0].add_signal(add_oper_signal) add_oper_2_signal = Signal(add_oper_2.output(0), add_oper_3.output(0)) - add_oper_2._output_ports[0].connect_signal(add_oper_2_signal) - add_oper_3._input_ports[1].connect_signal(add_oper_2_signal) + add_oper_2._output_ports[0].add_signal(add_oper_2_signal) + add_oper_3._input_ports[1].add_signal(add_oper_2_signal) return const_oper diff --git a/test/port/test_inputport.py b/test/port/test_inputport.py index 7a78d5f7..a4324069 100644 --- a/test/port/test_inputport.py +++ b/test/port/test_inputport.py @@ -39,9 +39,9 @@ def connected_sig(): inp_port = InputPort(0, None) return Signal(source=out_port, destination=inp_port) -def test_connect_port_then_disconnect(inp_port, out_port): +def test_connect_then_disconnect(inp_port, out_port): """Test connect unused port to port.""" - s1 = inp_port.connect_port(out_port) + s1 = inp_port.connect(out_port) assert inp_port.connected_ports == [out_port] assert out_port.connected_ports == [inp_port] @@ -50,7 +50,7 @@ def test_connect_port_then_disconnect(inp_port, out_port): assert s1.source is out_port assert s1.destination is inp_port - inp_port.disconnect_signal() + inp_port.remove_signal(s1) assert inp_port.connected_ports == [] assert out_port.connected_ports == [] @@ -61,12 +61,13 @@ def test_connect_port_then_disconnect(inp_port, out_port): def test_connect_used_port_to_new_port(inp_port, out_port, out_port2): """Does connecting multiple ports to an inputport throw error?""" - inp_port.connect_port(out_port) + inp_port.connect(out_port) with pytest.raises(AssertionError): - inp_port.connect_port(out_port2) + inp_port.connect(out_port2) -def test_connect_signal_then_disconnect(inp_port, s_w_source): - inp_port.connect_signal(s_w_source) +def test_add_signal_then_disconnect(inp_port, s_w_source): + """Can signal be connected then disconnected properly?""" + inp_port.add_signal(s_w_source) assert inp_port.connected_ports == [s_w_source.source] assert s_w_source.source.connected_ports == [inp_port] @@ -74,7 +75,7 @@ def test_connect_signal_then_disconnect(inp_port, s_w_source): assert s_w_source.source.signals == [s_w_source] assert s_w_source.destination is inp_port - inp_port.disconnect_signal() + inp_port.remove_signal(s_w_source) assert inp_port.connected_ports == [] assert s_w_source.source.connected_ports == [] @@ -82,3 +83,13 @@ def test_connect_signal_then_disconnect(inp_port, s_w_source): assert s_w_source.source.signals == [s_w_source] assert s_w_source.destination is None +def test_connect_then_disconnect(inp_port, out_port): + """Can port be connected and then disconnected properly?""" + inp_port.connect(out_port) + + inp_port.disconnect(out_port) + + print("outport signals:", out_port.signals, "count:", out_port.signal_count()) + assert inp_port.signal_count() == 1 + assert len(inp_port.connected_ports) == 0 + assert out_port.signal_count() == 0 diff --git a/test/port/test_outputport.py b/test/port/test_outputport.py index f48afbdb..ac50818e 100644 --- a/test/port/test_outputport.py +++ b/test/port/test_outputport.py @@ -14,18 +14,33 @@ def test_connect_multiple_signals(inp_ports): out_port = OutputPort(0, None) for port in inp_ports: - out_port.connect_port(port) - + out_port.connect(port) + assert out_port.signal_count() == len(inp_ports) def test_disconnect_multiple_signals(inp_ports): - """Can multiple ports disconnect from an output port?""" + """Can multiple signals disconnect from an output port?""" + out_port = OutputPort(0, None) + + sigs = [] + + for port in inp_ports: + sigs.append(out_port.connect(port)) + + for sig in sigs: + out_port.remove_signal(sig) + + assert out_port.signal_count() == 0 + +def test_disconnect_mulitple_ports(inp_ports): + """Can multiple ports be disconnected from an output port?""" out_port = OutputPort(0, None) for port in inp_ports: - out_port.connect_port(port) - - for _ in inp_ports: - out_port.disconnect_signal(0) + out_port.connect(port) + + for port in inp_ports: + out_port.disconnect(port) - assert out_port.signal_count() == 0 \ No newline at end of file + assert out_port.signal_count() == 3 + assert len(out_port.connected_ports) == 0 \ No newline at end of file diff --git a/test/signal/test_signal.py b/test/signal/test_signal.py index 8c10d1e3..ab07eb77 100644 --- a/test/signal/test_signal.py +++ b/test/signal/test_signal.py @@ -18,7 +18,7 @@ def test_signal_creation_and_disconnction_and_connection_changing(): assert s.destination is in_port in_port1 = InputPort(0, None) - s.connect_destination(in_port1) + s.set_destination(in_port1) assert in_port.signals == [] assert in_port1.signals == [s] @@ -26,14 +26,14 @@ def test_signal_creation_and_disconnction_and_connection_changing(): assert s.source is out_port assert s.destination is in_port1 - s.disconnect_source() + s.remove_source() assert out_port.signals == [] assert in_port1.signals == [s] assert s.source is None assert s.destination is in_port1 - s.disconnect_destination() + s.remove_destination() assert out_port.signals == [] assert in_port1.signals == [] @@ -41,20 +41,20 @@ def test_signal_creation_and_disconnction_and_connection_changing(): assert s.destination is None out_port1 = OutputPort(0, None) - s.connect_source(out_port1) + s.set_source(out_port1) assert out_port1.signals == [s] assert s.source is out_port1 assert s.destination is None - s.connect_source(out_port) + s.set_source(out_port) assert out_port.signals == [s] assert out_port1.signals == [] assert s.source is out_port assert s.destination is None - s.connect_destination(in_port) + s.set_destination(in_port) assert out_port.signals == [s] assert in_port.signals == [s] diff --git a/test/traverse/test_traverse_tree.py b/test/traverse/test_traverse_tree.py index 9f509287..031aeec7 100644 --- a/test/traverse/test_traverse_tree.py +++ b/test/traverse/test_traverse_tree.py @@ -24,7 +24,7 @@ def test_traverse_type(large_operation_tree): def test_traverse_loop(operation_tree): add_oper_signal = Signal() - operation_tree._output_ports[0].connect_signal(add_oper_signal) - operation_tree._input_ports[0].disconnect_signal() - operation_tree._input_ports[0].connect_signal(add_oper_signal) + operation_tree._output_ports[0].add_signal(add_oper_signal) + operation_tree._input_ports[0].remove_signal(add_oper_signal) + operation_tree._input_ports[0].add_signal(add_oper_signal) assert len(list(operation_tree.traverse())) == 2 -- GitLab