Skip to content
Snippets Groups Projects
Commit 6d81535e authored by Mikael Henriksson's avatar Mikael Henriksson :runner:
Browse files

process.py: add ABC MemoryProcess with support for spliting based on read times

parent 8037c620
No related branches found
No related tags found
1 merge request!381process.py: add ABC MemoryProcess with support for spliting based on read times
Pipeline #97351 passed
......@@ -55,7 +55,7 @@ def memory_based_storage(
# Write the input port specification
f.write(f'{2*VHDL_TAB}-- Memory port I/O\n')
read_ports: set[Port] = set(
sum((mv.read_ports for mv in collection), ())
read_port for mv in collection for read_port in mv.read_ports
) # type: ignore
for idx, read_port in enumerate(read_ports):
port_name = read_port if isinstance(read_port, int) else read_port.name
......
"""B-ASIC classes representing resource usage."""
from typing import Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, cast
from b_asic.operation import Operation
from b_asic.port import InputPort, OutputPort
......@@ -105,7 +105,149 @@ class OperatorProcess(Process):
return f"OperatorProcess({self.start_time}, {self.operation}, {self.name!r})"
class MemoryVariable(Process):
class MemoryProcess(Process):
"""
Intermediate class (abstract) for memory processes.
Different from regular :class:`Processe` objects, :class:`MemoryProcess` objects
can contain multiple read accesses and can be split into two new
:class:`MemoryProcess` objects based on these read times.
Parameters
----------
write_time : int
Start time of process.
life_times : list of int
List of ints representing times after ``start_time`` this process is accessed.
name : str, default=""
Name of the process.
"""
def __init__(
self,
write_time: int,
life_times: List[int],
name: str = "",
):
pass
self._life_times = life_times
super().__init__(
start_time=write_time,
execution_time=max(self._life_times),
name=name,
)
@property
def read_times(self) -> List[int]:
return list(self.start_time + read for read in self._life_times)
@property
def life_times(self) -> List[int]:
return self._life_times
@property
def reads(self) -> Dict[Any, int]:
raise NotImplementedError("MultiReadProcess should be derived from")
@property
def read_ports(self) -> List[Any]:
raise NotImplementedError("MultiReadProcess should be derived from")
@property
def write_port(self) -> Any:
raise NotImplementedError("MultiReadProcess should be derived from")
def split_on_length(
self,
length: int = 0,
) -> Tuple[Optional["MemoryProcess"], Optional["MemoryProcess"]]:
"""
Split this :class:`MemoryProcess` into two new :class:`MemoryProcess` objects,
based on lifetimes of the read accesses.
Parameters
----------
length : int, default: 0
The life time length to split on. Length is inclusive for the smaller
process.
Returns
-------
Two-tuple where the first element is a :class:`MemoryProcess` consisting
of reads with read times smaller than or equal to ``length`` (or None if no such
reads exists), and vice-versa for the other tuple element.
"""
reads = self.reads
short_reads = {k: v for k, v in filter(lambda t: t[1] <= length, reads.items())}
long_reads = {k: v for k, v in filter(lambda t: t[1] > length, reads.items())}
short_process = None
long_process = None
if short_reads:
# Create a new Process of type self (which is a derived variant of
# MultiReadProcess) by calling the self constructor
short_process = type(self)(
self.start_time, # type: ignore
self.write_port, # type: ignore
short_reads, # type: ignore
self.name, # type: ignore
)
if long_reads:
# Create a new Process of type self (which is a derived variant of
# MultiReadProcess) by calling the self constructor
long_process = type(self)(
self.start_time, # type: ignore
self.write_port, # type: ignore
long_reads, # type: ignore
self.name, # type: ignore
)
return short_process, long_process
def _add_life_time(self, life_time: int):
"""
Add a lifetime to this :class:`~b_asic.process.MultiReadProcess` set of
lifetimes.
If the lifetime specified by ``life_time`` is already in this
:class:`~b_asic.process.MultiReadProcess`, nothing happens
After adding a lifetime from this :class:`~b_asic.process.MultiReadProcess`,
the execution time is re-evaluated.
Parameters
----------
life_time : int
The lifetime to add to this :class:`~b_asic.process.MultiReadProcess`.
"""
if life_time not in self.life_times:
self._life_times.append(life_time)
self._execution_time = max(self.life_times)
def _remove_life_time(self, life_time: int):
"""
Remove a lifetime from this :class:`~b_asic.process.MultiReadProcess`
set of lifetimes.
After removing a lifetime from this :class:`~b_asic.process.MultiReadProcess`,
the execution time is re-evaluated.
Raises :class:`KeyError` if the specified lifetime is not a lifetime of this
:class:`~b_asic.process.MultiReadProcess`.
Parameters
----------
life_time : int
The lifetime to remove from this :class:`~b_asic.process.MultiReadProcess`.
"""
if life_time not in self.life_times:
raise KeyError(
f"Process {self.name}: {life_time} not in life_times: {self.life_times}"
)
else:
self._life_times.remove(life_time)
self._execution_time = max(self.life_times)
class MemoryVariable(MemoryProcess):
"""
Object that corresponds to a memory variable.
......@@ -130,13 +272,12 @@ class MemoryVariable(Process):
reads: Dict[InputPort, int],
name: Optional[str] = None,
):
self._read_ports = tuple(reads.keys())
self._life_times = tuple(reads.values())
self._read_ports = list(reads.keys())
self._reads = reads
self._write_port = write_port
super().__init__(
start_time=write_time,
execution_time=max(self._life_times),
write_time=write_time,
life_times=list(reads.values()),
name=name,
)
......@@ -145,11 +286,7 @@ class MemoryVariable(Process):
return self._reads
@property
def life_times(self) -> Tuple[int, ...]:
return self._life_times
@property
def read_ports(self) -> Tuple[InputPort, ...]:
def read_ports(self) -> List[InputPort]:
return self._read_ports
@property
......@@ -163,12 +300,36 @@ class MemoryVariable(Process):
f" {reads!r}, {self.name!r})"
)
@property
def read_times(self) -> Tuple[int, ...]:
return tuple(self.start_time + read for read in self._life_times)
def split_on_length(
self,
length: int = 0,
) -> Tuple[Optional["MemoryVariable"], Optional["MemoryVariable"]]:
"""
Split this :class:`MemoryVariable` into two new :class:`MemoryVariable` objects,
based on lifetimes of read accesses.
Parameters
----------
length : int, default: 0
The lifetime length to split on. Length is inclusive for the smaller
process.
Returns
-------
Two-tuple where the first element is a :class:`MemoryVariable` consisting
of reads with read times smaller than or equal to ``length`` (or None if no such
reads exists), and vice-versa for the other tuple element.
"""
# This method exists only for documentation purposes and for generating correct
# type annotations when calling it. Just call super().split_on_length() in here.
short_process, long_process = super().split_on_length(length)
return (
cast(Optional["MemoryVariable"], short_process),
cast(Optional["MemoryVariable"], long_process),
)
class PlainMemoryVariable(Process):
class PlainMemoryVariable(MemoryProcess):
"""
Object that corresponds to a memory variable which only use numbers for ports.
......@@ -196,8 +357,7 @@ class PlainMemoryVariable(Process):
reads: Dict[int, int],
name: Optional[str] = None,
):
self._read_ports = tuple(reads.keys())
self._life_times = tuple(reads.values())
self._read_ports = list(reads.keys())
self._write_port = write_port
self._reads = reads
if name is None:
......@@ -205,8 +365,8 @@ class PlainMemoryVariable(Process):
PlainMemoryVariable._name_cnt += 1
super().__init__(
start_time=write_time,
execution_time=max(self._life_times),
write_time=write_time,
life_times=list(reads.values()),
name=name,
)
......@@ -215,11 +375,7 @@ class PlainMemoryVariable(Process):
return self._reads
@property
def life_times(self) -> Tuple[int, ...]:
return self._life_times
@property
def read_ports(self) -> Tuple[int, ...]:
def read_ports(self) -> List[int]:
return self._read_ports
@property
......@@ -233,9 +389,33 @@ class PlainMemoryVariable(Process):
f" {reads!r}, {self.name!r})"
)
@property
def read_times(self) -> Tuple[int, ...]:
return tuple(self.start_time + read for read in self._life_times)
def split_on_length(
self,
length: int = 0,
) -> Tuple[Optional["PlainMemoryVariable"], Optional["PlainMemoryVariable"]]:
"""
Split this :class:`PlainMemoryVariable` into two new
:class:`PlainMemoryVariable` objects, based on lifetimes of read accesses.
Parameters
----------
length : int, default: 0
The lifetime length to split on. Length is inclusive for the smaller
process.
Returns
-------
Two-tuple where the first element is a :class:`PlainMemoryVariable` consisting
of reads with read times smaller than or equal to ``length`` (or None if no such
reads exists), and vice-versa for the other tuple element.
"""
# This method exists only for documentation purposes and for generating correct
# type annotations when calling it. Just call super().split_on_length() in here.
short_process, long_process = super().split_on_length(length)
return (
cast(Optional["PlainMemoryVariable"], short_process),
cast(Optional["PlainMemoryVariable"], long_process),
)
# Static counter for default names
_name_cnt = 0
......@@ -11,7 +11,13 @@ from matplotlib.ticker import MaxNLocator
from b_asic._preferences import LATENCY_COLOR, WARNING_COLOR
from b_asic.codegen.vhdl.common import is_valid_vhdl_identifier
from b_asic.process import MemoryVariable, OperatorProcess, PlainMemoryVariable, Process
from b_asic.process import (
MemoryProcess,
MemoryVariable,
OperatorProcess,
PlainMemoryVariable,
Process,
)
from b_asic.types import TypeName
# Default latency coloring RGB tuple
......@@ -1272,10 +1278,20 @@ class ProcessCollection:
if process.execution_time <= length:
short.append(process)
else:
long.append(process)
return ProcessCollection(
short, schedule_time=self.schedule_time
), ProcessCollection(long, schedule_time=self.schedule_time)
if isinstance(process, MemoryProcess):
# Split this MultiReadProcess into two new processes
p_short, p_long = process.split_on_length(length)
if p_short is not None:
short.append(p_short)
if p_long is not None:
long.append(p_long)
else:
# Not a MultiReadProcess: has only a single read
long.append(process)
return (
ProcessCollection(short, self.schedule_time, self._cyclic),
ProcessCollection(long, self.schedule_time, self._cyclic),
)
def generate_register_based_storage_vhdl(
self,
......
......@@ -10,8 +10,8 @@ def test_PlainMemoryVariable():
assert mem.write_port == 0
assert mem.start_time == 3
assert mem.execution_time == 2
assert mem.life_times == (1, 2)
assert mem.read_ports == (4, 5)
assert mem.life_times == [1, 2]
assert mem.read_ports == [4, 5]
assert repr(mem) == "PlainMemoryVariable(3, 0, {4: 1, 5: 2}, 'Var. 0')"
mem2 = PlainMemoryVariable(2, 0, {4: 2, 5: 3}, 'foo')
......@@ -39,3 +39,35 @@ def test_MemoryVariables(secondorder_iir_schedule):
def test_OperatorProcess_error(secondorder_iir_schedule):
with pytest.raises(ValueError, match="does not have an execution time specified"):
_ = secondorder_iir_schedule.get_operations()
def test_MultiReadProcess():
mv = PlainMemoryVariable(3, 0, {0: 1, 1: 2, 2: 5}, name="MV")
with pytest.raises(KeyError, match=r'Process MV: 3 not in life_times: \[1, 2, 5\]'):
mv._remove_life_time(3)
assert mv.life_times == [1, 2, 5]
assert mv.execution_time == 5
mv._remove_life_time(5)
assert mv.life_times == [1, 2]
assert mv.execution_time == 2
mv._add_life_time(4)
assert mv.execution_time == 4
assert mv.life_times == [1, 2, 4]
mv._add_life_time(4)
assert mv.life_times == [1, 2, 4]
def test_split_on_length():
mv = PlainMemoryVariable(3, 0, {0: 1, 1: 2, 2: 5}, name="MV")
short, long = mv.split_on_length(2)
assert short is not None and long is not None
assert short.start_time == 3 and long.start_time == 3
assert short.execution_time == 2 and long.execution_time == 5
assert short.reads == {0: 1, 1: 2}
assert long.reads == {2: 5}
short, long = mv.split_on_length(0)
assert short is None
assert long is not None
......@@ -232,3 +232,28 @@ class TestProcessCollectionPlainMemoryVariable:
match="MV0 has execution time greater than the schedule time",
):
collection.split_on_execution_time(heuristic)
def test_split_on_length(self):
# Test 1: Exclude a zero-time access time
collection = ProcessCollection(
collection=[PlainMemoryVariable(0, 1, {0: 1, 1: 2, 2: 3})],
schedule_time=4,
)
short, long = collection.split_on_length(0)
assert len(short) == 0 and len(long) == 1
for split_time in [1, 2]:
short, long = collection.split_on_length(split_time)
assert len(short) == 1 and len(long) == 1
short, long = collection.split_on_length(3)
assert len(short) == 1 and len(long) == 0
# Test 2: Include a zero-time access time
collection = ProcessCollection(
collection=[PlainMemoryVariable(0, 1, {0: 0, 1: 1, 2: 2, 3: 3})],
schedule_time=4,
)
short, long = collection.split_on_length(0)
assert len(short) == 1 and len(long) == 1
for split_time in [1, 2]:
short, long = collection.split_on_length(split_time)
assert len(short) == 1 and len(long) == 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment