From e202399243836b57a4b347b1bf59a876e94d8798 Mon Sep 17 00:00:00 2001
From: adaja901 <adaja901@student.liu.se>
Date: Thu, 21 May 2020 22:03:26 +0200
Subject: [PATCH] working not for cases when a delay is in the graph

---
 b_asic/schema.py | 69 +++++++++++++++++++++++++++++++++++-------------
 1 file changed, 51 insertions(+), 18 deletions(-)

diff --git a/b_asic/schema.py b/b_asic/schema.py
index 931e4295..e3cec775 100644
--- a/b_asic/schema.py
+++ b/b_asic/schema.py
@@ -8,6 +8,7 @@ from typing import Dict, List
 from b_asic.signal_flow_graph import SFG
 from b_asic.graph_component import GraphID
 from b_asic.operation import Operation
+from b_asic.special_operations import *
 
 
 class Schema:
@@ -48,8 +49,6 @@ class Schema:
             else:
                 self._schedule_time = schedule_time
 
-        self.get_memory_elements()
-
     def start_time_of_operation(self, op_id: GraphID):
         """Get the start time of the operation with the specified by the op_id."""
         assert op_id in self._start_times, "No operation with the specified op_id in this schema."
@@ -86,7 +85,6 @@ class Schema:
                     # Schedule the operation if it doesn't have a start time yet.
                     op_start_time = 0
                     for inport in op.inputs:
-                        print(inport.operation.graph_id)
                         assert len(inport.signals) == 1, "Error in scheduling, dangling input port detected."
                         assert inport.signals[0].source is not None, "Error in scheduling, signal with no source detected."
                         source_port = inport.signals[0].source
@@ -110,24 +108,59 @@ class Schema:
                     self._start_times[op.graph_id] = op_start_time
 
     def get_memory_elements(self):
-        pl = self._sfg.get_precedence_list()
+        operation_orderd = self._sfg.get_operations_topological_order()
+        
+        for op in operation_orderd:
+            if isinstance(op, Input) or isinstance(op, Output):
+                pass
+            
+            for key in self._start_times:
+                if op.graph_id == key:
+                    for i in range(len(op.outputs)):
+                        time_list = []
+                        start_time = self._start_times.get(op.graph_id)+op.outputs[i].latency_offset
+                        time_list.append(start_time)
+                        for j in range(len(op.outputs[i].signals)):
+                            new_op = self.get_op_after_delay(op.outputs[i].signals[j].destination.operation, op.outputs[i].signals[j].destination)
+                            
+                            end_start_time = self._start_times.get(new_op[0].graph_id)
+                            end_start_time_latency_offset = new_op[1].latency_offset
+                            
+                            if end_start_time_latency_offset is None:
+                                end_start_time_latency_offset = 0
+                            if end_start_time is None:
+                                end_time = self._schedule_time
+                            else:
+                                end_time = end_start_time + end_start_time_latency_offset
+                            
+                            time_list.append(end_time)
+                            read_name = op.name
+                            write_name = new_op[0].name
+                            key_name = read_name + "->" + write_name
+                            self._memory_elements[key_name] = time_list
+                            
+
+
+    def get_op_after_delay(self, op, destination):
+        if isinstance(op, Delay):
+            for i in range(len(op.outputs[0].signals)):
+                connected_op = op.outputs[0].signals[i].destination.operation
+                dest = op.outputs[0].signals[i].destination
+                return self.get_op_after_delay(connected_op, dest)
+        
+        return [op, destination]
 
-        for port_list in pl:
-            for port in port_list:
-                time_list = []
-                for key in self._start_times:
-                    if port.operation.graph_id == key:
-                        for i in range(len(port.operation.outputs)-1):
-                            time_list.append(self.start_times.get(port.operation.graph_id)+op.outputs[i].latency_offset)
-                            for j in range(len(port.operation.outputs[i].signals)-1):
-                                print(self._start_times.get(port.operation.outputs[i].signals[j].destination.operation.op_id))
-                                time_list.append(self._start_times.get(port.operation.outputs[i].signals[j].destination.operation.op_id))
-                                self._memory_elements[port.operation.outputs[i].signals[j].type_name] = time_list
-                                time_list.pop() # remove end time if an output port has several signals as it's source
 
     def print_memory_elements(self):
+        self.get_memory_elements()
+        output_string = ""
         for key in self._memory_elements:
-            for value in self._memory_elements[key]: 
-                print(key, value)
+            output_string += key
+            output_string += ": start time: " 
+            output_string += str(self._memory_elements[key][0])
+            output_string += " end time: "
+            output_string += str(self._memory_elements[key][1])
+            output_string += '\n'
+        print(output_string)
 
 
-- 
GitLab