Skip to content
Snippets Groups Projects
Commit f7ad2d1c authored by Oscar Gustafsson's avatar Oscar Gustafsson :bicyclist:
Browse files

Initial work on constant propagation

parent 457e040c
No related tags found
1 merge request!137Initial work on constant propagation
Pipeline #88922 passed
......@@ -5,13 +5,13 @@ Contains some of the most commonly used mathematical operations.
"""
from numbers import Number
from typing import Dict, Optional
from typing import Dict, Iterable, Optional, Set
from numpy import abs as np_abs
from numpy import conjugate, sqrt
from b_asic.graph_component import Name, TypeName
from b_asic.operation import AbstractOperation
from b_asic.operation import AbstractOperation, Operation
from b_asic.port import SignalSourceProvider
......@@ -125,6 +125,14 @@ class Addition(AbstractOperation):
def evaluate(self, a, b):
return a + b
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
if any(c == 0.0 for c in constants):
print("One input is 0!")
class Subtraction(AbstractOperation):
"""
......@@ -185,6 +193,14 @@ class Subtraction(AbstractOperation):
def evaluate(self, a, b):
return a - b
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
if any(c == 0.0 for c in constants):
print("One input is 0!")
class AddSub(AbstractOperation):
r"""
......@@ -266,6 +282,20 @@ class AddSub(AbstractOperation):
"""Set if operation is add."""
self.set_param("is_add", is_add)
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
if any(c == 0.0 for c in constants):
print("One input is 0!")
def _propagate_constant_parameters(
self, valid_operations: Optional[Set["Operation"]] = None
) -> None:
print(f"Can turn into {'Addition' if self.is_add else 'Subtraction'}")
return
class Multiplication(AbstractOperation):
r"""
......@@ -327,6 +357,17 @@ class Multiplication(AbstractOperation):
def evaluate(self, a, b):
return a * b
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
if any(c == 0.0 for c in constants):
print("One input is 0!")
if any(c == 1.0 for c in constants):
print("One input is 1!")
print("Can turn into ConstantMultiplication")
class Division(AbstractOperation):
r"""
......@@ -368,6 +409,17 @@ class Division(AbstractOperation):
def evaluate(self, a, b):
return a / b
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
numerator, denominator = constants
if numerator == 0.0:
print("Result is 0!")
if denominator is not None:
print("Can turn into ConstantMultiplication")
class Min(AbstractOperation):
r"""
......@@ -619,6 +671,15 @@ class ConstantMultiplication(AbstractOperation):
"""Set the constant value of this operation."""
self.set_param("value", value)
def _propagate_constant_parameters(
self, valid_operations: Optional[Set["Operation"]] = None
) -> None:
if self.value == 0.0:
print("Constant is zero!")
if self.value == 1.0:
print("Constant is zero!")
return
class Butterfly(AbstractOperation):
r"""
......@@ -661,6 +722,14 @@ class Butterfly(AbstractOperation):
def evaluate(self, a, b):
return a + b, a - b
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
if any(c == 0.0 for c in constants):
print("One input is 0!")
class MAD(AbstractOperation):
r"""
......@@ -700,6 +769,19 @@ class MAD(AbstractOperation):
def evaluate(self, a, b, c):
return a * b + c
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
a, b, c = constants
if a == 0.0 or b == 0.0:
print("One multiplier input is zero!")
if a == 1.0 or b == 1.0:
print("One multiplier input is one!")
if any(c == 0.0):
print("Adder input is zero!")
class SymmetricTwoportAdaptor(AbstractOperation):
r"""
......@@ -752,6 +834,21 @@ class SymmetricTwoportAdaptor(AbstractOperation):
"""Set the constant value of this operation."""
self.set_param("value", value)
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
if any(c == 0.0 for c in constants):
print("One input is 0!")
def _propagate_constant_parameters(
self, valid_operations: Optional[Set["Operation"]] = None
) -> None:
if self.value == 0.0:
print("Constant is zero!")
return
class Reciprocal(AbstractOperation):
r"""
......
......@@ -19,6 +19,7 @@ from typing import (
NewType,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
......@@ -435,6 +436,30 @@ class Operation(GraphComponent, SignalSourceProvider):
def _check_all_latencies_set(self) -> None:
raise NotImplementedError
@abstractmethod
def _propagate_constants(
self, valid_operations: Optional[Set["Operation"]] = None
) -> None:
raise NotImplementedError
@abstractmethod
def _constant_inputs(self) -> Iterable[Number]:
raise NotImplementedError
@abstractmethod
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
raise NotImplementedError
@abstractmethod
def _propagate_constant_parameters(
self, valid_operations: Optional[Set["Operation"]] = None
) -> None:
raise NotImplementedError
class AbstractOperation(Operation, AbstractGraphComponent):
"""
......@@ -1093,3 +1118,43 @@ class AbstractOperation(Operation, AbstractGraphComponent):
for k in range(len(self.outputs))
]
return input_coordinates, output_coordinates
def _propagate_constants(
self, valid_operations: Optional[Set["Operation"]] = None
) -> None:
# Must be implemented per operation, so just return otherwise
constants = self._constant_inputs()
if all(c is None for c in constants):
return
if all(c is not None for c in constants):
res = self.evalute(*constants)
print(f"Result is {res}!")
if any(c is not None for c in constants):
# This is operation dependent
self._propagate_some_constants(constants, valid_operations)
return
def _constant_inputs(self) -> Iterable[Optional[Number]]:
from b_asic.core_operations import Constant
ret = []
for port in self._input_ports:
if port.connected_source is None:
ret.append(None)
elif isinstance(port.connected_source.operation, Constant):
ret.append(port.connected_source.operation.value)
else:
ret.append(None)
return ret
def _propagate_some_constants(
self,
constants: Iterable[Optional[Number]],
valid_operations: Optional[Set["Operation"]] = None,
) -> None:
return
def _propagate_constant_parameters(
self, valid_operations: Optional[Set["Operation"]] = None
) -> None:
return
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment