diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index d06cbab3dc5356465bdc3c2f2b93c55f26715aed..8902b169c07e600a843e526d44452fb9386ff7a9 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -55,9 +55,9 @@ class Addition(AbstractOperation): self._output_ports = [OutputPort(0, self)] if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) if source2 is not None: - self._input_ports[1].connect_to_port(source2) + self._input_ports[1].connect(source2) def evaluate(self, a, b): return a + b @@ -78,9 +78,9 @@ class Subtraction(AbstractOperation): self._output_ports = [OutputPort(0, self)] if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) if source2 is not None: - self._input_ports[1].connect_to_port(source2) + self._input_ports[1].connect(source2) def evaluate(self, a, b): return a - b @@ -101,9 +101,9 @@ class Multiplication(AbstractOperation): self._output_ports = [OutputPort(0, self)] if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) if source2 is not None: - self._input_ports[1].connect_to_port(source2) + self._input_ports[1].connect(source2) def evaluate(self, a, b): return a * b @@ -124,9 +124,9 @@ class Division(AbstractOperation): self._output_ports = [OutputPort(0, self)] if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) if source2 is not None: - self._input_ports[1].connect_to_port(source2) + self._input_ports[1].connect(source2) def evaluate(self, a, b): return a / b @@ -147,8 +147,7 @@ class SquareRoot(AbstractOperation): self._output_ports = [OutputPort(0, self)] if source1 is not None: - self._input_ports[0].connect_to_port(source1) - + self._input_ports[0].connect(source1) def evaluate(self, a): return sqrt((complex)(a)) @@ -169,8 +168,7 @@ class ComplexConjugate(AbstractOperation): self._output_ports = [OutputPort(0, self)] if source1 is not None: - self._input_ports[0].connect_to_port(source1) - + self._input_ports[0].connect(source1) def evaluate(self, a): return conjugate(a) @@ -191,9 +189,9 @@ class Max(AbstractOperation): self._output_ports = [OutputPort(0, self)] if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) if source2 is not None: - self._input_ports[1].connect_to_port(source2) + self._input_ports[1].connect(source2) def evaluate(self, a, b): assert not isinstance(a, complex) and not isinstance(b, complex), \ @@ -216,9 +214,9 @@ class Min(AbstractOperation): self._output_ports = [OutputPort(0, self)] if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) if source2 is not None: - self._input_ports[1].connect_to_port(source2) + self._input_ports[1].connect(source2) def evaluate(self, a, b): assert not isinstance(a, complex) and not isinstance(b, complex), \ @@ -241,8 +239,7 @@ class Absolute(AbstractOperation): self._output_ports = [OutputPort(0, self)] if source1 is not None: - self._input_ports[0].connect_to_port(source1) - + self._input_ports[0].connect(source1) def evaluate(self, a): return np_abs(a) @@ -264,7 +261,7 @@ class ConstantMultiplication(AbstractOperation): self._parameters["coefficient"] = coefficient if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) def evaluate(self, a): return a * self.param("coefficient") @@ -286,7 +283,7 @@ class ConstantAddition(AbstractOperation): self._parameters["coefficient"] = coefficient if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) def evaluate(self, a): return a + self.param("coefficient") @@ -308,7 +305,7 @@ class ConstantSubtraction(AbstractOperation): self._parameters["coefficient"] = coefficient if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) def evaluate(self, a): return a - self.param("coefficient") @@ -330,7 +327,7 @@ class ConstantDivision(AbstractOperation): self._parameters["coefficient"] = coefficient if source1 is not None: - self._input_ports[0].connect_to_port(source1) + self._input_ports[0].connect(source1) def evaluate(self, a): return a / self.param("coefficient") diff --git a/b_asic/operation.py b/b_asic/operation.py index fc007ffde91e596fc103e811677d7aab96241fa4..5578e3c48edcf15594d6d1cd71e71a17521eca25 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -109,9 +109,10 @@ class AbstractOperation(Operation, AbstractGraphComponent): self._parameters = {} @abstractmethod - def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ + def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ """Evaluate the operation and generate a list of output values given a - list of input values.""" + list of input values. + """ raise NotImplementedError def inputs(self) -> List["InputPort"]: @@ -139,15 +140,15 @@ class AbstractOperation(Operation, AbstractGraphComponent): return self._parameters.get(name) def set_param(self, name: str, value: Any) -> None: - assert name in self._parameters # TODO: Error message. + assert name in self._parameters # TODO: Error message. self._parameters[name] = value def evaluate_outputs(self, state: SimulationState) -> List[Number]: # TODO: Check implementation. input_count: int = self.input_count() output_count: int = self.output_count() - assert input_count == len(self._input_ports) # TODO: Error message. - assert output_count == len(self._output_ports) # TODO: Error message. + assert input_count == len(self._input_ports) # TODO: Error message. + assert output_count == len(self._output_ports) # TODO: Error message. self_state: OperationState = state.operation_states[self] @@ -155,10 +156,12 @@ class AbstractOperation(Operation, AbstractGraphComponent): input_values: List[Number] = [0] * input_count for i in range(input_count): source: Signal = self._input_ports[i].signal - input_values[i] = source.operation.evaluate_outputs(state)[source.port_index] + input_values[i] = source.operation.evaluate_outputs(state)[ + source.port_index] self_state.output_values = self.evaluate(input_values) - assert len(self_state.output_values) == output_count # TODO: Error message. + # TODO: Error message. + assert len(self_state.output_values) == output_count self_state.iteration += 1 for i in range(output_count): for signal in self._output_ports[i].signals(): @@ -202,3 +205,64 @@ class AbstractOperation(Operation, AbstractGraphComponent): if n_operation not in visited: visited.add(n_operation) queue.append(n_operation) + + def __add__(self, other): + """Overloads the addition operator to make it return a new Addition operation + object that is connected to the self and other objects. If other is a number then + returns a ConstantAddition operation object instead. + """ + # Import here to avoid circular imports. + from b_asic.core_operations import Addition, ConstantAddition + + if isinstance(other, Operation): + return Addition(self.output(0), other.output(0)) + elif isinstance(other, Number): + return ConstantAddition(other, self.output(0)) + else: + raise TypeError("Other type is not an Operation or a Number.") + + def __sub__(self, other): + """Overloads the subtraction operator to make it return a new Subtraction operation + object that is connected to the self and other objects. If other is a number then + returns a ConstantSubtraction operation object instead. + """ + # Import here to avoid circular imports. + from b_asic.core_operations import Subtraction, ConstantSubtraction + + if isinstance(other, Operation): + return Subtraction(self.output(0), other.output(0)) + elif isinstance(other, Number): + return ConstantSubtraction(other, self.output(0)) + else: + raise TypeError("Other type is not an Operation or a Number.") + + def __mul__(self, other): + """Overloads the multiplication operator to make it return a new Multiplication operation + object that is connected to the self and other objects. If other is a number then + returns a ConstantMultiplication operation object instead. + """ + # Import here to avoid circular imports. + from b_asic.core_operations import Multiplication, ConstantMultiplication + + if isinstance(other, Operation): + return Multiplication(self.output(0), other.output(0)) + elif isinstance(other, Number): + return ConstantMultiplication(other, self.output(0)) + else: + raise TypeError("Other type is not an Operation or a Number.") + + def __truediv__(self, other): + """Overloads the division operator to make it return a new Division operation + object that is connected to the self and other objects. If other is a number then + returns a ConstantDivision operation object instead. + """ + # Import here to avoid circular imports. + from b_asic.core_operations import Division, ConstantDivision + + if isinstance(other, Operation): + return Division(self.output(0), other.output(0)) + elif isinstance(other, Number): + return ConstantDivision(other, self.output(0)) + else: + raise TypeError("Other type is not an Operation or a Number.") + diff --git a/test/conftest.py b/test/conftest.py index 66ee9630ea4ac0a05b446f4dedbfe68549a1191e..64f39843c53a4369781a269fd7fc30ad9aa1d255 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,3 +1,4 @@ +from test.fixtures.signal import signal, signals +from test.fixtures.operation_tree import * +from test.fixtures.port import * import pytest -from test.fixtures.signal import * -from test.fixtures.operation_tree import * \ No newline at end of file diff --git a/test/fixtures/operation_tree.py b/test/fixtures/operation_tree.py index 74d3b8c6f34cce87878f82e539c798a8a6dc9b0a..df3fcac35cc495d14bed06ccdfc2a3ebed25616e 100644 --- a/test/fixtures/operation_tree.py +++ b/test/fixtures/operation_tree.py @@ -17,11 +17,10 @@ def create_operation(_type, dest_oper, index, **kwargs): @pytest.fixture def operation_tree(): - """ - Return a addition operation connected with 2 constants. - >---C---+ - ---A - >---C---+ + """Return a addition operation connected with 2 constants. + ---C---+ + ---A + ---C---+ """ add_oper = Addition() create_operation(Constant, add_oper, 0, value=2) @@ -30,15 +29,14 @@ def operation_tree(): @pytest.fixture def large_operation_tree(): - """ - Return a constant operation connected with a large operation tree with 3 other constants and 3 additions. - >---C---+ - ---A---+ - >---C---+ | - +---A - >---C---+ | - ---A---+ - >---C---+ + """Return a constant operation connected with a large operation tree with 3 other constants and 3 additions. + ---C---+ + ---A---+ + ---C---+ | + +---A + ---C---+ | + ---A---+ + ---C---+ """ add_oper = Addition() add_oper_2 = Addition() diff --git a/test/fixtures/port.py b/test/fixtures/port.py new file mode 100644 index 0000000000000000000000000000000000000000..4019b3a2016aa418daeca771f9a2d8bcc4ca6652 --- /dev/null +++ b/test/fixtures/port.py @@ -0,0 +1,10 @@ +import pytest +from b_asic.port import InputPort, OutputPort + +@pytest.fixture +def input_port(): + return InputPort(0, None) + +@pytest.fixture +def output_port(): + return OutputPort(0, None) diff --git a/test/operation/test_abstract_operation.py b/test/operation/test_abstract_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..626a2dc3e5e26fb76d9266dcdd31940681df5c6e --- /dev/null +++ b/test/operation/test_abstract_operation.py @@ -0,0 +1,77 @@ +""" +B-ASIC test suite for the AbstractOperation class. +""" + +from b_asic.core_operations import Addition, ConstantAddition, Subtraction, ConstantSubtraction, \ + Multiplication, ConstantMultiplication, Division, ConstantDivision + +import pytest + + +def test_addition_overload(): + """Tests addition overloading for both operation and number argument.""" + add1 = Addition(None, None, "add1") + add2 = Addition(None, None, "add2") + + add3 = add1 + add2 + + assert isinstance(add3, Addition) + assert add3.input(0).signals == add1.output(0).signals + assert add3.input(1).signals == add2.output(0).signals + + add4 = add3 + 5 + + assert isinstance(add4, ConstantAddition) + assert add4.input(0).signals == add3.output(0).signals + + +def test_subtraction_overload(): + """Tests subtraction overloading for both operation and number argument.""" + add1 = Addition(None, None, "add1") + add2 = Addition(None, None, "add2") + + sub1 = add1 - add2 + + assert isinstance(sub1, Subtraction) + assert sub1.input(0).signals == add1.output(0).signals + assert sub1.input(1).signals == add2.output(0).signals + + sub2 = sub1 - 5 + + assert isinstance(sub2, ConstantSubtraction) + assert sub2.input(0).signals == sub1.output(0).signals + + +def test_multiplication_overload(): + """Tests multiplication overloading for both operation and number argument.""" + add1 = Addition(None, None, "add1") + add2 = Addition(None, None, "add2") + + mul1 = add1 * add2 + + assert isinstance(mul1, Multiplication) + assert mul1.input(0).signals == add1.output(0).signals + assert mul1.input(1).signals == add2.output(0).signals + + mul2 = mul1 * 5 + + assert isinstance(mul2, ConstantMultiplication) + assert mul2.input(0).signals == mul1.output(0).signals + + +def test_division_overload(): + """Tests division overloading for both operation and number argument.""" + add1 = Addition(None, None, "add1") + add2 = Addition(None, None, "add2") + + div1 = add1 / add2 + + assert isinstance(div1, Division) + assert div1.input(0).signals == add1.output(0).signals + assert div1.input(1).signals == add2.output(0).signals + + div2 = div1 / 5 + + assert isinstance(div2, ConstantDivision) + assert div2.input(0).signals == div1.output(0).signals + diff --git a/test/test_core_operations.py b/test/test_core_operations.py new file mode 100644 index 0000000000000000000000000000000000000000..b176b2a6506cc5a1297813f6ddcb6d3589492838 --- /dev/null +++ b/test/test_core_operations.py @@ -0,0 +1,227 @@ +""" +B-ASIC test suite for the core operations. +""" + +from b_asic.core_operations import Constant, Addition, Subtraction, Multiplication, Division, SquareRoot, ComplexConjugate, Max, Min, Absolute, ConstantMultiplication, ConstantAddition, ConstantSubtraction, ConstantDivision + +# Constant tests. +def test_constant(): + constant_operation = Constant(3) + assert constant_operation.evaluate() == 3 + +def test_constant_negative(): + constant_operation = Constant(-3) + assert constant_operation.evaluate() == -3 + +def test_constant_complex(): + constant_operation = Constant(3+4j) + assert constant_operation.evaluate() == 3+4j + +# Addition tests. +def test_addition(): + test_operation = Addition() + constant_operation = Constant(3) + constant_operation_2 = Constant(5) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 8 + +def test_addition_negative(): + test_operation = Addition() + constant_operation = Constant(-3) + constant_operation_2 = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -8 + +def test_addition_complex(): + test_operation = Addition() + constant_operation = Constant((3+5j)) + constant_operation_2 = Constant((4+6j)) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (7+11j) + +# Subtraction tests. +def test_subtraction(): + test_operation = Subtraction() + constant_operation = Constant(5) + constant_operation_2 = Constant(3) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 2 + +def test_subtraction_negative(): + test_operation = Subtraction() + constant_operation = Constant(-5) + constant_operation_2 = Constant(-3) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -2 + +def test_subtraction_complex(): + test_operation = Subtraction() + constant_operation = Constant((3+5j)) + constant_operation_2 = Constant((4+6j)) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (-1-1j) + +# Multiplication tests. +def test_multiplication(): + test_operation = Multiplication() + constant_operation = Constant(5) + constant_operation_2 = Constant(3) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 15 + +def test_multiplication_negative(): + test_operation = Multiplication() + constant_operation = Constant(-5) + constant_operation_2 = Constant(-3) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 15 + +def test_multiplication_complex(): + test_operation = Multiplication() + constant_operation = Constant((3+5j)) + constant_operation_2 = Constant((4+6j)) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (-18+38j) + +# Division tests. +def test_division(): + test_operation = Division() + constant_operation = Constant(30) + constant_operation_2 = Constant(5) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 6 + +def test_division_negative(): + test_operation = Division() + constant_operation = Constant(-30) + constant_operation_2 = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 6 + +def test_division_complex(): + test_operation = Division() + constant_operation = Constant((60+40j)) + constant_operation_2 = Constant((10+20j)) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == (2.8-1.6j) + +# SquareRoot tests. +def test_squareroot(): + test_operation = SquareRoot() + constant_operation = Constant(36) + assert test_operation.evaluate(constant_operation.evaluate()) == 6 + +def test_squareroot_negative(): + test_operation = SquareRoot() + constant_operation = Constant(-36) + assert test_operation.evaluate(constant_operation.evaluate()) == 6j + +def test_squareroot_complex(): + test_operation = SquareRoot() + constant_operation = Constant((48+64j)) + assert test_operation.evaluate(constant_operation.evaluate()) == (8+4j) + +# ComplexConjugate tests. +def test_complexconjugate(): + test_operation = ComplexConjugate() + constant_operation = Constant(3+4j) + assert test_operation.evaluate(constant_operation.evaluate()) == (3-4j) + +def test_test_complexconjugate_negative(): + test_operation = ComplexConjugate() + constant_operation = Constant(-3-4j) + assert test_operation.evaluate(constant_operation.evaluate()) == (-3+4j) + +# Max tests. +def test_max(): + test_operation = Max() + constant_operation = Constant(30) + constant_operation_2 = Constant(5) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 30 + +def test_max_negative(): + test_operation = Max() + constant_operation = Constant(-30) + constant_operation_2 = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -5 + +# Min tests. +def test_min(): + test_operation = Min() + constant_operation = Constant(30) + constant_operation_2 = Constant(5) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == 5 + +def test_min_negative(): + test_operation = Min() + constant_operation = Constant(-30) + constant_operation_2 = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate(), constant_operation_2.evaluate()) == -30 + +# Absolute tests. +def test_absolute(): + test_operation = Absolute() + constant_operation = Constant(30) + assert test_operation.evaluate(constant_operation.evaluate()) == 30 + +def test_absolute_negative(): + test_operation = Absolute() + constant_operation = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate()) == 5 + +def test_absolute_complex(): + test_operation = Absolute() + constant_operation = Constant((3+4j)) + assert test_operation.evaluate(constant_operation.evaluate()) == 5.0 + +# ConstantMultiplication tests. +def test_constantmultiplication(): + test_operation = ConstantMultiplication(5) + constant_operation = Constant(20) + assert test_operation.evaluate(constant_operation.evaluate()) == 100 + +def test_constantmultiplication_negative(): + test_operation = ConstantMultiplication(5) + constant_operation = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate()) == -25 + +def test_constantmultiplication_complex(): + test_operation = ConstantMultiplication(3+2j) + constant_operation = Constant((3+4j)) + assert test_operation.evaluate(constant_operation.evaluate()) == (1+18j) + +# ConstantAddition tests. +def test_constantaddition(): + test_operation = ConstantAddition(5) + constant_operation = Constant(20) + assert test_operation.evaluate(constant_operation.evaluate()) == 25 + +def test_constantaddition_negative(): + test_operation = ConstantAddition(4) + constant_operation = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate()) == -1 + +def test_constantaddition_complex(): + test_operation = ConstantAddition(3+2j) + constant_operation = Constant((3+4j)) + assert test_operation.evaluate(constant_operation.evaluate()) == (6+6j) + +# ConstantSubtraction tests. +def test_constantsubtraction(): + test_operation = ConstantSubtraction(5) + constant_operation = Constant(20) + assert test_operation.evaluate(constant_operation.evaluate()) == 15 + +def test_constantsubtraction_negative(): + test_operation = ConstantSubtraction(4) + constant_operation = Constant(-5) + assert test_operation.evaluate(constant_operation.evaluate()) == -9 + +def test_constantsubtraction_complex(): + test_operation = ConstantSubtraction(4+6j) + constant_operation = Constant((3+4j)) + assert test_operation.evaluate(constant_operation.evaluate()) == (-1-2j) + +# ConstantDivision tests. +def test_constantdivision(): + test_operation = ConstantDivision(5) + constant_operation = Constant(20) + assert test_operation.evaluate(constant_operation.evaluate()) == 4 + +def test_constantdivision_negative(): + test_operation = ConstantDivision(4) + constant_operation = Constant(-20) + assert test_operation.evaluate(constant_operation.evaluate()) == -5 + +def test_constantdivision_complex(): + test_operation = ConstantDivision(2+2j) + constant_operation = Constant((10+10j)) + assert test_operation.evaluate(constant_operation.evaluate()) == (5+0j) diff --git a/test/test_graph_id_generator.py b/test/test_graph_id_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..b14597eabe6c15695c5c452f69f3deeab56e36d5 --- /dev/null +++ b/test/test_graph_id_generator.py @@ -0,0 +1,28 @@ +""" +B-ASIC test suite for graph id generator. +""" + +from b_asic.graph_id import GraphIDGenerator, GraphID +import pytest + +@pytest.fixture +def graph_id_generator(): + return GraphIDGenerator() + +class TestGetNextId: + def test_empty_string_generator(self, graph_id_generator): + """Test the graph id generator for an empty string type.""" + assert graph_id_generator.get_next_id("") == "1" + assert graph_id_generator.get_next_id("") == "2" + + def test_normal_string_generator(self, graph_id_generator): + """"Test the graph id generator for a normal string type.""" + assert graph_id_generator.get_next_id("add") == "add1" + assert graph_id_generator.get_next_id("add") == "add2" + + def test_different_strings_generator(self, graph_id_generator): + """Test the graph id generator for different strings.""" + assert graph_id_generator.get_next_id("sub") == "sub1" + assert graph_id_generator.get_next_id("mul") == "mul1" + assert graph_id_generator.get_next_id("sub") == "sub2" + assert graph_id_generator.get_next_id("mul") == "mul2" diff --git a/test/test_inputport.py b/test/test_inputport.py new file mode 100644 index 0000000000000000000000000000000000000000..a43240693ac632b48461023536ff46b0ea379c5c --- /dev/null +++ b/test/test_inputport.py @@ -0,0 +1,95 @@ +""" +B-ASIC test suite for Inputport +""" + +import pytest + +from b_asic import InputPort, OutputPort +from b_asic import Signal + +@pytest.fixture +def inp_port(): + return InputPort(0, None) + +@pytest.fixture +def out_port(): + return OutputPort(0, None) + +@pytest.fixture +def out_port2(): + return OutputPort(1, None) + +@pytest.fixture +def dangling_sig(): + return Signal() + +@pytest.fixture +def s_w_source(): + out_port = OutputPort(0, None) + return Signal(source=out_port) + +@pytest.fixture +def sig_with_dest(): + inp_port = InputPort(0, None) + return Signal(destination=out_port) + +@pytest.fixture +def connected_sig(): + out_port = OutputPort(0, None) + inp_port = InputPort(0, None) + return Signal(source=out_port, destination=inp_port) + +def test_connect_then_disconnect(inp_port, out_port): + """Test connect unused port to port.""" + s1 = inp_port.connect(out_port) + + assert inp_port.connected_ports == [out_port] + assert out_port.connected_ports == [inp_port] + assert inp_port.signals == [s1] + assert out_port.signals == [s1] + assert s1.source is out_port + assert s1.destination is inp_port + + inp_port.remove_signal(s1) + + assert inp_port.connected_ports == [] + assert out_port.connected_ports == [] + assert inp_port.signals == [] + assert out_port.signals == [s1] + assert s1.source is out_port + assert s1.destination is None + +def test_connect_used_port_to_new_port(inp_port, out_port, out_port2): + """Does connecting multiple ports to an inputport throw error?""" + inp_port.connect(out_port) + with pytest.raises(AssertionError): + inp_port.connect(out_port2) + +def test_add_signal_then_disconnect(inp_port, s_w_source): + """Can signal be connected then disconnected properly?""" + inp_port.add_signal(s_w_source) + + assert inp_port.connected_ports == [s_w_source.source] + assert s_w_source.source.connected_ports == [inp_port] + assert inp_port.signals == [s_w_source] + assert s_w_source.source.signals == [s_w_source] + assert s_w_source.destination is inp_port + + inp_port.remove_signal(s_w_source) + + assert inp_port.connected_ports == [] + assert s_w_source.source.connected_ports == [] + assert inp_port.signals == [] + assert s_w_source.source.signals == [s_w_source] + assert s_w_source.destination is None + +def test_connect_then_disconnect(inp_port, out_port): + """Can port be connected and then disconnected properly?""" + inp_port.connect(out_port) + + inp_port.disconnect(out_port) + + print("outport signals:", out_port.signals, "count:", out_port.signal_count()) + assert inp_port.signal_count() == 1 + assert len(inp_port.connected_ports) == 0 + assert out_port.signal_count() == 0 diff --git a/test/test_operation.py b/test/test_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..6c37e30bddd0b55ea69ae5b95a341c1ddeb56847 --- /dev/null +++ b/test/test_operation.py @@ -0,0 +1,31 @@ +from b_asic.core_operations import Constant, Addition +from b_asic.signal import Signal +from b_asic.port import InputPort, OutputPort + +import pytest + +class TestTraverse: + def test_traverse_single_tree(self, operation): + """Traverse a tree consisting of one operation.""" + constant = Constant(None) + assert list(constant.traverse()) == [constant] + + def test_traverse_tree(self, operation_tree): + """Traverse a basic addition tree with two constants.""" + assert len(list(operation_tree.traverse())) == 3 + + def test_traverse_large_tree(self, large_operation_tree): + """Traverse a larger tree.""" + assert len(list(large_operation_tree.traverse())) == 7 + + def test_traverse_type(self, large_operation_tree): + traverse = list(large_operation_tree.traverse()) + assert len(list(filter(lambda type_: isinstance(type_, Addition), traverse))) == 3 + assert len(list(filter(lambda type_: isinstance(type_, Constant), traverse))) == 4 + + def test_traverse_loop(self, operation_tree): + add_oper_signal = Signal() + operation_tree._output_ports[0].add_signal(add_oper_signal) + operation_tree._input_ports[0].remove_signal(add_oper_signal) + operation_tree._input_ports[0].add_signal(add_oper_signal) + assert len(list(operation_tree.traverse())) == 2 diff --git a/test/test_outputport.py b/test/test_outputport.py new file mode 100644 index 0000000000000000000000000000000000000000..deed7a1e06836600254e3903b8b45a3d05f17cbe --- /dev/null +++ b/test/test_outputport.py @@ -0,0 +1,80 @@ +""" +B-ASIC test suite for OutputPort. +""" +from b_asic import OutputPort, InputPort, Signal +import pytest + +@pytest.fixture +def output_port(): + return OutputPort(0, None) + +@pytest.fixture +def input_port(): + return InputPort(0, None) + +@pytest.fixture +def list_of_input_ports(): + return [InputPort(_, None) for _ in range(0,3)] + +class TestConnect: + def test_multiple_ports(self, output_port, list_of_input_ports): + """Can multiple ports connect to an output port?""" + for port in list_of_input_ports: + output_port.connect(port) + + assert output_port.signal_count() == len(list_of_input_ports) + + def test_same_port(self, output_port, list_of_input_ports): + """Check error handing.""" + output_port.connect(list_of_input_ports[0]) + with pytest.raises(AssertionError): + output_port.connect(list_of_input_ports[0]) + + assert output_port.signal_count() == 2 + +class TestAddSignal: + def test_dangling(self, output_port): + s = Signal() + output_port.add_signal(s) + + assert output_port.signal_count() == 1 + + def test_with_destination(self, output_port, input_port): + s = Signal(destination=input_port) + output_port.add_signal(s) + + assert output_port.connected_ports == [s.destination] + +class TestDisconnect: + def test_multiple_ports(self, output_port, list_of_input_ports): + """Can multiple ports disconnect from OutputPort?""" + for port in list_of_input_ports: + output_port.connect(port) + + for port in list_of_input_ports: + output_port.disconnect(port) + + assert output_port.signal_count() == 3 + assert output_port.connected_ports == [] + +class TestRemoveSignal: + def test_one_signal(self, output_port, input_port): + s = output_port.connect(input_port) + output_port.remove_signal(s) + + assert output_port.signal_count() == 0 + assert output_port.signals == [] + assert output_port.connected_ports == [] + + def test_multiple_signals(self, output_port, list_of_input_ports): + """Can multiple signals disconnect from OutputPort?""" + sigs = [] + + for port in list_of_input_ports: + sigs.append(output_port.connect(port)) + + for sig in sigs: + output_port.remove_signal(sig) + + assert output_port.signal_count() == 0 + assert output_port.signals == [] diff --git a/test/test_signal.py b/test/test_signal.py new file mode 100644 index 0000000000000000000000000000000000000000..ab07eb778ddb693bfc9cfabf6aeb7804038312d5 --- /dev/null +++ b/test/test_signal.py @@ -0,0 +1,62 @@ +""" +B-ASIC test suit for the signal module which consists of the Signal class. +""" + +from b_asic.port import InputPort, OutputPort +from b_asic.signal import Signal + +import pytest + +def test_signal_creation_and_disconnction_and_connection_changing(): + in_port = InputPort(0, None) + out_port = OutputPort(1, None) + s = Signal(out_port, in_port) + + assert in_port.signals == [s] + assert out_port.signals == [s] + assert s.source is out_port + assert s.destination is in_port + + in_port1 = InputPort(0, None) + s.set_destination(in_port1) + + assert in_port.signals == [] + assert in_port1.signals == [s] + assert out_port.signals == [s] + assert s.source is out_port + assert s.destination is in_port1 + + s.remove_source() + + assert out_port.signals == [] + assert in_port1.signals == [s] + assert s.source is None + assert s.destination is in_port1 + + s.remove_destination() + + assert out_port.signals == [] + assert in_port1.signals == [] + assert s.source is None + assert s.destination is None + + out_port1 = OutputPort(0, None) + s.set_source(out_port1) + + assert out_port1.signals == [s] + assert s.source is out_port1 + assert s.destination is None + + s.set_source(out_port) + + assert out_port.signals == [s] + assert out_port1.signals == [] + assert s.source is out_port + assert s.destination is None + + s.set_destination(in_port) + + assert out_port.signals == [s] + assert in_port.signals == [s] + assert s.source is out_port + assert s.destination is in_port