Newer
Older
Angus Lothian
committed
"""B-ASIC Operation Module.
Contains the base for operations that are used by B-ASIC.
Angus Lothian
committed
from b_asic.signal import Signal
from b_asic.port import SignalSourceProvider, InputPort, OutputPort
from b_asic.graph_component import GraphComponent, AbstractGraphComponent, Name
import itertools as it
from math import trunc
import collections
from abc import abstractmethod
from numbers import Number
Angus Lothian
committed
from typing import NewType, List, Dict, Sequence, Iterable, Mapping, MutableMapping, Optional, Any, Set, Union
Angus Lothian
committed
ResultKey = NewType("ResultKey", str)
ResultMap = Mapping[ResultKey, Optional[Number]]
MutableResultMap = MutableMapping[ResultKey, Optional[Number]]
DelayMap = Mapping[ResultKey, Number]
MutableDelayMap = MutableMapping[ResultKey, Number]
Angus Lothian
committed
class Operation(GraphComponent, SignalSourceProvider):
Angus Lothian
committed
Operations are graph components that perform a certain function.
They are connected to eachother by signals through their input/output
ports.
Operations can be evaluated independently using evaluate_output().
Operations may specify how to truncate inputs through truncate_input().
Angus Lothian
committed
def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
"""Overloads the addition operator to make it return a new Addition operation
object that is connected to the self and other objects.
"""
raise NotImplementedError
@abstractmethod
def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
"""Overloads the addition operator to make it return a new Addition operation
object that is connected to the self and other objects.
"""
raise NotImplementedError
@abstractmethod
def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction":
"""Overloads the subtraction operator to make it return a new Subtraction operation
object that is connected to the self and other objects.
"""
raise NotImplementedError
@abstractmethod
def __rsub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction":
"""Overloads the subtraction operator to make it return a new Subtraction operation
object that is connected to the self and other objects.
"""
raise NotImplementedError
@abstractmethod
Angus Lothian
committed
def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
"""Overloads the multiplication operator to make it return a new Multiplication operation
object that is connected to the self and other objects. If other is a number then
returns a ConstantMultiplication operation object instead.
"""
Angus Lothian
committed
@abstractmethod
def __rmul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
"""Overloads the multiplication operator to make it return a new Multiplication operation
object that is connected to the self and other objects. If other is a number then
returns a ConstantMultiplication operation object instead.
"""
raise NotImplementedError
@abstractmethod
def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division":
"""Overloads the division operator to make it return a new Division operation
object that is connected to the self and other objects.
"""
raise NotImplementedError
@abstractmethod
def __rtruediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division":
"""Overloads the division operator to make it return a new Division operation
object that is connected to the self and other objects.
"""
raise NotImplementedError
@abstractmethod
def __lshift__(self, src: SignalSourceProvider) -> Signal:
"""Overloads the left shift operator to make it connect the provided signal source
to this operation's input, assuming it has exactly 1 input port.
Returns the new signal.
"""
raise NotImplementedError
@property
@abstractmethod
def input_count(self) -> int:
"""Get the number of input ports."""
raise NotImplementedError
Angus Lothian
committed
@property
@abstractmethod
def output_count(self) -> int:
"""Get the number of output ports."""
raise NotImplementedError
@abstractmethod
Angus Lothian
committed
def input(self, index: int) -> InputPort:
"""Get the input port at the given index."""
raise NotImplementedError
@abstractmethod
Angus Lothian
committed
def output(self, index: int) -> OutputPort:
"""Get the output port at the given index."""
Angus Lothian
committed
@property
Angus Lothian
committed
def inputs(self) -> Sequence[InputPort]:
"""Get all input ports."""
Angus Lothian
committed
@property
@abstractmethod
def outputs(self) -> Sequence[OutputPort]:
"""Get all output ports."""
raise NotImplementedError
@property
@abstractmethod
def input_signals(self) -> Iterable[Signal]:
"""Get all the signals that are connected to this operation's input ports,
in no particular order.
"""
raise NotImplementedError
@property
Angus Lothian
committed
def output_signals(self) -> Iterable[Signal]:
"""Get all the signals that are connected to this operation's output ports,
in no particular order.
"""
raise NotImplementedError
@abstractmethod
Angus Lothian
committed
def key(self, index: int, prefix: str = "") -> ResultKey:
"""Get the key used to access the output of a certain output of this operation
from the output parameter passed to current_output(s) or evaluate_output(s).
"""
raise NotImplementedError
@abstractmethod
Angus Lothian
committed
def current_output(self, index: int, delays: Optional[DelayMap] = None, prefix: str = "") -> Optional[Number]:
"""Get the current output at the given index of this operation, if available.
The delays parameter will be used for lookup.
The prefix parameter will be used as a prefix for the key string when looking for delays.
See also: current_outputs, evaluate_output, evaluate_outputs.
"""
raise NotImplementedError
@abstractmethod
Angus Lothian
committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, delays: Optional[MutableDelayMap] = None, prefix: str = "", bits_override: Optional[int] = None, truncate: bool = True) -> Number:
"""Evaluate the output at the given index of this operation with the given input values.
The results parameter will be used to store any results (including intermediate results) for caching.
The delays parameter will be used to get the current value of any intermediate delays that are encountered, and be updated with their new values.
The prefix parameter will be used as a prefix for the key string when storing results/delays.
The bits_override parameter specifies a word length override when truncating inputs which ignores the word length specified by the input signal.
The truncate parameter specifies whether input truncation should be enabled in the first place. If set to False, input values will be used driectly without any bit truncation.
See also: evaluate_outputs, current_output, current_outputs.
"""
raise NotImplementedError
@abstractmethod
def current_outputs(self, delays: Optional[DelayMap] = None, prefix: str = "") -> Sequence[Optional[Number]]:
"""Get all current outputs of this operation, if available.
See current_output for more information.
"""
raise NotImplementedError
@abstractmethod
def evaluate_outputs(self, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, delays: Optional[MutableDelayMap] = None, prefix: str = "", bits_override: Optional[int] = None, truncate: bool = True) -> Sequence[Number]:
"""Evaluate all outputs of this operation given the input values.
See evaluate_output for more information.
"""
raise NotImplementedError
@abstractmethod
def split(self) -> Iterable["Operation"]:
"""Split the operation into multiple operations.
If splitting is not possible, this may return a list containing only the operation itself.
"""
raise NotImplementedError
Angus Lothian
committed
@abstractmethod
def to_sfg(self) -> "SFG":
"""Convert the operation into its corresponding SFG.
If the operation is composed by multiple operations, the operation will be split.
"""
raise NotImplementedError
@abstractmethod
def inputs_required_for_output(self, output_index: int) -> Iterable[int]:
"""Get the input indices of all inputs in this operation whose values are required in order to evaluate the output at the given output index."""
raise NotImplementedError
@abstractmethod
def truncate_input(self, index: int, value: Number, bits: int) -> Number:
"""Truncate the value to be used as input at the given index to a certain bit length."""
raise NotImplementedError
Angus Lothian
committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def latency(self) -> int:
"""Get the latency of the operation, which is the longest time it takes from one of
the operations inputport to one of the operations outputport.
"""
raise NotImplementedError
@property
@abstractmethod
def latency_offsets(self) -> Sequence[Sequence[int]]:
"""Get a nested list with all the operations ports latency-offsets, the first list contains the
latency-offsets of the operations input ports, the second list contains the latency-offsets of
the operations output ports.
"""
raise NotImplementedError
@abstractmethod
def set_latency(self, latency: int) -> None:
"""Sets the latency of the operation to the specified integer value by setting the
latency-offsets of operations input ports to 0 and the latency-offsets of the operations
output ports to the specified value. The latency cannot be a negative integers.
"""
raise NotImplementedError
@abstractmethod
def set_latency_offsets(self, latency_offsets: Dict[str, int]) -> None:
"""Sets the latency-offsets for the operations ports specified in the latency_offsets dictionary.
The latency offsets dictionary should be {'in0': 2, 'out1': 4} if you want to set the latency offset
for the inport port with index 0 to 2, and the latency offset of the output port with index 1 to 4.
"""
raise NotImplementedError
class AbstractOperation(Operation, AbstractGraphComponent):
Angus Lothian
committed
"""Generic abstract operation base class.
Concrete operations should normally derive from this to get the default
behavior.
Angus Lothian
committed
_input_ports: List[InputPort]
_output_ports: List[OutputPort]
Angus Lothian
committed
def __init__(self, input_count: int, output_count: int, name: Name = "", input_sources: Optional[Sequence[Optional[SignalSourceProvider]]] = None, latency: Optional[int] = None, latency_offsets: Optional[Dict[str, int]] = None):
"""Construct an operation with the given input/output count.
A list of input sources may be specified to automatically connect
to the input ports.
If provided, the number of sources must match the number of inputs.
The latency offsets may also be specified to be initialized.
"""
Angus Lothian
committed
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
self._input_ports = [InputPort(self, i) for i in range(input_count)]
self._output_ports = [OutputPort(self, i) for i in range(output_count)]
# Connect given input sources, if any.
if input_sources is not None:
source_count = len(input_sources)
if source_count != input_count:
raise ValueError(
f"Wrong number of input sources supplied to Operation (expected {input_count}, got {source_count})")
for i, src in enumerate(input_sources):
if src is not None:
self._input_ports[i].connect(src.source)
ports_without_latency_offset = set(([f"in{i}" for i in range(self.input_count)] +
[f"out{i}" for i in range(self.output_count)]))
if latency_offsets is not None:
self.set_latency_offsets(latency_offsets)
if latency is not None:
# Set the latency of the rest of ports with no latency_offset.
assert latency >= 0, "Negative latency entered"
for inp in self.inputs:
if inp.latency_offset is None:
inp.latency_offset = 0
for outp in self.outputs:
if outp.latency_offset is None:
outp.latency_offset = latency
@abstractmethod
def evaluate(self, *inputs) -> Any: # pylint: disable=arguments-differ
Angus Lothian
committed
"""Evaluate the operation and generate a list of output values given a list of input values."""
Angus Lothian
committed
def __add__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Addition
return Addition(self, Constant(src) if isinstance(src, Number) else src)
def __radd__(self, src: Union[SignalSourceProvider, Number]) -> "Addition":
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Addition
return Addition(Constant(src) if isinstance(src, Number) else src, self)
def __sub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction":
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Subtraction
return Subtraction(self, Constant(src) if isinstance(src, Number) else src)
def __rsub__(self, src: Union[SignalSourceProvider, Number]) -> "Subtraction":
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Subtraction
return Subtraction(Constant(src) if isinstance(src, Number) else src, self)
def __mul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
# Import here to avoid circular imports.
from b_asic.core_operations import Multiplication, ConstantMultiplication
return ConstantMultiplication(src, self) if isinstance(src, Number) else Multiplication(self, src)
Angus Lothian
committed
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
def __rmul__(self, src: Union[SignalSourceProvider, Number]) -> "Union[Multiplication, ConstantMultiplication]":
# Import here to avoid circular imports.
from b_asic.core_operations import Multiplication, ConstantMultiplication
return ConstantMultiplication(src, self) if isinstance(src, Number) else Multiplication(src, self)
def __truediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division":
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Division
return Division(self, Constant(src) if isinstance(src, Number) else src)
def __rtruediv__(self, src: Union[SignalSourceProvider, Number]) -> "Division":
# Import here to avoid circular imports.
from b_asic.core_operations import Constant, Division
return Division(Constant(src) if isinstance(src, Number) else src, self)
def __lshift__(self, src: SignalSourceProvider) -> Signal:
if self.input_count != 1:
diff = "more" if self.input_count > 1 else "less"
raise TypeError(
f"{self.__class__.__name__} cannot be used as a destination because it has {diff} than 1 input")
return self.input(0).connect(src)
def __str__(self) -> str:
"""Get a string representation of this operation."""
inputs_dict = dict()
for i, port in enumerate(self.inputs):
if port.signal_count == 0:
inputs_dict[i] = '-'
break
dict_ele = []
for signal in port.signals:
if signal.source:
if signal.source.operation.graph_id:
dict_ele.append(signal.source.operation.graph_id)
else:
dict_ele.append("no_id")
else:
if signal.graph_id:
dict_ele.append(signal.graph_id)
else:
dict_ele.append("no_id")
inputs_dict[i] = dict_ele
outputs_dict = dict()
for i, port in enumerate(self.outputs):
if port.signal_count == 0:
outputs_dict[i] = '-'
break
dict_ele = []
for signal in port.signals:
if signal.destination:
if signal.destination.operation.graph_id:
dict_ele.append(signal.destination.operation.graph_id)
else:
dict_ele.append("no_id")
else:
if signal.graph_id:
dict_ele.append(signal.graph_id)
else:
dict_ele.append("no_id")
outputs_dict[i] = dict_ele
return super().__str__() + f", \tinputs: {str(inputs_dict)}, \toutputs: {str(outputs_dict)}"
Angus Lothian
committed
@property
def input_count(self) -> int:
return len(self._input_ports)
Angus Lothian
committed
@property
def output_count(self) -> int:
return len(self._output_ports)
Angus Lothian
committed
def input(self, index: int) -> InputPort:
return self._input_ports[index]
def output(self, index: int) -> OutputPort:
return self._output_ports[index]
Angus Lothian
committed
def inputs(self) -> Sequence[InputPort]:
return self._input_ports
Angus Lothian
committed
@property
def outputs(self) -> Sequence[OutputPort]:
return self._output_ports
@property
def input_signals(self) -> Iterable[Signal]:
result = []
for p in self.inputs:
for s in p.signals:
result.append(s)
return result
Angus Lothian
committed
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
@property
def output_signals(self) -> Iterable[Signal]:
result = []
for p in self.outputs:
for s in p.signals:
result.append(s)
return result
def key(self, index: int, prefix: str = "") -> ResultKey:
key = prefix
if self.output_count != 1:
if key:
key += "."
key += str(index)
elif not key:
key = str(index)
return key
def current_output(self, index: int, delays: Optional[DelayMap] = None, prefix: str = "") -> Optional[Number]:
return None
def evaluate_output(self, index: int, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, delays: Optional[MutableDelayMap] = None, prefix: str = "", bits_override: Optional[int] = None, truncate: bool = True) -> Number:
if index < 0 or index >= self.output_count:
raise IndexError(
f"Output index out of range (expected 0-{self.output_count - 1}, got {index})")
if len(input_values) != self.input_count:
raise ValueError(
f"Wrong number of input values supplied to operation (expected {self.input_count}, got {len(input_values)})")
values = self.evaluate(
*(self.truncate_inputs(input_values, bits_override) if truncate else input_values))
if isinstance(values, collections.abc.Sequence):
if len(values) != self.output_count:
raise RuntimeError(
f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got {len(values)})")
elif isinstance(values, Number):
if self.output_count != 1:
raise RuntimeError(
f"Operation evaluated to incorrect number of outputs (expected {self.output_count}, got 1)")
values = (values,)
Angus Lothian
committed
raise RuntimeError(
f"Operation evaluated to invalid type (expected Sequence/Number, got {values.__class__.__name__})")
Angus Lothian
committed
if results is not None:
for i in range(self.output_count):
results[self.key(i, prefix)] = values[i]
return values[index]
Angus Lothian
committed
def current_outputs(self, delays: Optional[DelayMap] = None, prefix: str = "") -> Sequence[Optional[Number]]:
return [self.current_output(i, delays, prefix) for i in range(self.output_count)]
Angus Lothian
committed
def evaluate_outputs(self, input_values: Sequence[Number], results: Optional[MutableResultMap] = None, delays: Optional[MutableDelayMap] = None, prefix: str = "", bits_override: Optional[int] = None, truncate: bool = True) -> Sequence[Number]:
return [self.evaluate_output(i, input_values, results, delays, prefix, bits_override, truncate) for i in range(self.output_count)]
Angus Lothian
committed
def split(self) -> Iterable[Operation]:
# Import here to avoid circular imports.
from b_asic.special_operations import Input
try:
result = self.evaluate(*([Input()] * self.input_count))
if isinstance(result, collections.Sequence) and all(isinstance(e, Operation) for e in result):
return result
if isinstance(result, Operation):
return [result]
except TypeError:
pass
except ValueError:
pass
return [self]
Angus Lothian
committed
def to_sfg(self) -> "SFG":
Angus Lothian
committed
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
from b_asic.special_operations import Input, Output
from b_asic.signal_flow_graph import SFG
inputs = [Input() for i in range(self.input_count)]
try:
last_operations = self.evaluate(*inputs)
if isinstance(last_operations, Operation):
last_operations = [last_operations]
outputs = [Output(o) for o in last_operations]
except TypeError:
operation_copy: Operation = self.copy_component()
inputs = []
for i in range(self.input_count):
_input = Input()
operation_copy.input(i).connect(_input)
inputs.append(_input)
outputs = [Output(operation_copy)]
return SFG(inputs=inputs, outputs=outputs)
def copy_component(self, *args, **kwargs) -> GraphComponent:
new_component: Operation = super().copy_component(*args, **kwargs)
for i, inp in enumerate(self.inputs):
new_component.input(i).latency_offset = inp.latency_offset
for i, outp in enumerate(self.outputs):
new_component.output(i).latency_offset = outp.latency_offset
return new_component
def inputs_required_for_output(self, output_index: int) -> Iterable[int]:
if output_index < 0 or output_index >= self.output_count:
raise IndexError(
f"Output index out of range (expected 0-{self.output_count - 1}, got {output_index})")
# By default, assume each output depends on all inputs.
return [i for i in range(self.input_count)]
Angus Lothian
committed
@property
def neighbors(self) -> Iterable[GraphComponent]:
return list(self.input_signals) + list(self.output_signals)
Angus Lothian
committed
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
@property
def preceding_operations(self) -> Iterable[Operation]:
"""Returns an Iterable of all Operations that are connected to this Operations input ports."""
return [signal.source.operation for signal in self.input_signals if signal.source]
@property
def subsequent_operations(self) -> Iterable[Operation]:
"""Returns an Iterable of all Operations that are connected to this Operations output ports."""
return [signal.destination.operation for signal in self.output_signals if signal.destination]
@property
def source(self) -> OutputPort:
if self.output_count != 1:
diff = "more" if self.output_count > 1 else "less"
raise TypeError(
f"{self.__class__.__name__} cannot be used as an input source because it has {diff} than 1 output")
return self.output(0)
def truncate_input(self, index: int, value: Number, bits: int) -> Number:
return int(value) & ((2 ** bits) - 1)
def truncate_inputs(self, input_values: Sequence[Number], bits_override: Optional[int] = None) -> Sequence[Number]:
"""Truncate the values to be used as inputs to the bit lengths specified by the respective signals connected to each input."""
args = []
for i, input_port in enumerate(self.inputs):
value = input_values[i]
bits = bits_override
if bits_override is None and input_port.signal_count >= 1:
bits = input_port.signals[0].bits
if bits_override is not None:
if isinstance(value, complex):
raise TypeError(
"Complex value cannot be truncated to {bits} bits as requested by the signal connected to input #{i}")
value = self.truncate_input(i, value, bits)
args.append(value)
return args
@property
def latency(self) -> int:
if None in [inp.latency_offset for inp in self.inputs] or None in [outp.latency_offset for outp in self.outputs]:
raise ValueError(
"All native offsets have to set to a non-negative value to calculate the latency.")
return max(((outp.latency_offset - inp.latency_offset) for outp, inp in it.product(self.outputs, self.inputs)))
@property
def latency_offsets(self) -> Sequence[Sequence[int]]:
latency_offsets = dict()
for i, inp in enumerate(self.inputs):
latency_offsets["in" + str(i)] = inp.latency_offset
for i, outp in enumerate(self.outputs):
latency_offsets["out" + str(i)] = outp.latency_offset
return latency_offsets
def set_latency(self, latency: int) -> None:
assert latency >= 0, "Negative latency entered."
for inport in self.inputs:
inport.latency_offset = 0
for outport in self.outputs:
outport.latency_offset = latency
def set_latency_offsets(self, latency_offsets: Dict[str, int]) -> None:
for port_str, latency_offset in latency_offsets.items():
port_str = port_str.lower()
if port_str.startswith("in"):
index_str = port_str[2:]
assert index_str.isdigit(), "Incorrectly formatted index in string, expected 'in' + index"
self.input(int(index_str)).latency_offset = latency_offset
elif port_str.startswith("out"):
index_str = port_str[3:]
assert index_str.isdigit(), "Incorrectly formatted index in string, expected 'out' + index"
self.output(int(index_str)).latency_offset = latency_offset
else:
raise ValueError(
"Incorrectly formatted string, expected 'in' + index or 'out' + index")