Skip to content
Snippets Groups Projects
Commit cf430402 authored by Robier Al Kaadi's avatar Robier Al Kaadi :penguin:
Browse files

Add is_commutative and is_distributive

parent d625926b
Branches commutative
No related tags found
1 merge request!457Add is_commutative and is_distributive
Pipeline #134600 passed
......@@ -1053,6 +1053,7 @@ class MAD(AbstractOperation):
_execution_time: Optional[int]
is_swappable = True
is_distributive = True
def __init__(
self,
......
......@@ -31,7 +31,10 @@ class ColorButton(QPushButton):
self.set_color(self._default)
def set_color(self, color: QColor):
"""Set new color."""
"""Set new color.
color : QColor
The new color of the button.
"""
if color != self._color:
self._color = color
self._color_changed.emit(color)
......@@ -42,7 +45,10 @@ class ColorButton(QPushButton):
self.setStyleSheet("")
def set_text_color(self, color: QColor):
"""Set text color."""
"""Set text color.
color : QColor
The new color of the text in the button.
"""
self.setStyleSheet(f"color: {color.name()};")
@property
......
......@@ -425,6 +425,22 @@ class Operation(GraphComponent, SignalSourceProvider):
"""
raise NotImplementedError
@property
@abstractmethod
def is_commutative(self) -> bool:
"""
Return True if the operation is commutative.
"""
raise NotImplementedError
@property
@abstractmethod
def is_distributive(self) -> bool:
"""
Return True if the operation is distributive.
"""
raise NotImplementedError
@property
@abstractmethod
def is_swappable(self) -> bool:
......@@ -1064,6 +1080,29 @@ class AbstractOperation(Operation, AbstractGraphComponent):
input_.connected_source.operation.is_constant for input_ in self.inputs
)
@property
def is_commutative(self) -> bool:
"""
Checks if the operation is commutative.
An operation is commutative if the order of the inputs does not change the result.
For example, addition is commutative because `a + b == b + a`, but subtraction is not
because `a - b != b - a`.
Returns:
bool: True if the operation is commutative, False otherwise.
"""
# doc-string inherited
if self.input_count == 2:
return self.is_swappable
return False
@property
def is_distributive(self) -> bool:
# doc-string inherited
return False
@property
def is_swappable(self) -> bool:
# doc-string inherited
......
......@@ -2111,6 +2111,153 @@ class SFG(AbstractOperation):
def is_constant(self) -> bool:
return all(output.is_constant for output in self._output_operations)
@property
def is_commutative(self) -> bool:
"""
Checks if all operations in the SFG are commutative.
An operation is considered commutative if it is not in B-ASIC Special Operations Module,
and its `is_commutative` property is `True`.
Returns
-------
bool: `True` if all operations are commutative, `False` otherwise.
"""
return all(
(
op.is_commutative
if op.type_name() not in ["in", "out", "t", "c"]
else True
)
for op in self.split()
)
@property
def is_distributive(self) -> bool:
"""
Checks if the SFG is distributive.
An operation is considered distributive if it can be applied to each element of a set separately.
For example, multiplication is distributive over addition, meaning that `a * (b + c)` is equivalent to `a * b + a * c`.
Returns
-------
bool: True if the SFG is distributive, False otherwise.
Examples
--------
>>> Mad_op = MAD(Input(), Input(), Input()) # Creates an instance of the Mad operation, MAD is defined in b_asic.core_operations
>>> Mad_sfg = Mad_op.to_sfg() # The operation is turned into a sfg
>>> Mad_sfg.is_distributive # True # if the distributive property holds, False otherwise
"""
structures = []
operations = self.get_operations_topological_order()
for op in operations:
if not (
op.type_name() == "in"
or op.type_name() == "out"
or op.type_name() == "c"
or op.type_name() == "t"
):
structures.append(op)
return (
all(self.has_distributive_structure(op) for op in structures)
if len(structures) > 1
else False
)
def has_distributive_structure(self, op: Operation) -> bool:
"""
Checks if the SFG contains distributive structures.
NOTE: a*b + c = a*(b + c/a) is considered distributive. Meaning that an algorithm transformation would require an additionat operation.
Parameters:
----------
op : Operation
The operation that is the start of the structure to check for distributivity.
Returns:
-------
bool: True if a distributive structures is found, False otherwise.
"""
# TODO Butterfly and SymmetricTwoportAdaptor needs to be converted to a SF using to_sfg() in order to be evaluated
if op.type_name() == 'mac':
return True
elif op.type_name() in ['mul', 'div']:
for subsequent_op in op.subsequent_operations:
if subsequent_op.type_name() in [
'add',
'sub',
'addsub',
'min',
'max',
'sqrt',
'abs',
'rec',
'out',
't',
]:
return True
elif subsequent_op.type_name() in ['mul', 'div']:
for subsequent_op in subsequent_op.subsequent_operations:
return self.has_distributive_structure(subsequent_op)
else:
return False
elif op.type_name() in ['cmul', 'shift', 'rshift', 'lshift']:
for subsequent_op in op.subsequent_operations:
if subsequent_op.type_name() in [
'add',
'sub',
'addsub',
'min',
'max',
'out',
't',
]:
return True
elif subsequent_op.type_name() in ['cmul', 'shift', 'rshift', 'lshift']:
for subsequent_op in subsequent_op.subsequent_operations:
return self.has_distributive_structure(subsequent_op)
else:
return False
elif op.type_name() in ['add', 'sub', 'addsub']:
for subsequent_op in op.subsequent_operations:
if subsequent_op.type_name() in [
'mul',
'div',
'min',
'max',
'out',
'cmul',
't',
]:
return True
elif subsequent_op.type_name() in ['add', 'sub', 'addsub']:
for subsequent_op in subsequent_op.subsequent_operations:
return self.has_distributive_structure(subsequent_op)
else:
return False
elif op.type_name() in ['min', 'max']:
for subsequent_op in op.subsequent_operations:
if subsequent_op.type_name() in [
'add',
'sub',
'addsub',
'mul',
'div',
'cmul',
'out',
't',
]:
return True
elif subsequent_op.type_name() in ['min', 'max']:
for subsequent_op in subsequent_op.subsequent_operations:
return self.has_distributive_structure(subsequent_op)
else:
return False
return False
def get_used_type_names(self) -> List[TypeName]:
"""Get a list of all TypeNames used in the SFG."""
ret = list({op.type_name() for op in self.operations})
......
......@@ -1611,6 +1611,36 @@ class TestIsConstant:
assert not sfg_nested.is_constant
class TestIsCommutative:
def test_single_accumulator(self, sfg_simple_accumulator: SFG):
assert sfg_simple_accumulator.is_commutative
def test_sfg_accumulation(self, sfg_accumulator: SFG):
assert not sfg_accumulator.is_commutative
class TestIsDistributive:
def test_single_accumulator(self, sfg_simple_accumulator: SFG):
assert not sfg_simple_accumulator.is_distributive
assert sfg_simple_accumulator.has_distributive_structure(
sfg_simple_accumulator.find_by_id('add0')
)
def test_sfg_accumulation(self, sfg_accumulator: SFG):
assert sfg_accumulator.is_distributive
def test_sfg_with_Chain_of_Additions(self): # value*(a+b+c)
a = Input()
b = Input()
c = Input()
add1 = a + b
add2 = add1 + c
mul = ConstantMultiplication(2, add2)
out1 = Output(mul)
sfg_simple = SFG(inputs=[a, b, c], outputs=[out1])
assert sfg_simple.is_distributive
class TestSwapIOOfOperation:
def do_test(self, sfg: SFG, graph_id: GraphID):
NUM_TESTS = 5
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment