Skip to content
Snippets Groups Projects
Commit aac6c47a authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Add testing for operation errors

parent 1baba58a
No related branches found
No related tags found
1 merge request!213Mixfixes
Pipeline #90042 passed
......@@ -62,6 +62,10 @@ class Constant(AbstractOperation):
"""Set the constant value of this operation."""
self.set_param("value", value)
@property
def latency(self) -> int:
return self.latency_offsets["out0"]
class Addition(AbstractOperation):
"""
......@@ -410,9 +414,7 @@ class Min(AbstractOperation):
def evaluate(self, a, b):
if isinstance(a, complex) or isinstance(b, complex):
raise ValueError(
"core_operations.Min does not support complex numbers."
)
raise ValueError("core_operations.Min does not support complex numbers.")
return a if a < b else b
......@@ -457,9 +459,7 @@ class Max(AbstractOperation):
def evaluate(self, a, b):
if isinstance(a, complex) or isinstance(b, complex):
raise ValueError(
"core_operations.Max does not support complex numbers."
)
raise ValueError("core_operations.Max does not support complex numbers.")
return a if a > b else b
......@@ -589,8 +589,7 @@ class ConstantMultiplication(AbstractOperation):
latency_offsets: Optional[Dict[str, int]] = None,
execution_time: Optional[int] = None,
):
"""Construct a ConstantMultiplication operation with the given value.
"""
"""Construct a ConstantMultiplication operation with the given value."""
super().__init__(
input_count=1,
output_count=1,
......
......@@ -988,7 +988,7 @@ class AbstractOperation(Operation, AbstractGraphComponent):
if any(val is None for val in latency_offsets):
raise ValueError(
"Missing latencies for inputs"
"Missing latencies for input(s)"
f" {[i for (i, latency) in enumerate(latency_offsets) if latency is None]}"
)
......@@ -999,8 +999,8 @@ class AbstractOperation(Operation, AbstractGraphComponent):
if any(val is None for val in latency_offsets):
raise ValueError(
"Missing latencies for outputs"
f" {[i for i in latency_offsets if i is not None]}"
"Missing latencies for output(s)"
f" {[i for (i, latency) in enumerate(latency_offsets) if latency is None]}"
)
return cast(List[int], latency_offsets)
......
......@@ -44,6 +44,10 @@ class Input(AbstractOperation):
def evaluate(self):
return self.param("value")
@property
def latency(self) -> int:
return self.latency_offsets["out0"]
@property
def value(self) -> Num:
"""Get the current value of this input."""
......@@ -56,9 +60,7 @@ class Input(AbstractOperation):
def get_plot_coordinates(
self,
) -> Tuple[
Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]
]:
) -> Tuple[Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]]:
# Doc-string inherited
return (
(
......@@ -122,9 +124,7 @@ class Output(AbstractOperation):
def get_plot_coordinates(
self,
) -> Tuple[
Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]
]:
) -> Tuple[Tuple[Tuple[float, float], ...], Tuple[Tuple[float, float], ...]]:
# Doc-string inherited
return (
((0, 0), (0, 1), (0.25, 1), (0.5, 0.5), (0.25, 0), (0, 0)),
......@@ -139,6 +139,10 @@ class Output(AbstractOperation):
# doc-string inherited
return tuple()
@property
def latency(self) -> int:
return self.latency_offsets["in0"]
class Delay(AbstractOperation):
"""
......@@ -174,9 +178,7 @@ class Delay(AbstractOperation):
self, index: int, delays: Optional[DelayMap] = None, prefix: str = ""
) -> Optional[Num]:
if delays is not None:
return delays.get(
self.key(index, prefix), self.param("initial_value")
)
return delays.get(self.key(index, prefix), self.param("initial_value"))
return self.param("initial_value")
def evaluate_output(
......@@ -190,9 +192,7 @@ class Delay(AbstractOperation):
truncate: bool = True,
) -> Num:
if index != 0:
raise IndexError(
f"Output index out of range (expected 0-0, got {index})"
)
raise IndexError(f"Output index out of range (expected 0-0, got {index})")
if len(input_values) != 1:
raise ValueError(
"Wrong number of inputs supplied to SFG for evaluation"
......
......@@ -94,10 +94,10 @@ def sfg_two_inputs_two_outputs_independent_with_cmul():
in1 = Input("IN1")
in2 = Input("IN2")
c1 = Constant(3, "C1")
add1 = Addition(in2, c1, "ADD1", 7)
cmul3 = ConstantMultiplication(2, add1, "CMUL3", 3)
cmul1 = ConstantMultiplication(5, in1, "CMUL1", 5)
cmul2 = ConstantMultiplication(4, cmul1, "CMUL2", 4)
add1 = Addition(in2, c1, "ADD1", 7, execution_time=2)
cmul3 = ConstantMultiplication(2, add1, "CMUL3", 3, execution_time=1)
cmul1 = ConstantMultiplication(5, in1, "CMUL1", 5, execution_time=3)
cmul2 = ConstantMultiplication(4, cmul1, "CMUL2", 4, execution_time=1)
out1 = Output(cmul2, "OUT1")
out2 = Output(cmul3, "OUT2")
return SFG(inputs=[in1, in2], outputs=[out1, out2])
......
......@@ -204,6 +204,10 @@ class TestLatency:
"out1": 9,
}
def test_set_latency_negative(self):
with pytest.raises(ValueError, match="Latency cannot be negative"):
Butterfly(latency=-1)
class TestExecutionTime:
def test_execution_time_constructor(self):
......@@ -292,9 +296,16 @@ class TestIOCoordinates:
bfly = Butterfly()
bfly.set_latency_offsets({"in0": 3, "out1": 5})
with pytest.raises(ValueError, match="Missing latencies for inputs \\[1\\]"):
with pytest.raises(
ValueError, match="Missing latencies for input\\(s\\) \\[1\\]"
):
bfly.get_input_coordinates()
with pytest.raises(
ValueError, match="Missing latencies for output\\(s\\) \\[0\\]"
):
bfly.get_output_coordinates()
class TestSplit:
def test_simple_case(self):
......
......@@ -410,21 +410,19 @@ class TestTimeResolution:
start_times_names = {}
for op_id, start_time in schedule._start_times.items():
op_name = sfg_two_inputs_two_outputs_independent_with_cmul.find_by_id(
op_id
).name
start_times_names[op_name] = start_time
op = sfg_two_inputs_two_outputs_independent_with_cmul.find_by_id(op_id)
start_times_names[op.name] = (start_time, op.latency, op.execution_time)
assert start_times_names == {
"C1": 0,
"IN1": 0,
"IN2": 0,
"CMUL1": 0,
"CMUL2": 30,
"ADD1": 0,
"CMUL3": 42,
"OUT1": 54,
"OUT2": 60,
"C1": (0, 0, None),
"IN1": (0, 0, None),
"IN2": (0, 0, None),
"CMUL1": (0, 30, 18),
"CMUL2": (30, 24, 6),
"ADD1": (0, 42, 12),
"CMUL3": (42, 18, 6),
"OUT1": (54, 0, None),
"OUT2": (60, 0, None),
}
assert 6 * old_schedule_time == schedule.schedule_time
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment