Skip to content
Snippets Groups Projects
Commit ba89f480 authored by Jacob Wahlman's avatar Jacob Wahlman :ok_hand:
Browse files

Merged operation id system with traversing and signal

parents 0133ee94 664b8044
No related branches found
No related tags found
1 merge request!2Integrated ID system, traversing and som signal tests
Pipeline #10016 passed
......@@ -36,7 +36,7 @@ class BasicOperation(Operation):
Evaluate the operation and generate a list of output values given a list of input values.
"""
pass
def inputs(self) -> List[InputPort]:
return self._input_ports.copy()
......@@ -97,4 +97,13 @@ class BasicOperation(Operation):
return results
return [self]
@property
def neighbours(self) -> List[Operation]:
neighbours: List[Operation] = []
for port in self._output_ports + self._input_ports:
for signal in port.signals():
neighbours += [signal.source.operation, signal.destination.operation]
return neighbours
# TODO: More stuff.
......@@ -30,8 +30,8 @@ class Constant(BasicOperation):
"""
Construct a Constant.
"""
super().__init__(identifier)
self._output_ports = [OutputPort()] # TODO: Generate appropriate ID for ports.
super().__init__()
self._output_ports = [OutputPort(1)] # TODO: Generate appropriate ID for ports.
self._parameters["value"] = value
def evaluate(self, inputs: list) -> list:
......@@ -50,7 +50,7 @@ class Addition(BasicOperation):
"""
Construct an Addition.
"""
super().__init__(self)
super().__init__()
self._input_ports = [InputPort(1), InputPort(1)] # TODO: Generate appropriate ID for ports.
self._output_ports = [OutputPort(1)] # TODO: Generate appropriate ID for ports.
......@@ -59,7 +59,7 @@ class Addition(BasicOperation):
def get_op_name(self) -> GraphIDType:
return "add"
class ConstantMultiplication(BasicOperation):
"""
......@@ -71,7 +71,7 @@ class ConstantMultiplication(BasicOperation):
"""
Construct a ConstantMultiplication.
"""
super().__init__(identifier)
super().__init__()
self._input_ports = [InputPort(1)] # TODO: Generate appropriate ID for ports.
self._output_ports = [OutputPort(1)] # TODO: Generate appropriate ID for ports.
self._parameters["coefficient"] = coefficient
......
......@@ -12,7 +12,7 @@ GraphIDNumber = NewType("GraphIDNumber", int)
class GraphIDGenerator:
"""
A class that generates Graph IDs for objects.
A class that generates Graph IDs for objects.
"""
_next_id_number: DefaultDict[GraphIDType, GraphIDNumber]
......
......@@ -88,7 +88,7 @@ class Operation(ABC):
"""
Simulate the circuit until its iteration count matches that of the simulation state,
then return the resulting output vector.
"""
"""
pass
@abstractmethod
......@@ -104,5 +104,12 @@ class Operation(ABC):
"""Returns a string representing the operation name of the operation."""
pass
@abstractmethod
def neighbours(self) -> "List[Operation]":
"""
Return all operations that are connected by signals to this operation.
If no neighbours are found this returns an empty list
"""
# TODO: More stuff.
"""
B-ASIC Operation Tree Traversing Module.
TODO:
- Get a first operation or? an entire operation tree
- For each start point, follow it to the next operation from it's out port.
- If we are searching for a specific operation end.
- If we are searching for a specific type of operation add the operation to a list and continue.
- When we no more out ports can be traversed return results and end.
"""
from typing import List, Optional
from collections import deque
from b_asic.operation import Operation
class Traverse:
"""Traverse operation tree.
TODO:
- More info.
- Check if a datastructure other than list suits better as return value.
- Implement the type check for operation.
"""
def __init__(self, operation: Operation):
"""Construct a TraverseTree."""
self._initial_operation = operation
def _breadth_first_search(self, start: Operation) -> List[Operation]:
"""Use breadth first search to traverse the operation tree."""
visited: List[Operation] = [start]
queue = deque([start])
while queue:
operation = queue.popleft()
for n_operation in operation.neighbours:
if n_operation not in visited:
visited.append(n_operation)
queue.append(n_operation)
return visited
def traverse(self, type_: Optional[Operation] = None) -> List[Operation]:
"""Traverse the the operation tree and return operation where type matches.
If the type is None then return the entire tree.
Keyword arguments:
type_-- the operation type to search for (default None)
"""
operations: List[Operation] = self._breadth_first_search(self._initial_operation)
if type_ is not None:
operations = [oper for oper in operations if isinstance(oper, type_)]
return operations
......@@ -6,15 +6,15 @@ Use a fixture for initializing objects and pass them as argument to a test funct
"""
@pytest.fixture
def signal():
source = SignalSource(Addition(0), 1)
dest = SignalDestination(Addition(1), 2)
return Signal(0, source, dest)
source = SignalSource(Addition(), 1)
dest = SignalDestination(Addition(), 2)
return Signal(source, dest)
@pytest.fixture
def signals():
ret = []
for i in range(0,3):
source = SignalSource(Addition(0), 1)
dest = SignalDestination(Addition(1), 2)
ret.append(Signal(i, source, dest))
for _ in range(0,3):
source = SignalSource(Addition(), 1)
dest = SignalDestination(Addition(), 2)
ret.append(Signal(source, dest))
return ret
\ No newline at end of file
......@@ -15,8 +15,8 @@ def test_connect_one_signal_to_port(signal):
def test_change_port_signal():
source = SignalSource(Addition, 1)
dest = SignalDestination(Addition,2)
signal1 = Signal(1, source, dest)
signal2 = Signal(2, source, dest)
signal1 = Signal(source, dest)
signal2 = Signal(source, dest)
port = InputPort(0)
port.connect(signal1)
......
"""
TODO:
- Rewrite to more clean code, not so repetitive
- Update when signals and id's has been merged.
"""
from b_asic.core_operations import Constant, Addition
from b_asic.signal import Signal, SignalSource, SignalDestination
from b_asic.port import InputPort, OutputPort
from b_asic.traverse_tree import Traverse
import pytest
@pytest.fixture
def operation():
return Constant(2)
def create_operation(_type, dest_oper, index, **kwargs):
oper = _type(**kwargs)
oper_signal_source = SignalSource(oper, 0)
oper_signal_dest = SignalDestination(dest_oper, index)
oper_signal = Signal(oper_signal_source, oper_signal_dest)
oper._output_ports[0].connect(oper_signal)
dest_oper._input_ports[index].connect(oper_signal)
return oper
@pytest.fixture
def operation_tree():
add_oper = Addition()
const_oper = create_operation(Constant, add_oper, 0, value=2)
const_oper_2 = create_operation(Constant, add_oper, 1, value=3)
return add_oper
@pytest.fixture
def large_operation_tree():
add_oper = Addition()
add_oper_2 = Addition()
const_oper = create_operation(Constant, add_oper, 0, value=2)
const_oper_2 = create_operation(Constant, add_oper, 1, value=3)
const_oper_3 = create_operation(Constant, add_oper_2, 0, value=4)
const_oper_4 = create_operation(Constant, add_oper_2, 1, value=5)
add_oper_3 = Addition()
add_oper_signal_source = SignalSource(add_oper, 0)
add_oper_signal_dest = SignalDestination(add_oper_3, 0)
add_oper_signal = Signal(add_oper_signal_source, add_oper_signal_dest)
add_oper._output_ports[0].connect(add_oper_signal)
add_oper_3._input_ports[0].connect(add_oper_signal)
add_oper_2_signal_source = SignalSource(add_oper_2, 0)
add_oper_2_signal_dest = SignalDestination(add_oper_3, 1)
add_oper_2_signal = Signal(add_oper_2_signal_source, add_oper_2_signal_dest)
add_oper_2._output_ports[0].connect(add_oper_2_signal)
add_oper_3._input_ports[1].connect(add_oper_2_signal)
return const_oper
def test_traverse_single_tree(operation):
traverse = Traverse(operation)
assert traverse.traverse() == [operation]
def test_traverse_tree(operation_tree):
traverse = Traverse(operation_tree)
assert len(traverse.traverse()) == 3
def test_traverse_large_tree(large_operation_tree):
traverse = Traverse(large_operation_tree)
assert len(traverse.traverse()) == 7
def test_traverse_type(large_operation_tree):
traverse = Traverse(large_operation_tree)
assert len(traverse.traverse(Addition)) == 3
assert len(traverse.traverse(Constant)) == 4
def test_traverse_loop(operation_tree):
add_oper_signal_source = SignalSource(operation_tree, 0)
add_oper_signal_dest = SignalDestination(operation_tree, 0)
add_oper_signal = Signal(add_oper_signal_source, add_oper_signal_dest)
operation_tree._output_ports[0].connect(add_oper_signal)
operation_tree._input_ports[0].connect(add_oper_signal)
traverse = Traverse(operation_tree)
assert len(traverse.traverse()) == 2
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment