diff --git a/b_asic/__init__.py b/b_asic/__init__.py index 4ae6652b9acfe05f80c7a9224de470cf5ffbd09c..fc787edfb699d24f8751655692fb05ddac83f0d7 100644 --- a/b_asic/__init__.py +++ b/b_asic/__init__.py @@ -14,4 +14,4 @@ from b_asic.schema import * from b_asic.signal_flow_graph import * from b_asic.signal import * from b_asic.simulation import * -from b_asic.traverse_tree import * +from b_asic.utilities import * diff --git a/b_asic/abstract_operation.py b/b_asic/abstract_operation.py index ab8438a58218862a987fdb88026a39a1df7124f9..558022a7b12304c2052d3263f20c9d2339ad8531 100644 --- a/b_asic/abstract_operation.py +++ b/b_asic/abstract_operation.py @@ -4,13 +4,14 @@ TODO: More info. """ from abc import abstractmethod -from typing import List, Dict, Optional, Any +from typing import List, Set, Dict, Optional, Any from numbers import Number from b_asic.port import InputPort, OutputPort from b_asic.signal import Signal from b_asic.operation import Operation from b_asic.simulation import SimulationState, OperationState +from b_asic.utilities import breadth_first_search from b_asic.abstract_graph_component import AbstractGraphComponent class AbstractOperation(Operation, AbstractGraphComponent): @@ -106,4 +107,8 @@ class AbstractOperation(Operation, AbstractGraphComponent): return neighbours + def traverse(self) -> Operation: + """Traverse the operation tree and return a generator with start point in the operation.""" + return breadth_first_search(self) + # TODO: More stuff. diff --git a/b_asic/traverse_tree.py b/b_asic/traverse_tree.py deleted file mode 100644 index dc00371eaddbbaba0592d31325dbdda9efad09f7..0000000000000000000000000000000000000000 --- a/b_asic/traverse_tree.py +++ /dev/null @@ -1,43 +0,0 @@ -"""@package docstring -B-ASIC Operation Tree Traversing Module. -""" - -from typing import List, Optional -from collections import deque - -from b_asic.operation import Operation - - -class Traverse: - """Traverse operation tree.""" - - def __init__(self, operation: Operation): - """Construct a TraverseTree.""" - self._initial_operation = operation - - def _breadth_first_search(self, start: Operation) -> List[Operation]: - """Use breadth first search to traverse the operation tree.""" - visited: List[Operation] = [start] - queue = deque([start]) - while queue: - operation = queue.popleft() - for n_operation in operation.neighbours: - if n_operation not in visited: - visited.append(n_operation) - queue.append(n_operation) - - return visited - - def traverse(self, type_: Optional[Operation] = None) -> List[Operation]: - """Traverse the the operation tree and return operation where type matches. - If the type is None then return the entire tree. - - Keyword arguments: - type_-- the operation type to search for (default None) - """ - - operations: List[Operation] = self._breadth_first_search(self._initial_operation) - if type_ is not None: - operations = [oper for oper in operations if isinstance(oper, type_)] - - return operations diff --git a/b_asic/utilities.py b/b_asic/utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..25707ff8ceceb241bae93081387910220da25d18 --- /dev/null +++ b/b_asic/utilities.py @@ -0,0 +1,21 @@ +"""@package docstring +B-ASIC Operation Module. +TODO: More info. +""" + +from typing import Set +from collections import deque + +from b_asic.operation import Operation + +def breadth_first_search(start: Operation) -> Operation: + """Use breadth first search to traverse the operation tree.""" + visited: Set[Operation] = {start} + queue = deque([start]) + while queue: + operation = queue.popleft() + yield operation + for n_operation in operation.neighbours: + if n_operation not in visited: + visited.add(n_operation) + queue.append(n_operation) diff --git a/test/conftest.py b/test/conftest.py index 986af94cc7341f48ba736e6f9d934c8eb706c079..66ee9630ea4ac0a05b446f4dedbfe68549a1191e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,2 +1,3 @@ import pytest -from test.fixtures.signal import * \ No newline at end of file +from test.fixtures.signal import * +from test.fixtures.operation_tree import * \ No newline at end of file diff --git a/test/fixtures/operation_tree.py b/test/fixtures/operation_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..bcc0bb27fac9fa8d1a024d67373469c52f8a45dc --- /dev/null +++ b/test/fixtures/operation_tree.py @@ -0,0 +1,44 @@ +from b_asic.core_operations import Addition, Constant +from b_asic.signal import Signal + +import pytest + +@pytest.fixture +def operation(): + return Constant(2) + +def create_operation(_type, dest_oper, index, **kwargs): + oper = _type(**kwargs) + oper_signal = Signal() + oper._output_ports[0].connect(oper_signal) + + dest_oper._input_ports[index].connect(oper_signal) + return oper + +@pytest.fixture +def operation_tree(): + add_oper = Addition() + create_operation(Constant, add_oper, 0, value=2) + create_operation(Constant, add_oper, 1, value=3) + return add_oper + +@pytest.fixture +def large_operation_tree(): + add_oper = Addition() + add_oper_2 = Addition() + + const_oper = create_operation(Constant, add_oper, 0, value=2) + create_operation(Constant, add_oper, 1, value=3) + + create_operation(Constant, add_oper_2, 0, value=4) + create_operation(Constant, add_oper_2, 1, value=5) + + add_oper_3 = Addition() + add_oper_signal = Signal(add_oper, add_oper_3) + add_oper._output_ports[0].connect(add_oper_signal) + add_oper_3._input_ports[0].connect(add_oper_signal) + + add_oper_2_signal = Signal(add_oper_2, add_oper_3) + add_oper_2._output_ports[0].connect(add_oper_2_signal) + add_oper_3._input_ports[1].connect(add_oper_2_signal) + return const_oper diff --git a/test/traverse/test_traverse_tree.py b/test/traverse/test_traverse_tree.py index 57e8a67befc512146859a8999152ff5c679b4588..2c1d08fe555df06ff86845812ca5df4fef4b5c92 100644 --- a/test/traverse/test_traverse_tree.py +++ b/test/traverse/test_traverse_tree.py @@ -1,78 +1,29 @@ -""" -TODO: - - Rewrite to more clean code, not so repetitive - - Update when signals and id's has been merged. -""" - from b_asic.core_operations import Constant, Addition -from b_asic.signal import Signal +from b_asic.signal import Signal from b_asic.port import InputPort, OutputPort -from b_asic.traverse_tree import Traverse import pytest -@pytest.fixture -def operation(): - return Constant(2) - -def create_operation(_type, dest_oper, index, **kwargs): - oper = _type(**kwargs) - oper_signal = Signal() - oper._output_ports[0].connect(oper_signal) - - dest_oper._input_ports[index].connect(oper_signal) - return oper - -@pytest.fixture -def operation_tree(): - add_oper = Addition() - - const_oper = create_operation(Constant, add_oper, 0, value=2) - const_oper_2 = create_operation(Constant, add_oper, 1, value=3) - - return add_oper - -@pytest.fixture -def large_operation_tree(): - add_oper = Addition() - add_oper_2 = Addition() - - const_oper = create_operation(Constant, add_oper, 0, value=2) - const_oper_2 = create_operation(Constant, add_oper, 1, value=3) - - const_oper_3 = create_operation(Constant, add_oper_2, 0, value=4) - const_oper_4 = create_operation(Constant, add_oper_2, 1, value=5) - - add_oper_3 = Addition() - add_oper_signal = Signal(add_oper, add_oper_3) - add_oper._output_ports[0].connect(add_oper_signal) - add_oper_3._input_ports[0].connect(add_oper_signal) - - add_oper_2_signal = Signal(add_oper_2, add_oper_3) - add_oper_2._output_ports[0].connect(add_oper_2_signal) - add_oper_3._input_ports[1].connect(add_oper_2_signal) - return const_oper - def test_traverse_single_tree(operation): - traverse = Traverse(operation) - assert traverse.traverse() == [operation] + """Traverse a tree consisting of one operation.""" + constant = Constant(None) + assert list(constant.traverse()) == [constant] def test_traverse_tree(operation_tree): - traverse = Traverse(operation_tree) - assert len(traverse.traverse()) == 3 + """Traverse a basic addition tree with two constants.""" + assert len(list(operation_tree.traverse())) == 3 def test_traverse_large_tree(large_operation_tree): - traverse = Traverse(large_operation_tree) - assert len(traverse.traverse()) == 7 + """Traverse a larger tree.""" + assert len(list(large_operation_tree.traverse())) == 7 def test_traverse_type(large_operation_tree): - traverse = Traverse(large_operation_tree) - assert len(traverse.traverse(Addition)) == 3 - assert len(traverse.traverse(Constant)) == 4 + traverse = list(large_operation_tree.traverse()) + assert len(list(filter(lambda type_: isinstance(type_, Addition), traverse))) == 3 + assert len(list(filter(lambda type_: isinstance(type_, Constant), traverse))) == 4 def test_traverse_loop(operation_tree): add_oper_signal = Signal() operation_tree._output_ports[0].connect(add_oper_signal) operation_tree._input_ports[0].connect(add_oper_signal) - traverse = Traverse(operation_tree) - assert len(traverse.traverse()) == 2 \ No newline at end of file + assert len(list(operation_tree.traverse())) == 2