diff --git a/b_asic/signal.py b/b_asic/signal.py index b9b80383b09016a6e81bf796e52270660cbb215e..59148a1458d74897ec2245b3d645244959eff9d3 100644 --- a/b_asic/signal.py +++ b/b_asic/signal.py @@ -3,7 +3,7 @@ B-ASIC Signal Module. Contains the class for representing the connections between operations. """ -from typing import TYPE_CHECKING, Iterable, Optional +from typing import TYPE_CHECKING, Iterable, Optional, Union from b_asic.graph_component import ( AbstractGraphComponent, @@ -14,18 +14,37 @@ from b_asic.graph_component import ( if TYPE_CHECKING: from b_asic.port import InputPort, OutputPort + from b_asic.operation import Operation class Signal(AbstractGraphComponent): - """A connection between two ports.""" + """ + A connection between two ports. + + Parameters + ========== + + source : OutputPort or Operation, optional + OutputPort or Operation to connect as source to the signal. + destination : InputPort or Operation, optional + InputPort or Operation to connect as destination to the signal. + bits : int, optional + The word length of the signal. + name : Name, default: "" + The signal name. + + See also + ======== + set_source, set_destination + """ _source: Optional["OutputPort"] _destination: Optional["InputPort"] def __init__( self, - source: Optional["OutputPort"] = None, - destination: Optional["InputPort"] = None, + source: Optional[Union["OutputPort", "Operation"]] = None, + destination: Optional[Union["InputPort", "Operation"]] = None, bits: Optional[int] = None, name: Name = Name(""), ): @@ -53,67 +72,96 @@ class Signal(AbstractGraphComponent): @property def source(self) -> Optional["OutputPort"]: - """Return the source OutputPort of the signal.""" + """The source OutputPort of the signal.""" return self._source @property def destination(self) -> Optional["InputPort"]: - """Return the destination "InputPort" of the signal.""" + """The destination InputPort of the signal.""" return self._destination - def set_source(self, src: "OutputPort") -> None: - """Disconnect the previous source OutputPort of the signal and + def set_source(self, source: Union["OutputPort", "Operation"]) -> None: + """ + Disconnect the previous source OutputPort of the signal and connect to the entered source OutputPort. Also connect the entered - source port to the signal if it hasn't already been connected. + source port to the signal if it has not already been connected. Parameters ========== - src : OutputPort - OutputPort to connect as source to the signal. + source : OutputPort or Operation, optional + OutputPort or Operation to connect as source to the signal. If + Operation, it must have a single output, otherwise a TypeError is + raised. That output is used to extract the OutputPort. """ - if src is not self._source: + from b_asic.operation import Operation + + if isinstance(source, Operation): + if source.output_count != 1: + raise TypeError( + "Can only connect operations with a single output." + f" {source.type_name()} has {source.output_count} outputs." + " Use the output port directly instead." + ) + source = source.output(0) + + if source is not self._source: self.remove_source() - self._source = src - if self not in src.signals: - src.add_signal(self) + self._source = source + if self not in source.signals: + source.add_signal(self) - def set_destination(self, dest: "InputPort") -> None: + def set_destination(self, destination: "InputPort") -> None: """ Disconnect the previous destination InputPort of the signal and connect to the entered destination InputPort. Also connect the entered - destination port to the signal if it hasn't already been connected. + destination port to the signal if it has not already been connected. Parameters ========== - dest : InputPort - InputPort to connect as destination to the signal. + destination : InputPort or Operation + InputPort or Operation to connect as destination to the signal. + If Operation, it must have a single input, otherwise a TypeError + is raised. + """ - if dest is not self._destination: + from b_asic.operation import Operation + + if isinstance(destination, Operation): + if destination.input_count != 1: + raise TypeError( + "Can only connect operations with a single input." + f" {destination.type_name()} has" + f" {destination.input_count} outputs. Use the input port" + " directly instead." + ) + destination = destination.input(0) + + if destination is not self._destination: self.remove_destination() - self._destination = dest - if self not in dest.signals: - dest.add_signal(self) + self._destination = destination + if self not in destination.signals: + destination.add_signal(self) def remove_source(self) -> None: """ Disconnect the source OutputPort of the signal. If the source port still is connected to this signal then also disconnect the source port. """ - src = self._source - if src is not None: + source = self._source + if source is not None: self._source = None - if self in src.signals: - src.remove_signal(self) + if self in source.signals: + source.remove_signal(self) def remove_destination(self) -> None: """Disconnect the destination InputPort of the signal.""" - dest = self._destination - if dest is not None: + destination = self._destination + if destination is not None: self._destination = None - if self in dest.signals: - dest.remove_signal(self) + if self in destination.signals: + destination.remove_signal(self) def dangling(self) -> bool: """ diff --git a/test/test_signal.py b/test/test_signal.py index 30186fd1df4e6757ca1ca61f0e538639e902a296..f7c7c767ef6c2141063fba38c38ed7df4162410f 100644 --- a/test/test_signal.py +++ b/test/test_signal.py @@ -4,10 +4,12 @@ B-ASIC test suit for the signal module which consists of the Signal class. import pytest -from b_asic import InputPort, OutputPort, Signal +from b_asic.core_operations import Addition, Butterfly, ConstantMultiplication +from b_asic.port import InputPort, OutputPort +from b_asic.signal import Signal -def test_signal_creation_and_disconnction_and_connection_changing(): +def test_signal_creation_and_disconnection_and_connection_changing(): in_port = InputPort(None, 0) out_port = OutputPort(None, 1) s = Signal(out_port, in_port) @@ -87,3 +89,60 @@ class TestBits: signal.bits = 10 signal.bits = None assert signal.bits is None + + +def test_create_from_single_input_single_output(): + cm1 = ConstantMultiplication(0.5, name="Foo") + cm2 = ConstantMultiplication(1.5, name="Bar") + signal = Signal(cm1, cm2) + assert signal.destination.operation.name == "Bar" + assert signal.source.operation.name == "Foo" + + add1 = Addition(name="Zig") + + signal.set_source(add1) + + assert signal.source.operation.name == "Zig" + + +def test_signal_errors(): + cm1 = ConstantMultiplication(0.5, name="Foo") + add1 = Addition(name="Zig") + with pytest.raises( + TypeError, + match=( + "Can only connect operations with a single input. add has 2" + " outputs." + ), + ): + _ = Signal(cm1, add1) + + bf = Butterfly() + with pytest.raises( + TypeError, + match=( + "Can only connect operations with a single output. bfly has 2" + " outputs." + ), + ): + _ = Signal(bf, cm1) + + cm2 = ConstantMultiplication(1.5, name="Bar") + signal = Signal(cm1, cm2) + with pytest.raises( + TypeError, + match=( + "Can only connect operations with a single input. add has 2" + " outputs." + ), + ): + signal.set_destination(add1) + + with pytest.raises( + TypeError, + match=( + "Can only connect operations with a single output. bfly has 2" + " outputs." + ), + ): + signal.set_source(bf)