Skip to content
Snippets Groups Projects
Commit f69a4f3c authored by Adam Jakobsson's avatar Adam Jakobsson Committed by Felix Goding
Browse files

Refactor constructor so that Input signals and Output signals are connected to...

Refactor constructor so that Input signals and Output signals are connected to ports before traversal is started, that way edge cases of empty SFG's are easily handled
parent afe22efb
No related branches found
No related tags found
2 merge requests!67WIP: B-ASIC version 1.0.0 hotfix,!65B-ASIC version 1.0.0
......@@ -166,13 +166,14 @@ class AbstractOperation(Operation, AbstractGraphComponent):
if input_sources is not None:
source_count = len(input_sources)
if source_count != input_count:
raise ValueError(f"Operation expected {input_count} input sources but only got {source_count}")
raise ValueError(
f"Operation expected {input_count} input sources but only got {source_count}")
for i, src in enumerate(input_sources):
if src is not None:
self._input_ports[i].connect(src.source)
@abstractmethod
def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ
def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ
"""Evaluate the operation and generate a list of output values given a
list of input values.
"""
......@@ -246,11 +247,13 @@ class AbstractOperation(Operation, AbstractGraphComponent):
result = self.evaluate(*input_values)
if isinstance(result, collections.Sequence):
if len(result) != self.output_count:
raise RuntimeError("Operation evaluated to incorrect number of outputs")
raise RuntimeError(
"Operation evaluated to incorrect number of outputs")
return result
if isinstance(result, Number):
if self.output_count != 1:
raise RuntimeError("Operation evaluated to incorrect number of outputs")
raise RuntimeError(
"Operation evaluated to incorrect number of outputs")
return [result]
raise RuntimeError("Operation evaluated to invalid type")
......@@ -296,11 +299,13 @@ class AbstractOperation(Operation, AbstractGraphComponent):
def source(self) -> OutputPort:
if self.output_count != 1:
diff = "more" if self.output_count > 1 else "less"
raise TypeError(f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output")
raise TypeError(
f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output")
return self.output(0)
def copy_unconnected(self) -> GraphComponent:
new_comp: AbstractOperation = super().copy_unconnected()
for name, value in self.params.items():
new_comp.set_param(name, deepcopy(value)) # pylint: disable=no-member
new_comp.set_param(name, deepcopy(
value)) # pylint: disable=no-member
return new_comp
......@@ -8,6 +8,7 @@ from copy import copy
from typing import NewType, Optional, List, Iterable, TYPE_CHECKING
from b_asic.signal import Signal
from b_asic.graph_component import Name
if TYPE_CHECKING:
from b_asic.operation import Operation
......@@ -144,22 +145,24 @@ class InputPort(AbstractPort):
"""
return None if self._source_signal is None else self._source_signal.source
def connect(self, src: SignalSourceProvider) -> Signal:
def connect(self, src: SignalSourceProvider, name: Name = "") -> Signal:
"""Connect the provided signal source to this input port by creating a new signal.
Returns the new signal.
"""
assert self._source_signal is None, "Attempted to connect already connected input port."
return Signal(src.source, self) # self._source_signal is set by the signal constructor.
# self._source_signal is set by the signal constructor.
return Signal(source=src.source, destination=self, name=name)
@property
def value_length(self) -> Optional[int]:
"""Get the number of bits that this port should truncate received values to."""
return self._value_length
@value_length.setter
def value_length(self, bits: Optional[int]) -> None:
"""Set the number of bits that this port should truncate received values to."""
assert bits is None or (isinstance(bits, int) and bits >= 0), "Value length must be non-negative."
assert bits is None or (isinstance(
bits, int) and bits >= 0), "Value length must be non-negative."
self._value_length = bits
......@@ -185,7 +188,7 @@ class OutputPort(AbstractPort, SignalSourceProvider):
def add_signal(self, signal: Signal) -> None:
assert signal not in self._destination_signals, "Attempted to add already connected signal."
self._destination_signals.append(signal)
signal.set_source(self)
signal.set_source(self)
def remove_signal(self, signal: Signal) -> None:
assert signal in self._destination_signals, "Attempted to remove already removed signal."
......@@ -195,7 +198,7 @@ class OutputPort(AbstractPort, SignalSourceProvider):
def clear(self) -> None:
for signal in copy(self._destination_signals):
self.remove_signal(signal)
@property
def source(self) -> "OutputPort":
return self
\ No newline at end of file
return self
This diff is collapsed.
"""
B-ASIC test suite for printing a SFG
"""
from b_asic.signal_flow_graph import SFG
from b_asic.core_operations import Addition, Multiplication, Constant, ConstantAddition
from b_asic.port import InputPort, OutputPort
from b_asic.signal import Signal
from b_asic.special_operations import Input, Output
import pytest
class TestPrintSfg:
def test_print_one_addition(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
add1 = Addition(inp1, inp2, "ADD1")
out1 = Output(add1, "OUT1")
sfg = SFG(inputs=[inp1, inp2], outputs=[out1], name="sf1")
assert sfg.__str__() == ("id: add1, name: ADD1, input: [s1, s2], output: [s3]\nid: in1, name: INP1, input: [], output: [s1]\nid: in2, name: INP2, input: [], output: [s2]\nid: out1, name: OUT1, input: [s3], output: []\n")
def test_print_add_mul(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
inp3 = Input("INP3")
add1 = Addition(inp1, inp2, "ADD1")
mul1 = Multiplication(add1, inp3, "MUL1")
out1 = Output(mul1, "OUT1")
sfg = SFG(inputs=[inp1, inp2, inp3], outputs=[out1], name="mac_sfg")
assert sfg.__str__() == ("id: add1, name: ADD1, input: [s1, s2], output: [s5]\nid: in1, name: INP1, input: [], output: [s1]\nid: in2, name: INP2, input: [], output: [s2]\nid: mul1, name: MUL1, input: [s5, s3], output: [s4]\nid: in3, name: INP3, input: [], output: [s3]\nid: out1, name: OUT1, input: [s4], output: []\n")
def test_print_constant(self):
inp1 = Input("INP1")
const1 = Constant(3, "CONST")
add1 = Addition(const1, inp1, "ADD1")
out1 = Output(add1, "OUT1")
sfg = SFG(inputs=[inp1], outputs=[out1], name="sfg")
assert sfg.__str__() == ("id: add1, name: ADD1, input: [s3, s1], output: [s2]\nid: c1, name: CONST, value: 3, input: [], output: [s3]\nid: in1, name: INP1, input: [], output: [s1]\nid: out1, name: OUT1, input: [s2], output: []\n")
\ No newline at end of file
from b_asic import SFG
from b_asic.signal import Signal
from b_asic.core_operations import Addition, Constant
from b_asic.core_operations import Addition, Constant, Multiplication
from b_asic.special_operations import Input, Output
class TestConstructor:
def test_direct_input_to_output_sfg_construction(self):
inp = Input("INP1")
out = Output(None, "OUT1")
out.input(0).connect(inp, "S1")
sfg = SFG(inputs=[inp], outputs=[out])
assert len(list(sfg.components)) == 3
assert sfg.input_count == 1
assert sfg.output_count == 1
def test_same_signal_input_and_output_sfg_construction(self):
add1 = Addition(None, None, "ADD1")
add2 = Addition(None, None, "ADD2")
sig1 = add2.input(0).connect(add1, "S1")
sfg = SFG(input_signals=[sig1], output_signals=[sig1])
assert len(list(sfg.components)) == 3
assert sfg.input_count == 1
assert sfg.output_count == 1
def test_outputs_construction(self, operation_tree):
outp = Output(operation_tree)
sfg = SFG(outputs=[outp])
......@@ -20,13 +44,73 @@ class TestConstructor:
assert sfg.input_count == 0
assert sfg.output_count == 1
def test_operations_construction(self, operation_tree):
sfg1 = SFG(operations=[operation_tree])
sfg2 = SFG(operations=[operation_tree.input(1).signals[0].source.operation])
assert len(list(sfg1.components)) == 5
assert len(list(sfg2.components)) == 5
assert sfg1.input_count == 0
assert sfg2.input_count == 0
assert sfg1.output_count == 0
assert sfg2.output_count == 0
class TestDeepCopy:
def test_deep_copy_no_duplicates(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
inp3 = Input("INP3")
add1 = Addition(inp1, inp2, "ADD1")
mul1 = Multiplication(add1, inp3, "MUL1")
out1 = Output(mul1, "OUT1")
mac_sfg = SFG(inputs=[inp1, inp2],
outputs=[out1], name="mac_sfg")
mac_sfg_deep_copy = mac_sfg.deep_copy()
for g_id, component in mac_sfg._components_by_id.items():
component_copy = mac_sfg_deep_copy.find_by_id(g_id)
assert component.name == component_copy.name
def test_deep_copy(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
inp3 = Input("INP3")
add1 = Addition(None, None, "ADD1")
add2 = Addition(None, None, "ADD2")
mul1 = Multiplication(None, None, "MUL1")
out1 = Output(None, "OUT1")
add1.input(0).connect(inp1, "S1")
add1.input(1).connect(inp2, "S2")
add2.input(0).connect(add1, "S4")
add2.input(1).connect(inp3, "S3")
mul1.input(0).connect(add1, "S5")
mul1.input(1).connect(add2, "S6")
out1.input(0).connect(mul1, "S7")
mac_sfg = SFG(inputs=[inp1, inp2],
outputs=[out1], name="mac_sfg")
mac_sfg_deep_copy = mac_sfg.deep_copy()
for g_id, component in mac_sfg._components_by_id.items():
component_copy = mac_sfg_deep_copy.find_by_id(g_id)
assert component.name == component_copy.name
class TestComponents:
def test_advanced_components(self):
inp1 = Input("INP1")
inp2 = Input("INP2")
inp3 = Input("INP3")
add1 = Addition(None, None, "ADD1")
add2 = Addition(None, None, "ADD2")
mul1 = Multiplication(None, None, "MUL1")
out1 = Output(None, "OUT1")
add1.input(0).connect(inp1, "S1")
add1.input(1).connect(inp2, "S2")
add2.input(0).connect(add1, "S4")
add2.input(1).connect(inp3, "S3")
mul1.input(0).connect(add1, "S5")
mul1.input(1).connect(add2, "S6")
out1.input(0).connect(mul1, "S7")
mac_sfg = SFG(inputs=[inp1, inp2],
outputs=[out1], name="mac_sfg")
assert set([comp.name for comp in mac_sfg.components]) == {
"INP1", "INP2", "INP3", "ADD1", "ADD2", "MUL1", "OUT1", "S1", "S2", "S3", "S4", "S5", "S6", "S7"}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment