diff --git a/b_asic/basic_operation.py b/b_asic/basic_operation.py index d5d03045eff9491213da6188fedfcb0de76e18a8..4f7b426b995eb390001d39a0785f3dd0841f2563 100644 --- a/b_asic/basic_operation.py +++ b/b_asic/basic_operation.py @@ -5,7 +5,7 @@ TODO: More info. from b_asic.port import InputPort, OutputPort from b_asic.signal import SignalSource, SignalDestination -from b_asic.operation import OperationId, Operation +from b_asic.operation import Operation from b_asic.simulation import SimulationState, OperationState from abc import ABC, abstractmethod from typing import List, Dict, Optional, Any @@ -18,16 +18,14 @@ class BasicOperation(Operation): TODO: More info. """ - _identifier: OperationId _input_ports: List[InputPort] _output_ports: List[OutputPort] _parameters: Dict[str, Optional[Any]] - def __init__(self, identifier: OperationId): + def __init__(self): """ Construct a BasicOperation. """ - self._identifier = identifier self._input_ports = [] self._output_ports = [] self._parameters = {} @@ -39,9 +37,6 @@ class BasicOperation(Operation): """ pass - def identifier(self) -> OperationId: - return self._identifier - def inputs(self) -> List[InputPort]: return self._input_ports.copy() @@ -102,4 +97,13 @@ class BasicOperation(Operation): return results return [self] + @property + def neighbours(self) -> List[Operation]: + neighbours: List[Operation] = [] + for port in self._output_ports + self._input_ports: + for signal in port.signals(): + neighbours += [signal.source.operation, signal.destination.operation] + + return neighbours + # TODO: More stuff. diff --git a/b_asic/core_operations.py b/b_asic/core_operations.py index c766d06c2f190b8a20763bf75f2661cb65df2fd0..bca344cf91a3c21065692a054bd0382dc77724d4 100644 --- a/b_asic/core_operations.py +++ b/b_asic/core_operations.py @@ -4,8 +4,9 @@ TODO: More info. """ from b_asic.port import InputPort, OutputPort -from b_asic.operation import OperationId, Operation +from b_asic.operation import Operation from b_asic.basic_operation import BasicOperation +from b_asic.graph_id import GraphIDType from numbers import Number @@ -15,7 +16,7 @@ class Input(Operation): TODO: More info. """ - # TODO: Implement. + # TODO: Implement all functions. pass @@ -25,17 +26,19 @@ class Constant(BasicOperation): TODO: More info. """ - def __init__(self, identifier: OperationId, value: Number): + def __init__(self, value: Number): """ Construct a Constant. """ - super().__init__(identifier) - self._output_ports = [OutputPort()] # TODO: Generate appropriate ID for ports. + super().__init__() + self._output_ports = [OutputPort(1)] # TODO: Generate appropriate ID for ports. self._parameters["value"] = value def evaluate(self, inputs: list) -> list: return [self.param("value")] + def get_op_name(self) -> GraphIDType: + return "const" class Addition(BasicOperation): """ @@ -43,17 +46,20 @@ class Addition(BasicOperation): TODO: More info. """ - def __init__(self, identifier: OperationId): + def __init__(self): """ Construct an Addition. """ - super().__init__(identifier) - self._input_ports = [InputPort(), InputPort()] # TODO: Generate appropriate ID for ports. - self._output_ports = [OutputPort()] # TODO: Generate appropriate ID for ports. + super().__init__() + self._input_ports = [InputPort(1), InputPort(1)] # TODO: Generate appropriate ID for ports. + self._output_ports = [OutputPort(1)] # TODO: Generate appropriate ID for ports. def evaluate(self, inputs: list) -> list: return [inputs[0] + inputs[1]] + def get_op_name(self) -> GraphIDType: + return "add" + class ConstantMultiplication(BasicOperation): """ @@ -61,16 +67,19 @@ class ConstantMultiplication(BasicOperation): TODO: More info. """ - def __init__(self, identifier: OperationId, coefficient: Number): + def __init__(self, coefficient: Number): """ Construct a ConstantMultiplication. """ - super().__init__(identifier) - self._input_ports = [InputPort()] # TODO: Generate appropriate ID for ports. - self._output_ports = [OutputPort()] # TODO: Generate appropriate ID for ports. + super().__init__() + self._input_ports = [InputPort(1)] # TODO: Generate appropriate ID for ports. + self._output_ports = [OutputPort(1)] # TODO: Generate appropriate ID for ports. self._parameters["coefficient"] = coefficient def evaluate(self, inputs: list) -> list: return [inputs[0] * self.param("coefficient")] + def get_op_name(self) -> GraphIDType: + return "const_mul" + # TODO: More operations. diff --git a/b_asic/graph_id.py b/b_asic/graph_id.py new file mode 100644 index 0000000000000000000000000000000000000000..3f25f5139d9a8388127962015680e73cf01751c0 --- /dev/null +++ b/b_asic/graph_id.py @@ -0,0 +1,30 @@ +""" +B-ASIC Graph ID module for handling IDs of different objects in a graph. +TODO: More info +""" + +from collections import defaultdict +from typing import NewType, Union, DefaultDict + +GraphID = NewType("GraphID", str) +GraphIDType = NewType("GraphIDType", str) +GraphIDNumber = NewType("GraphIDNumber", int) + +class GraphIDGenerator: + """ + A class that generates Graph IDs for objects. + """ + + _next_id_number: DefaultDict[GraphIDType, GraphIDNumber] + + def __init__(self): + self._next_id_number = defaultdict(lambda: 1) # Initalises every key element to 1 + + def get_next_id(self, graph_id_type: GraphIDType) -> GraphID: + """ + Returns the next graph id for a certain graph id type. + """ + graph_id = graph_id_type + str(self._next_id_number[graph_id_type]) + self._next_id_number[graph_id_type] += 1 # Increase the current id number + return graph_id + diff --git a/b_asic/operation.py b/b_asic/operation.py index 731822f1d0cbc06e5ad244c878c5dbaf8847b905..f02cd70054880fb533fa4a2a292fbdaf222ee085 100644 --- a/b_asic/operation.py +++ b/b_asic/operation.py @@ -10,9 +10,7 @@ from typing import NewType, List, Dict, Optional, Any, TYPE_CHECKING if TYPE_CHECKING: from b_asic.port import InputPort, OutputPort from b_asic.simulation import SimulationState - -OperationId = NewType("OperationId", int) - + from b_asic.graph_id import GraphIDType class Operation(ABC): """ @@ -20,14 +18,6 @@ class Operation(ABC): TODO: More info. """ - @abstractmethod - def identifier(self) -> OperationId: - """ - Get the unique identifier. - TODO: Move id info to SFG, remove id class members. - """ - pass - @abstractmethod def inputs(self) -> "List[InputPort]": """ @@ -109,5 +99,17 @@ class Operation(ABC): """ pass + @abstractmethod + def get_op_name(self) -> "GraphIDType": + """Returns a string representing the operation name of the operation.""" + pass + + @abstractmethod + def neighbours(self) -> "List[Operation]": + """ + Return all operations that are connected by signals to this operation. + If no neighbours are found this returns an empty list + """ + # TODO: More stuff. diff --git a/b_asic/signal.py b/b_asic/signal.py index 4fac563faf48cb6a6a7ea585cf7aea44ce259758..6ef55c8d076320d52e5044296af0e3e65617728c 100644 --- a/b_asic/signal.py +++ b/b_asic/signal.py @@ -6,9 +6,6 @@ TODO: More info. from b_asic.operation import Operation from typing import NewType -SignalId = NewType("SignalId", int) - - class SignalSource: """ Handle to a signal source. @@ -50,22 +47,14 @@ class Signal: A connection between two operations consisting of a source and destination handle. TODO: More info. """ - _identifier: SignalId source: SignalSource destination: SignalDestination - def __init__(self, identifier: SignalId, source: SignalSource, destination: SignalDestination): + def __init__(self, source: SignalSource, destination: SignalDestination): """ Construct a Signal. """ - self._identifier = identifier self.source = source self.destination = destination - def identifier(self) -> SignalId: - """ - Get the unique identifier. - """ - return self._identifier - # TODO: More stuff. diff --git a/b_asic/signal_flow_graph.py b/b_asic/signal_flow_graph.py index f9671636dd345fadfebd2a116a9add2466615db8..9d31b04b82509a5749f9899a25cf2091c8492c9e 100644 --- a/b_asic/signal_flow_graph.py +++ b/b_asic/signal_flow_graph.py @@ -3,12 +3,13 @@ B-ASIC Signal Flow Graph Module. TODO: More info. """ -from b_asic.operation import OperationId, Operation +from b_asic.operation import Operation from b_asic.basic_operation import BasicOperation -from b_asic.signal import SignalSource, SignalDestination +from b_asic.signal import Signal, SignalSource, SignalDestination from b_asic.simulation import SimulationState, OperationState -from typing import List +from b_asic.graph_id import GraphIDGenerator, GraphID +from typing import List, Dict, Union, Optional class SFG(BasicOperation): """ @@ -16,22 +17,60 @@ class SFG(BasicOperation): TODO: More info. """ - _operations: List[Operation] + _graph_objects_by_id: Dict[GraphID, Union[Operation, Signal]] + _graph_id_generator: GraphIDGenerator - def __init__(self, identifier: OperationId, input_destinations: List[SignalDestination], output_sources: List[SignalSource]): - """ - Construct a SFG. - """ - super().__init__(identifier) + def __init__(self, input_destinations: List[SignalDestination], output_sources: List[SignalSource]): + """Constructs an SFG.""" + super().__init__() # TODO: Allocate input/output ports with appropriate IDs. - self._operations = [] + + self._graph_objects_by_id = dict # Map Operation ID to Operation objects + self._graph_id_generator = GraphIDGenerator() + # TODO: Traverse the graph between the inputs/outputs and add to self._operations. # TODO: Connect ports with signals with appropriate IDs. def evaluate(self, inputs: list) -> list: return [] # TODO: Implement - def split(self) -> List[Operation]: - return self._operations + def add_operation(self, operation: Operation) -> GraphID: + """Adds the entered operation to the SFG's dictionary of graph objects and + returns a generated GraphID for it. + + Keyword arguments: + operation: Operation to add to the graph. + """ + return self._add_graph_obj(operation, operation.get_op_name()) + + + def add_signal(self, signal: Signal) -> GraphID: + """Adds the entered signal to the SFG's dictionary of graph objects and returns + a generated GraphID for it. + + Keyword argumentst: + signal: Signal to add to the graph. + """ + return self._add_graph_obj(signal, 'sig') + + + def find_by_id(self, graph_id: GraphID) -> Optional[Operation]: + """Finds a graph object based on the entered Graph ID and returns it. If no graph + object with the entered ID was found then returns None. + + Keyword arguments: + graph_id: Graph ID of the wanted object. + """ + if graph_id in self._graph_objects_by_id: + return self._graph_objects_by_id[graph_id] + else: + return None + + + + def _add_graph_obj(self, obj: Union[Operation, Signal], operation_id_type: str): + graph_id = self._graph_id_generator.get_next_id(operation_id_type) + self._graph_objects_by_id[graph_id] = obj + return graph_id + - # TODO: More stuff. diff --git a/b_asic/simulation.py b/b_asic/simulation.py index e219445b38abf7fc755295a4ccf8b6284ce6651b..d3d3aaf6913844f865cc47930802f03b2c6f0d2f 100644 --- a/b_asic/simulation.py +++ b/b_asic/simulation.py @@ -3,7 +3,6 @@ B-ASIC Simulation Module. TODO: More info. """ -from b_asic.operation import OperationId from numbers import Number from typing import List, Dict @@ -31,7 +30,7 @@ class SimulationState: TODO: More info. """ - operation_states: Dict[OperationId, OperationState] + # operation_states: Dict[OperationId, OperationState] iteration: int def __init__(self): diff --git a/b_asic/traverse_tree.py b/b_asic/traverse_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..024a542d558dd7fd140b54120f19ad12c623f9f9 --- /dev/null +++ b/b_asic/traverse_tree.py @@ -0,0 +1,54 @@ +""" +B-ASIC Operation Tree Traversing Module. +TODO: + - Get a first operation or? an entire operation tree + - For each start point, follow it to the next operation from it's out port. + - If we are searching for a specific operation end. + - If we are searching for a specific type of operation add the operation to a list and continue. + - When we no more out ports can be traversed return results and end. +""" + +from typing import List, Optional +from collections import deque + +from b_asic.operation import Operation + + +class Traverse: + """Traverse operation tree. + TODO: + - More info. + - Check if a datastructure other than list suits better as return value. + - Implement the type check for operation. + """ + + def __init__(self, operation: Operation): + """Construct a TraverseTree.""" + self._initial_operation = operation + + def _breadth_first_search(self, start: Operation) -> List[Operation]: + """Use breadth first search to traverse the operation tree.""" + visited: List[Operation] = [start] + queue = deque([start]) + while queue: + operation = queue.popleft() + for n_operation in operation.neighbours: + if n_operation not in visited: + visited.append(n_operation) + queue.append(n_operation) + + return visited + + def traverse(self, type_: Optional[Operation] = None) -> List[Operation]: + """Traverse the the operation tree and return operation where type matches. + If the type is None then return the entire tree. + + Keyword arguments: + type_-- the operation type to search for (default None) + """ + + operations: List[Operation] = self._breadth_first_search(self._initial_operation) + if type_ is not None: + operations = [oper for oper in operations if isinstance(oper, type_)] + + return operations diff --git a/test/fixtures/signal.py b/test/fixtures/signal.py index 5fbdcf2b4e5e50c9728fb50529835a6fb501fc4c..64b96f5541c7866c4540cd3052a5cbeaea5e4400 100644 --- a/test/fixtures/signal.py +++ b/test/fixtures/signal.py @@ -6,15 +6,15 @@ Use a fixture for initializing objects and pass them as argument to a test funct """ @pytest.fixture def signal(): - source = SignalSource(Addition(0), 1) - dest = SignalDestination(Addition(1), 2) - return Signal(0, source, dest) + source = SignalSource(Addition(), 1) + dest = SignalDestination(Addition(), 2) + return Signal(source, dest) @pytest.fixture def signals(): ret = [] - for i in range(0,3): - source = SignalSource(Addition(0), 1) - dest = SignalDestination(Addition(1), 2) - ret.append(Signal(i, source, dest)) + for _ in range(0,3): + source = SignalSource(Addition(), 1) + dest = SignalDestination(Addition(), 2) + ret.append(Signal(source, dest)) return ret \ No newline at end of file diff --git a/test/graph_id/conftest.py b/test/graph_id/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..5871ed8eef2f90304e1f64c12ba17e1915250724 --- /dev/null +++ b/test/graph_id/conftest.py @@ -0,0 +1 @@ +import pytest diff --git a/test/graph_id/test_graph_id_generator.py b/test/graph_id/test_graph_id_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..7aeb6cad27e43233a88eb69e58bd89f78a863c5b --- /dev/null +++ b/test/graph_id/test_graph_id_generator.py @@ -0,0 +1,29 @@ +""" +B-ASIC test suite for graph id generator. +""" + +from b_asic.graph_id import GraphIDGenerator, GraphID + +import pytest + +def test_empty_string_generator(): + """Test the graph id generator for an empty string type.""" + graph_id_generator = GraphIDGenerator() + assert graph_id_generator.get_next_id("") == "1" + assert graph_id_generator.get_next_id("") == "2" + + +def test_normal_string_generator(): + """"Test the graph id generator for a normal string type.""" + graph_id_generator = GraphIDGenerator() + assert graph_id_generator.get_next_id("add") == "add1" + assert graph_id_generator.get_next_id("add") == "add2" + +def test_different_strings_generator(): + """Test the graph id generator for different strings.""" + graph_id_generator = GraphIDGenerator() + assert graph_id_generator.get_next_id("sub") == "sub1" + assert graph_id_generator.get_next_id("mul") == "mul1" + assert graph_id_generator.get_next_id("sub") == "sub2" + assert graph_id_generator.get_next_id("mul") == "mul2" + \ No newline at end of file diff --git a/test/port/test_port.py b/test/port/test_port.py index 56cb9be227149c957bf77f73ed4f1301e54fac16..7e1fc9b7589a4955966d3333d336bdb6f3245014 100644 --- a/test/port/test_port.py +++ b/test/port/test_port.py @@ -13,10 +13,10 @@ def test_connect_one_signal_to_port(signal): assert port.signal() == signal def test_change_port_signal(): - source = SignalSource(Addition(0), 1) - dest = SignalDestination(Addition(1),2) - signal1 = Signal(1, source, dest) - signal2 = Signal(2, source, dest) + source = SignalSource(Addition, 1) + dest = SignalDestination(Addition,2) + signal1 = Signal(source, dest) + signal2 = Signal(source, dest) port = InputPort(0) port.connect(signal1) diff --git a/test/signal_flow_graph/conftest.py b/test/signal_flow_graph/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..5871ed8eef2f90304e1f64c12ba17e1915250724 --- /dev/null +++ b/test/signal_flow_graph/conftest.py @@ -0,0 +1 @@ +import pytest diff --git a/test/signal_flow_graph/test_signal_flow_graph.py b/test/signal_flow_graph/test_signal_flow_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..921e8906ff277b85f7d53e68cd55be338c778419 --- /dev/null +++ b/test/signal_flow_graph/test_signal_flow_graph.py @@ -0,0 +1,3 @@ +from b_asic.signal_flow_graph import SFG +from b_asic.core_operations import Addition, Constant +from b_asic.signal import Signal diff --git a/test/traverse/test_traverse_tree.py b/test/traverse/test_traverse_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..d218a6e74aed06fee9691ae78e754367cc6f2988 --- /dev/null +++ b/test/traverse/test_traverse_tree.py @@ -0,0 +1,85 @@ +""" +TODO: + - Rewrite to more clean code, not so repetitive + - Update when signals and id's has been merged. +""" + +from b_asic.core_operations import Constant, Addition +from b_asic.signal import Signal, SignalSource, SignalDestination +from b_asic.port import InputPort, OutputPort +from b_asic.traverse_tree import Traverse + +import pytest + +@pytest.fixture +def operation(): + return Constant(2) + +def create_operation(_type, dest_oper, index, **kwargs): + oper = _type(**kwargs) + oper_signal_source = SignalSource(oper, 0) + oper_signal_dest = SignalDestination(dest_oper, index) + oper_signal = Signal(oper_signal_source, oper_signal_dest) + oper._output_ports[0].connect(oper_signal) + dest_oper._input_ports[index].connect(oper_signal) + return oper + +@pytest.fixture +def operation_tree(): + add_oper = Addition() + + const_oper = create_operation(Constant, add_oper, 0, value=2) + const_oper_2 = create_operation(Constant, add_oper, 1, value=3) + + return add_oper + +@pytest.fixture +def large_operation_tree(): + add_oper = Addition() + add_oper_2 = Addition() + + const_oper = create_operation(Constant, add_oper, 0, value=2) + const_oper_2 = create_operation(Constant, add_oper, 1, value=3) + + const_oper_3 = create_operation(Constant, add_oper_2, 0, value=4) + const_oper_4 = create_operation(Constant, add_oper_2, 1, value=5) + + add_oper_3 = Addition() + add_oper_signal_source = SignalSource(add_oper, 0) + add_oper_signal_dest = SignalDestination(add_oper_3, 0) + add_oper_signal = Signal(add_oper_signal_source, add_oper_signal_dest) + add_oper._output_ports[0].connect(add_oper_signal) + add_oper_3._input_ports[0].connect(add_oper_signal) + + add_oper_2_signal_source = SignalSource(add_oper_2, 0) + add_oper_2_signal_dest = SignalDestination(add_oper_3, 1) + add_oper_2_signal = Signal(add_oper_2_signal_source, add_oper_2_signal_dest) + add_oper_2._output_ports[0].connect(add_oper_2_signal) + add_oper_3._input_ports[1].connect(add_oper_2_signal) + return const_oper + +def test_traverse_single_tree(operation): + traverse = Traverse(operation) + assert traverse.traverse() == [operation] + +def test_traverse_tree(operation_tree): + traverse = Traverse(operation_tree) + assert len(traverse.traverse()) == 3 + +def test_traverse_large_tree(large_operation_tree): + traverse = Traverse(large_operation_tree) + assert len(traverse.traverse()) == 7 + +def test_traverse_type(large_operation_tree): + traverse = Traverse(large_operation_tree) + assert len(traverse.traverse(Addition)) == 3 + assert len(traverse.traverse(Constant)) == 4 + +def test_traverse_loop(operation_tree): + add_oper_signal_source = SignalSource(operation_tree, 0) + add_oper_signal_dest = SignalDestination(operation_tree, 0) + add_oper_signal = Signal(add_oper_signal_source, add_oper_signal_dest) + operation_tree._output_ports[0].connect(add_oper_signal) + operation_tree._input_ports[0].connect(add_oper_signal) + traverse = Traverse(operation_tree) + assert len(traverse.traverse()) == 2 \ No newline at end of file