Skip to content
Snippets Groups Projects
Commit 1a8eb53b authored by Johan Edstedt's avatar Johan Edstedt :speech_balloon:
Browse files

Merge branch '80-consider-tracking-alternative' into 'master'

Resolve "Consider tracking alternative"

Closes #80

See merge request liuhomewreckers/liu-home-wreckers!18
parents ba0db039 3c6487ba
No related branches found
No related tags found
No related merge requests found
......@@ -27,6 +27,10 @@ find_package(std_msgs REQUIRED)
find_package(geometry_msgs REQUIRED)
find_package(rosidl_default_generators REQUIRED)
# FIX INCLUDE DIRS FOR NUMPY
set(PYTHON_INCLUDE_DIRS ${PYTHON_INCLUDE_DIRS} /usr/local/lib/python3.8/dist-packages/numpy/core/include)
include_directories(${PYTHON_INCLUDE_DIRS})
# ADD NEW INTERFACES HERE!!!
rosidl_generate_interfaces(${PROJECT_NAME}
"msg_intelligence/Behaviour.msg"
......
......@@ -3,22 +3,23 @@ ARG NO_GPU=0
# Install pytorch
RUN if [ "$NO_GPU" = 0 ]; then \
pip3 install torch torchvision; \
pip3 install torch torchvision timm; \
else \
pip3 install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html; fi
# Install Densesense
RUN python3 -m pip install Cython setuptools av scipy git+https://github.com/Axelwickm/DenseSense.git
# Install DenseSense (if no GPU)
RUN if [ "$NO_GPU" = 0 ]; then \
RUN if [ "$NO_GPU" = 2 ]; then \
python3 -m pip install Cython setuptools av scipy git+https://github.com/Axelwickm/DenseSense.git && \
python3 -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.6/index.html && \
cd / && git clone https://github.com/facebookresearch/detectron2; fi
# Install Yolo5
RUN apt update && apt install -y ffmpeg libsm6 libxext6
RUN if [ "$NO_GPU" = 0 ]; then \
RUN if [ "$NO_GPU" = 2 ]; then \
git clone https://github.com/ultralytics/yolov5 /workspace/yolov5 && \
cd /workspace/yolov5 && \
pip3 install pandas seaborn && \
python3 -c "from models.experimental import attempt_load; attempt_load('yolov5l.pt')" ; fi
# Install deep-person-reid
......@@ -26,6 +27,14 @@ RUN if [ "$NO_GPU" = 0 ]; then \
git clone https://github.com/KaiyangZhou/deep-person-reid /workspace/deep-person-reid/ && \
python3 -m pip install gdown timm; fi
# Install ByteTrack
RUN if [ "$NO_GPU" = 0 ]; then \
git clone https://github.com/Parskatt/ByteTrack /workspace/bytetrack && \
pip3 install -e /workspace/bytetrack &&\
bash /workspace/bytetrack/get_bytetrack_pretrained.sh &&\
pip3 install -r /workspace/bytetrack/requirements.txt &&\
pip3 install cython_bbox; fi
# Install dependencies
COPY src/lhw_interfaces src/lhw_interfaces
COPY src/lhw_vision src/lhw_vision
......
......@@ -6,8 +6,15 @@ def generate_launch_description():
return LaunchDescription([
Node(
package='lhw_vision',
executable='yolo_node'
executable='bytetrack_node'
),
Node(
package='lhw_vision',
executable='image_to_world_node'
),
]
)
"""
Node(
package='lhw_vision',
executable='densesense'
......@@ -18,7 +25,6 @@ def generate_launch_description():
),
Node(
package='lhw_vision',
executable='image_to_world_node'
executable='yolo_node'
),
]
)
"""
\ No newline at end of file
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import Image,CompressedImage
from lhw_interfaces.msg import Entity, Entities
from builtin_interfaces.msg import Time
from cv_bridge import CvBridge, CvBridgeError
from collections import OrderedDict
from lhw_vision.bytetracker import ByteTrack
from lhw_vision.utils import save_image
from lhw_vision import VISION_ROOT
import numpy as np
import time
class Tracker(Node):
"""Subscribes to various sources of identified entities (YOLO, etc.),
give them an ID which is persistent over time, cleans up messy input, fills in when data is missing,
and compiles it all onto a published topic ("tracked_entities"). All of the published entities on
this topic contain a list of where they originate (might have multiple origins)
and their original IDs so that the original data might be queried by other parts of the system.
"""
def __init__(self):
super().__init__("tracker")
self.log = self.get_logger()
self.images = {}
self.log.info(str(ByteTrack))
self.tracker = ByteTrack(self.log)
self.image_sub = self.create_subscription(Image, '/image',
self.image_callback, 10)
self.tracked_targets = self.create_publisher(Entities, "tracked_targets", 10)
self.bridge = CvBridge()
self.log.info("Now running Tracker")
self.last_time = time.time()
def image_callback(self,image_msg: Image):
""" Callback upon pepper image being updated. If entities with a corresponding timestamp has been received, self.track is called.
If not, then add to self.images
Args:
image_msg: Image from Pepper
"""
timestamp = image_msg.header.stamp
image = self.bridge.imgmsg_to_cv2(image_msg,desired_encoding='rgb8')
self.track(image, timestamp)
def track(self, image: np.ndarray, timestamp: Time) -> Entities:
"""Tracks any detected objects
Args:
image (np.ndarray): Image from Pepper.
entities (dict): The attribute name to use on each object of the iterable.
timestamp (Time): The received time
Returns:
tracked_targets (Entities): The tracked targets
"""
t = time.time()
self.log.info(f'Waited for {t-self.last_time}.')
self.last_time = t
height, width = image.shape[:2]
t0 = time.perf_counter()
targets = self.tracker.track(image)
self.log.info(str(time.perf_counter()-t0))
msg = Entities()
msg.entities = [Entity(bbox=tar.tlbr,type='person',id=tar.track_id,
sources=[0],sources_ids=[0],confidence=float(tar.score))
for tar in targets if tar.time_since_update==0]
msg.source = msg.sources_types.TRACKER
msg.source_height = int(height)
msg.source_width = int(width)
msg.time = timestamp
self.tracked_targets.publish(msg)
def main(args=None):
rclpy.init(args=args)
tracker_node = Tracker()
rclpy.spin(tracker_node)
tracker_node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()
from .bytetrack import ByteTrack
\ No newline at end of file
import cv2
import torch
from PIL import Image, ImageDraw
from yolox.data.data_augment import preproc
from yolox.exp import get_exp
from yolox.utils import fuse_model, get_model_info, postprocess, vis
from yolox.utils.visualize import plot_tracking
from yolox.tracker.byte_tracker import BYTETracker
from yolox.tracking_utils.timer import Timer
import argparse
import os
import time
from lhw_vision import VISION_ROOT
IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
def make_parser():
parser = argparse.ArgumentParser("ByteTrack Tracker!")
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
parser.add_argument("-n", "--name", type=str, default=None, help="model name")
# exp file
parser.add_argument(
"-f",
"--exp_file",
default='/workspace/bytetrack/exps/example/mot/yolox_l_mix_det.py',
type=str,
help="pls input your expriment description file",
)
parser.add_argument("-c", "--ckpt", default='/tmp/bytetrack_l_mot17.pth.tar', type=str, help="ckpt for eval")
parser.add_argument(
"--device",
default="gpu",
type=str,
help="device to run our model, can either be cpu or gpu",
)
parser.add_argument("--conf", default=None, type=float, help="test conf")
parser.add_argument("--nms", default=None, type=float, help="test nms threshold")
parser.add_argument("--tsize", default=None, type=int, help="test img size")
parser.add_argument(
"--fp16",
dest="fp16",
default=False,
action="store_true",
help="Adopting mix precision evaluating.",
)
parser.add_argument(
"--fuse",
dest="fuse",
default=False,
action="store_true",
help="Fuse conv and bn for testing.",
)
parser.add_argument(
"--trt",
dest="trt",
default=False,
action="store_true",
help="Using TensorRT model for testing.",
)
# tracking args
parser.add_argument("--track_thresh", type=float, default=0.5, help="tracking confidence threshold")
parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
parser.add_argument("--match_thresh", type=int, default=0.8, help="matching threshold for tracking")
parser.add_argument('--min-box-area', type=float, default=10, help='filter out tiny boxes')
parser.add_argument('--mot20', default=False, action="store_true", help='mot20')
return parser
class Predictor(object):
def __init__(
self,
model,
exp,
trt_file=None,
decoder=None,
device="cpu",
fp16=False
):
self.model = model
self.decoder = decoder
self.num_classes = exp.num_classes
self.confthre = exp.test_conf
self.nmsthre = exp.nmsthre
self.test_size = exp.test_size
self.device = device
self.fp16 = fp16
self.rgb_means = (0.485, 0.456, 0.406)
self.std = (0.229, 0.224, 0.225)
def inference(self, img, timer):
img_info = {"id": 0}
if isinstance(img, str):
img_info["file_name"] = os.path.basename(img)
img = cv2.imread(img)
else:
img_info["file_name"] = None
orig_img = img
height, width = img.shape[:2]
img_info["height"] = height
img_info["width"] = width
img_info["raw_img"] = img
img, ratio = preproc(img, self.test_size, self.rgb_means, self.std)
img_info["ratio"] = ratio
img = torch.from_numpy(img).unsqueeze(0)
img = img.float()
img = img.cuda()
with torch.no_grad():
timer.tic()
outputs = self.model(img)
if self.decoder is not None:
outputs = self.decoder(outputs, dtype=outputs.type())
outputs = postprocess(
outputs, self.num_classes, self.confthre, self.nmsthre
)
#logger.info("Infer time: {:.4f}s".format(time.time() - t0))
return outputs, img_info
class ByteTrack:
def __init__(self, logger) -> None:
self.logger = logger
args,unknown_args = make_parser().parse_known_args()
self.logger.info(str(args))
self.exp = get_exp(args.exp_file, args.name)
self.timer = Timer()
self.tracker = BYTETracker(args, frame_rate=30)
if not args.experiment_name:
args.experiment_name = self.exp.exp_name
self.frame_id = 0
if args.conf is not None:
self.exp.test_conf = args.conf
if args.nms is not None:
self.exp.nmsthre = args.nms
if args.tsize is not None:
self.exp.test_size = (args.tsize, args.tsize)
model = self.exp.get_model()
model.cuda()
model.eval()
if not args.trt:
ckpt_file = args.ckpt
self.logger.info("loading checkpoint")
ckpt = torch.load(ckpt_file)#, map_location="cpu")
# load the model state dict
model.load_state_dict(ckpt["model"])
self.logger.info("loaded checkpoint done.")
trt_file = None
decoder = None
self.predictor = Predictor(model, self.exp, trt_file, decoder, args.device, args.fp16)
self.tracker = BYTETracker(args, frame_rate=10)
self.timer = Timer()
self.frame_id = 0
def track(self,frame):
outputs, img_info = self.predictor.inference(frame, self.timer)
online_targets = self.tracker.update(outputs[0], [img_info['height'], img_info['width']], self.exp.test_size)
self.visualize(frame, online_targets)
return online_targets
def visualize(self,image,outputs):
pil_im = Image.fromarray(image)
draw = ImageDraw.Draw(pil_im)
for tar in outputs:
bb = list(tar.tlbr)
draw.rectangle(bb)
draw.text(bb[:2], f"{tar.track_id}")
pil_im.save(f"{VISION_ROOT}/static/bb_im.jpg")
......@@ -13,7 +13,9 @@ def depth_to_xyz(bbox,depth,K):
x = z*x'
y = z*y'
"""
h,w = depth.shape
u,v = (bbox[0]+bbox[2])//2,(bbox[1]+bbox[3])//2 # get the middle of the box
v,u = min(max(0,v),h-1),min(max(0,u),w-1)
coord = np.array([u,v,1])
x_prim,y_prim,_ = np.linalg.solve(K,coord)
assert abs(_-1.0) < 1e-4,f"K is not correct"
......
......@@ -69,7 +69,7 @@ class Tracker(Node):
height, width = image.shape[:2]
targets = self.tracker.track(entities,image)
msg = Entities()
msg.entities = [Entity(bbox=tar['bbox'],type=tar['type'],id=tar['tracked_id'],
msg.entities = [Entity(bbox=tar.tlbr,type=tar['type'],id=tar['tracked_id'],
sources=tar['sources'],sources_ids=[tar['id']],confidence=tar['confidence'])
for tars in targets.values() for tar in tars.values() if tar['missing']==0]
msg.source = msg.sources_types.TRACKER
......
......@@ -27,6 +27,7 @@ setup(
'image_to_world_node = lhw_vision.image_to_world_node:main',
'tracker_node = lhw_vision.tracker_node:main',
'yolo_node = lhw_vision.yolo_node:main',
'bytetrack_node = lhw_vision.bytetrack_node:main',
'densesense = lhw_vision.densesense_wrapper:main',
'youtube_node = lhw_vision.youtube_node:main',
'webcam_node = lhw_vision.webcam_node:main',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment