Skip to content
Snippets Groups Projects
Commit 3c6487ba authored by Johan Edstedt's avatar Johan Edstedt :speech_balloon:
Browse files

tracker works

parent dc40e592
No related branches found
No related tags found
No related merge requests found
......@@ -32,7 +32,8 @@ RUN if [ "$NO_GPU" = 0 ]; then \
git clone https://github.com/Parskatt/ByteTrack /workspace/bytetrack && \
pip3 install -e /workspace/bytetrack &&\
bash /workspace/bytetrack/get_bytetrack_pretrained.sh &&\
pip3 install -r /workspace/bytetrack/requirements.txt; fi
pip3 install -r /workspace/bytetrack/requirements.txt &&\
pip3 install cython_bbox; fi
# Install dependencies
COPY src/lhw_interfaces src/lhw_interfaces
......
......@@ -9,6 +9,7 @@ from lhw_vision.bytetracker import ByteTrack
from lhw_vision.utils import save_image
from lhw_vision import VISION_ROOT
import numpy as np
import time
class Tracker(Node):
"""Subscribes to various sources of identified entities (YOLO, etc.),
......@@ -23,11 +24,12 @@ class Tracker(Node):
self.images = {}
self.log.info(str(ByteTrack))
self.tracker = ByteTrack(self.log)
self.image_sub = self.create_subscription(Image, 'image',
self.image_sub = self.create_subscription(Image, '/image',
self.image_callback, 10)
self.tracked_targets = self.create_publisher(Entities, "tracked_targets", 10)
self.bridge = CvBridge()
self.log.info("Now running Tracker")
self.last_time = time.time()
def image_callback(self,image_msg: Image):
""" Callback upon pepper image being updated. If entities with a corresponding timestamp has been received, self.track is called.
......@@ -47,12 +49,17 @@ class Tracker(Node):
Returns:
tracked_targets (Entities): The tracked targets
"""
t = time.time()
self.log.info(f'Waited for {t-self.last_time}.')
self.last_time = t
height, width = image.shape[:2]
t0 = time.perf_counter()
targets = self.tracker.track(image)
self.log.info(str(time.perf_counter()-t0))
msg = Entities()
msg.entities = [Entity(bbox=tar.tlbr,type='person',id=tar.track_id,
sources=0,sources_ids=0,confidence=tar.score)
for tars in targets.values() for tar in tars.values() if tar.time_since_update==0]
sources=[0],sources_ids=[0],confidence=float(tar.score))
for tar in targets if tar.time_since_update==0]
msg.source = msg.sources_types.TRACKER
msg.source_height = int(height)
msg.source_width = int(width)
......
......@@ -2,6 +2,7 @@
import cv2
import torch
from PIL import Image, ImageDraw
from yolox.data.data_augment import preproc
from yolox.exp import get_exp
......@@ -14,6 +15,8 @@ import argparse
import os
import time
from lhw_vision import VISION_ROOT
IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
......@@ -65,6 +68,7 @@ def make_parser():
parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
parser.add_argument("--match_thresh", type=int, default=0.8, help="matching threshold for tracking")
parser.add_argument('--min-box-area', type=float, default=10, help='filter out tiny boxes')
parser.add_argument('--mot20', default=False, action="store_true", help='mot20')
return parser
class Predictor(object):
......@@ -95,12 +99,11 @@ class Predictor(object):
img = cv2.imread(img)
else:
img_info["file_name"] = None
orig_img = img
height, width = img.shape[:2]
img_info["height"] = height
img_info["width"] = width
img_info["raw_img"] = img
img, ratio = preproc(img, self.test_size, self.rgb_means, self.std)
img_info["ratio"] = ratio
img = torch.from_numpy(img).unsqueeze(0)
......@@ -163,4 +166,15 @@ class ByteTrack:
def track(self,frame):
outputs, img_info = self.predictor.inference(frame, self.timer)
online_targets = self.tracker.update(outputs[0], [img_info['height'], img_info['width']], self.exp.test_size)
return online_targets
\ No newline at end of file
self.visualize(frame, online_targets)
return online_targets
def visualize(self,image,outputs):
pil_im = Image.fromarray(image)
draw = ImageDraw.Draw(pil_im)
for tar in outputs:
bb = list(tar.tlbr)
draw.rectangle(bb)
draw.text(bb[:2], f"{tar.track_id}")
pil_im.save(f"{VISION_ROOT}/static/bb_im.jpg")
......@@ -13,7 +13,9 @@ def depth_to_xyz(bbox,depth,K):
x = z*x'
y = z*y'
"""
h,w = depth.shape
u,v = (bbox[0]+bbox[2])//2,(bbox[1]+bbox[3])//2 # get the middle of the box
v,u = min(max(0,v),h-1),min(max(0,u),w-1)
coord = np.array([u,v,1])
x_prim,y_prim,_ = np.linalg.solve(K,coord)
assert abs(_-1.0) < 1e-4,f"K is not correct"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment