Pytorch based library to rank predicted bounding boxes using text/image user's prompts.

Overview

pytorch_clip_bbox: Implementation of the CLIP guided bbox ranking for Object Detection.

Pytorch based library to rank predicted bounding boxes using text/image user's prompts.

Usually, object detection models trains to detect common classes of objects such as "car", "person", "cup", "bottle". But sometimes we need to detect more complex classes such as "lady in the red dress", "bottle of whiskey", or "where is my red cup" instead of "person", "bottle", "cup" respectively. One way to solve this problem is to train more complex detectors that can detect more complex classes, but we propose to use text-driven object detection that allows detecting any complex classes that can be described by natural language. This library is written to rank predicted bounding boxes using text/image descriptions of complex classes.

Install package

pip install pytorch_clip_bbox

Install the latest version

pip install --upgrade git+https://github.com/bes-dev/pytorch_clip_bbox.git

Features

  • The library supports multiple prompts (images or texts) as targets for filtering.
  • The library automatically detects the language of the input text, and multilingual translate it via google translate.
  • The library supports the original CLIP model by OpenAI and ruCLIP model by SberAI.
  • Simple integration with different object detection models.

Usage

We provide examples to integrate our library with different popular object detectors like: YOLOv5, MaskRCNN. Please, follow to examples to find more examples.

Simple example to integrate pytorch_clip_bbox with MaskRCNN model

$ pip install -r wheel cython opencv-python numpy torch torchvision pytorch_clip_bbox
args.confidence][-1] boxes = [[int(b) for b in box] for box in list(pred[0]['boxes'].detach().cpu().numpy())][:pred_threshold + 1] masks = (pred[0]['masks'] > 0.5).squeeze().detach().cpu().numpy()[:pred_threshold + 1] ranking = clip_bbox(image, boxes, top_k=args.top_k) for key in ranking.keys(): if key == "loss": continue for box in ranking[key]["ranking"]: mask, color = get_coloured_mask(masks[box["idx"]]) image = cv2.addWeighted(image, 1, mask, 0.5, 0) x1, y1, x2, y2 = box["rect"] cv2.rectangle(image, (x1, y1), (x2, y2), color, 6) cv2.rectangle(image, (x1, y1), (x2, y1-100), color, -1) cv2.putText(image, ranking[key]["src"], (x1 + 5, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 4, (0, 0, 0), thickness=5) if args.output_image is None: cv2.imshow("image", image) cv2.waitKey() else: cv2.imwrite(args.output_image, image) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-i", "--image", type=str, help="Input image.") parser.add_argument("--device", type=str, default="cuda:0", help="inference device.") parser.add_argument("--confidence", type=float, default=0.7, help="confidence threshold [MaskRCNN].") parser.add_argument("--text-prompt", type=str, default=None, help="Text prompt.") parser.add_argument("--image-prompt", type=str, default=None, help="Image prompt.") parser.add_argument("--clip-type", type=str, default="clip_vit_b32", help="Type of CLIP model [ruclip, clip_vit_b32, clip_vit_b16].") parser.add_argument("--top-k", type=int, default=1, help="top_k predictions will be returned.") parser.add_argument("--output-image", type=str, default=None, help="Output image name.") args = parser.parse_args() main(args)">
import argparse
import random
import cv2
import numpy as np
import torch
import torchvision.transforms as T
import torchvision
from pytorch_clip_bbox import ClipBBOX

def get_coloured_mask(mask):
    colours = [[0, 255, 0],[0, 0, 255],[255, 0, 0],[0, 255, 255],[255, 255, 0],[255, 0, 255],[80, 70, 180],[250, 80, 190],[245, 145, 50],[70, 150, 250],[50, 190, 190]]
    r = np.zeros_like(mask).astype(np.uint8)
    g = np.zeros_like(mask).astype(np.uint8)
    b = np.zeros_like(mask).astype(np.uint8)
    c = colours[random.randrange(0,10)]
    r[mask == 1], g[mask == 1], b[mask == 1] = c
    coloured_mask = np.stack([r, g, b], axis=2)
    return coloured_mask, c

def main(args):
    # build detector
    detector = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True).eval().to(args.device)
    clip_bbox = ClipBBOX(clip_type=args.clip_type).to(args.device)
    # add prompts
    if args.text_prompt is not None:
        for prompt in args.text_prompt.split(","):
            clip_bbox.add_prompt(text=prompt)
    if args.image_prompt is not None:
        image = cv2.cvtColor(cv2.imread(args.image_prompt), cv2.COLOR_BGR2RGB)
        image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
        image = img / 255.0
        clip_bbox.add_prompt(image=image)
    image = cv2.imread(args.image)
    pred = detector([
        T.ToTensor()(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).to(args.device)
    ])
    pred_score = list(pred[0]['scores'].detach().cpu().numpy())
    pred_threshold = [pred_score.index(x) for x in pred_score if x > args.confidence][-1]
    boxes = [[int(b) for b in box] for box in list(pred[0]['boxes'].detach().cpu().numpy())][:pred_threshold + 1]
    masks = (pred[0]['masks'] > 0.5).squeeze().detach().cpu().numpy()[:pred_threshold + 1]
    ranking = clip_bbox(image, boxes, top_k=args.top_k)
    for key in ranking.keys():
        if key == "loss":
            continue
        for box in ranking[key]["ranking"]:
            mask, color = get_coloured_mask(masks[box["idx"]])
            image = cv2.addWeighted(image, 1, mask, 0.5, 0)
            x1, y1, x2, y2 = box["rect"]
            cv2.rectangle(image, (x1, y1), (x2, y2), color, 6)
            cv2.rectangle(image, (x1, y1), (x2, y1-100), color, -1)
            cv2.putText(image, ranking[key]["src"], (x1 + 5, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 4, (0, 0, 0), thickness=5)
    if args.output_image is None:
        cv2.imshow("image", image)
        cv2.waitKey()
    else:
        cv2.imwrite(args.output_image, image)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--image", type=str, help="Input image.")
    parser.add_argument("--device", type=str, default="cuda:0", help="inference device.")
    parser.add_argument("--confidence", type=float, default=0.7, help="confidence threshold [MaskRCNN].")
    parser.add_argument("--text-prompt", type=str, default=None, help="Text prompt.")
    parser.add_argument("--image-prompt", type=str, default=None, help="Image prompt.")
    parser.add_argument("--clip-type", type=str, default="clip_vit_b32", help="Type of CLIP model [ruclip, clip_vit_b32, clip_vit_b16].")
    parser.add_argument("--top-k", type=int, default=1, help="top_k predictions will be returned.")
    parser.add_argument("--output-image", type=str, default=None, help="Output image name.")
    args = parser.parse_args()
    main(args)
Owner
Sergei Belousov
Sergei Belousov
A framework for the elicitation, specification, formalization and understanding of requirements.

A framework for the elicitation, specification, formalization and understanding of requirements.

NASA - Software V&V 161 Jan 03, 2023
DEMix Layers for Modular Language Modeling

DEMix This repository contains modeling utilities for "DEMix Layers: Disentangling Domains for Modular Language Modeling" (Gururangan et. al, 2021). T

Suchin 43 Nov 11, 2022
JFB: Jacobian-Free Backpropagation for Implicit Models

JFB: Jacobian-Free Backpropagation for Implicit Models

Typal Research 28 Dec 11, 2022
This repo contains the code required to train the multivariate time-series Transformer.

Multi-Variate Time-Series Transformer This repo contains the code required to train the multivariate time-series Transformer. Download the data The No

Gregory Duthé 4 Nov 24, 2022
Project dự đoán giá cổ phiếu bằng thuật toán LSTM gồm: code train và code demo

Web predicts stock prices using Long - Short Term Memory algorithm Give me some start please!!! User interface image: Choose: DayBegin, DayEnd, Stock

Vo Thuong Truong Nhon 8 Nov 11, 2022
Good Semi-Supervised Learning That Requires a Bad GAN

Good Semi-Supervised Learning that Requires a Bad GAN This is the code we used in our paper Good Semi-supervised Learning that Requires a Bad GAN Ziha

Zhilin Yang 177 Dec 12, 2022
Implementation of gaze tracking and demo

Predicting Customer Demand by Using Gaze Detecting and Object Tracking This project is the integration of gaze detecting and object tracking. Predict

2 Oct 20, 2022
Towhee is a flexible machine learning framework currently focused on computing deep learning embeddings over unstructured data.

Towhee is a flexible machine learning framework currently focused on computing deep learning embeddings over unstructured data.

1.7k Jan 08, 2023
We are More than Our JOints: Predicting How 3D Bodies Move

We are More than Our JOints: Predicting How 3D Bodies Move Citation This repo contains the official implementation of our paper MOJO: @inproceedings{Z

72 Oct 20, 2022
Real time sign language recognition

The proposed work aims at converting american sign language gestures into English that can be understood by everyone in real time.

Mohit Kaushik 6 Jun 13, 2022
N-Person-Check-Checker-Splitter - A calculator app use to divide checks

N-Person-Check-Checker-Splitter This is my from-scratch programmed calculator ap

2 Feb 15, 2022
Implements pytorch code for the Accelerated SGD algorithm.

AccSGD This is the code associated with Accelerated SGD algorithm used in the paper On the insufficiency of existing momentum schemes for Stochastic O

205 Jan 02, 2023
Cleaned up code for DSTC 10: SIMMC 2.0 track: subtask 2: multimodal coreference resolution

UNITER-Based Situated Coreference Resolution with Rich Multimodal Input: arXiv MMCoref_cleaned Code for the MMCoref task of the SIMMC 2.0 dataset. Pre

Yichen (William) Huang 2 Dec 05, 2022
High-Fidelity Pluralistic Image Completion with Transformers (ICCV 2021)

Image Completion Transformer (ICT) Project Page | Paper (ArXiv) | Pre-trained Models | Supplemental Material This repository is the official pytorch i

Ziyu Wan 243 Jan 03, 2023
Official Implementation of DE-DETR and DELA-DETR in "Towards Data-Efficient Detection Transformers"

DE-DETRs By Wen Wang, Jing Zhang, Yang Cao, Yongliang Shen, and Dacheng Tao This repository is an official implementation of DE-DETR and DELA-DETR in

Wen Wang 61 Dec 12, 2022
Collection of TensorFlow2 implementations of Generative Adversarial Network varieties presented in research papers.

TensorFlow2-GAN Collection of tf2.0 implementations of Generative Adversarial Network varieties presented in research papers. Model architectures will

41 Apr 28, 2022
StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators

StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators [Project Website] [Replicate.ai Project] StyleGAN-NADA: CLIP-Guided Domain Adaptation

992 Dec 30, 2022
(CVPR 2021) Lifting 2D StyleGAN for 3D-Aware Face Generation

Lifting 2D StyleGAN for 3D-Aware Face Generation Official implementation of paper "Lifting 2D StyleGAN for 3D-Aware Face Generation". Requirements You

Yichun Shi 66 Nov 29, 2022
This repository provides some of the code implemented and the data used for the work proposed in "A Cluster-Based Trip Prediction Graph Neural Network Model for Bike Sharing Systems".

cluster-link-prediction This repository provides some of the code implemented and the data used for the work proposed in "A Cluster-Based Trip Predict

Bárbara 0 Dec 28, 2022
Awesome AI Learning with +100 AI Cheat-Sheets, Free online Books, Top Courses, Best Videos and Lectures, Papers, Tutorials, +99 Researchers, Premium Websites, +121 Datasets, Conferences, Frameworks, Tools

All about AI with Cheat-Sheets(+100 Cheat-sheets), Free Online Books, Courses, Videos and Lectures, Papers, Tutorials, Researchers, Websites, Datasets

Niraj Lunavat 1.2k Jan 01, 2023