Generative Query Network (GQN) in PyTorch as described in "Neural Scene Representation and Rendering"

Overview

Update 2019/06/24: A model trained on 10% of the Shepard-Metzler dataset has been added, the following notebook explains the main features of this model: nbviewer

Generative Query Network

This is a PyTorch implementation of the Generative Query Network (GQN) described in the DeepMind paper "Neural scene representation and rendering" by Eslami et al. For an introduction to the model and problem described in the paper look at the article by DeepMind.

The current implementation generalises to any of the datasets described in the paper. However, currently, only the Shepard-Metzler dataset has been implemented. To use this dataset you can use the provided script in

sh scripts/data.sh data-dir batch-size

The model can be trained in full by in accordance to the paper by running the file run-gqn.py or by using the provided training script

sh scripts/gpu.sh data-dir

Implementation

The implementation shown in this repository consists of all of the representation architectures described in the paper along with the generative model that is similar to the one described in "Towards conceptual compression" by Gregor et al.

Additionally, this repository also contains implementations of the DRAW model and the ConvolutionalDRAW model both described by Gregor et al.

Comments
  • Training time and testing demo

    Training time and testing demo

    Hi Jesper,

    Thank you for your great code of gqn in real image, I am a little curious about the following issues: How many epochs it use to train a model on real image? How many training data do you use (percentage of full training dataset)? Can you show a testing demo?

    Thank you very much!

    Best wishes, Mingjia Chen

    opened by mjchen611 22
  • ConvLSTM did not concat hidden from last round

    ConvLSTM did not concat hidden from last round

    In the structure presented in the paper, the hidden from last round is concat with input and then proceed for other operation. But it seems your LSTM did not use the hidden information from previous round.

    opened by Tom-the-Cat 7
  • Bad images in training

    Bad images in training

    While playing around with the sm5 dataset, I noticed some of them are badly rendered. individualimage Not sure if this will pose any problem for training, just wanted to point this out.

    opened by versatran01 7
  • Question about generator

    Question about generator

    In the top docstring of generator.py, you mentioned that

    The inference-generator architecture is conceptually
    similar to the encoder-decoder pair seen in variational
    autoencoders.
    

    I don't quite understand this part and I would really appreciate if you could explain a bit or point me at some related aritcles. For the generator I can see how it is similar to a decoder, where it takes latent z, query viewpoint v, and aggregated representation r and eventually output the image x_mu.

    But I'm a bit confused by the inference being the conterpart of encoder.

    opened by versatran01 7
  • Loss Change

    Loss Change

    Dear wohlert,

    May I consult you several questions?

    1. I tried to train this network on Mazes Data from https://github.com/deepmind/gqn-datasets. Actually it just contains 5% data, which is around 110000, instead of the full data. Is it right?

    2. I trained 30000 steps, but the elbo loss only converged to 6800 which has a big difference compared to around 7 in the supplementary. So may I ask what is the approximate value do you achieve on the data you used?

    3. From the visualisation based on Question 2, the reconstruction seems to be reasonable. But the sampling results is quite bad. Do you meet the same problem?

    Many thanks, Bing

    opened by BingCS 5
  • Questions on data preparing

    Questions on data preparing

    Hi, Wohlert:

    After the data conversion with your scripts, I visualize some of the images in the *.pt found pictures like this Figure_1-1

    What's wrong with that Also I'm confused about your batch operation , say if you batch the sequences as you convert them, does it mean that you won't batch them again when use dataloader?

    Thanks

    opened by Kyridiculous2 5
  • Training crashes at the same spot for both Shepard Metzler datasets

    Training crashes at the same spot for both Shepard Metzler datasets

    Some context:

    • I downloaded and converted the datasets via data.sh and set batch size to 12. Note that I am using TensorFlow 1.14 for reading the tfrecord files and converting them.
    • I use gpu.sh to run the training script. I set the batch size to either of [1,12,36,72] and DataParallel to True to use 4 GPUs

    But after a shrot time I get the following errors if I use any batch size higher than 1. This happens on iterations 40, 13 and 6 with batch sizes 12, 36 and 72. This happens for both Shepard Metzler datasets. Why I am getting these errors? Does batch size 1 on the training code mean reading one of the .pt.gz files? If so, setting batch size to 1 in the training script should actually mean 12. Would that be correct?

    Here's what I get for the data set with 5 parts when I set batch size to 36 for instance:

    Epoch [1/200]: [13/1856]   1%|▊                                                                                                                       , elbo=-2.1e+4, kl=827, mu=5e-6, sigma=2 [00:21<52:34]Current run is terminating due to exception: Caught RuntimeError in DataLoader worker process 13.
    Original Traceback (most recent call last):
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
        data = fetcher.fetch(index)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
        return self.collate_fn(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
        return torch.stack(batch, 0, out=out)
    RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 12 and 8 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689
    .
    Engine run is terminating due to exception: Caught RuntimeError in DataLoader worker process 13.
    Original Traceback (most recent call last):
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
        data = fetcher.fetch(index)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
        return self.collate_fn(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
        return torch.stack(batch, 0, out=out)
    RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 12 and 8 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689
    .
    Traceback (most recent call last):
      File "../run-gqn.py", line 183, in <module>
        trainer.run(train_loader, args.n_epochs)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 850, in run
        return self._internal_run()
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 952, in _internal_run
        self._handle_exception(e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 714, in _handle_exception
        self._fire_event(Events.EXCEPTION_RAISED, e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 607, in _fire_event
        func(self, *(event_args + args), **kwargs)
      File "../run-gqn.py", line 181, in handle_exception
        else: raise e
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 937, in _internal_run
        hours, mins, secs = self._run_once_on_dataset()
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 705, in _run_once_on_dataset
        self._handle_exception(e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 714, in _handle_exception
        self._fire_event(Events.EXCEPTION_RAISED, e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 607, in _fire_event
        func(self, *(event_args + args), **kwargs)
      File "../run-gqn.py", line 181, in handle_exception
        else: raise e
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 655, in _run_once_on_dataset
        batch = next(self._dataloader_iter)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 801, in __next__
        return self._process_data(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 846, in _process_data
        data.reraise()
      File "/usr/local/lib/python3.6/dist-packages/torch/_utils.py", line 385, in reraise
        raise self.exc_type(msg)
    RuntimeError: Caught RuntimeError in DataLoader worker process 13.
    Original Traceback (most recent call last):
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
        data = fetcher.fetch(index)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
        return self.collate_fn(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
        return torch.stack(batch, 0, out=out)
    RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 12 and 8 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689
    
    opened by Amir-Arsalan 4
  • AttributeError: 'int' object has no attribute 'size'

    AttributeError: 'int' object has no attribute 'size'

    in draw.py, I get this error at the line 118 (batch_size = z.size(0)) Sorry if this is obvious, thanks for help anyway.

    ~ % pip show torch :( Name: torch Version: 1.0.1.post2

    opened by DRM-Free 4
  • Increase dimension of viewpoint and representation

    Increase dimension of viewpoint and representation

    Thanks for this implementation. One question I have is when increasing the dimension of viewpoint and representation, you use torch.repeat. Is there any reason for this? Can one possibly use interpolate?

    In the original paper it says "when concatenating viewpoint v to an image or feature map, its values are ‘broadcast’ in the spatial dimensions to obtain the correct size. "

    The word 'broadcast' is not precisely defined, hence the question.

    opened by versatran01 4
  • Learning rate change

    Learning rate change

    Regarding line 113 of run-gqn.py. Does this change the learning rate of the Adam optimizer? This post shows something different

    https://stackoverflow.com/questions/48324152/pytorch-how-to-change-the-learning-rate-of-an-optimizer-at-any-given-moment-no

    opened by david-bernstein 4
  • Using the rooms data?

    Using the rooms data?

    I wanted to try your code on the rooms data but during conversion, I get these errors. What could I be doing wrong? Note that for the rooms data with moving camera I set the number of camera parameters to 7:

    Traceback (most recent call last):
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 119, in worker
        result = (True, func(*args, **kwds))
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 44, in mapstar
        return list(map(*args))
      File "tfrecord-converter.py", line 66, in convert
        for i, batch in enumerate(batch_process(record)):
      File "tfrecord-converter.py", line 29, in chunk
        for first in iterator:
      File "tfrecord-converter.py", line 40, in process
        'cameras': tf.FixedLenFeature(shape=SEQ_DIM * POSE_DIM, dtype=tf.float32)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 1019, in parse_single_example
        serialized, features, example_names, name
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 1063, in parse_single_example_v2_unoptimized
        return parse_single_example_v2(serialized, features, name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 2089, in parse_single_example_v2
        dense_defaults, dense_shapes, name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 2206, in _parse_single_example_v2_raw
        name=name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_parsing_ops.py", line 1164, in parse_single_example
        ctx=_ctx)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_parsing_ops.py", line 1260, in parse_single_example_eager_fallback
        attrs=_attrs, ctx=_ctx, name=name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py", line 67, in quick_execute
        six.raise_from(core._status_to_exception(e.code, message), None)
      File "<string>", line 3, in raise_from
    tensorflow.python.framework.errors_impl.InvalidArgumentError: Key: frames.  Can't parse serialized Example. [Op:ParseSingleExample]
    """
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "tfrecord-converter.py", line 98, in <module>
        pool.map(f, records)
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 266, in map
        return self._map_async(func, iterable, mapstar, chunksize).get()
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 644, in get
        raise self._value
    tensorflow.python.framework.errors_impl.InvalidArgumentError: Key: frames.  Can't parse serialized Example. [Op:ParseSingleExample]
    
    opened by Amir-Arsalan 3
Releases(0.1)
Owner
Jesper Wohlert
Jesper Wohlert
This repository contains code and data for "On the Multimodal Person Verification Using Audio-Visual-Thermal Data"

trimodal_person_verification This repository contains the code, and preprocessed dataset featured in "A Study of Multimodal Person Verification Using

ISSAI 7 Aug 31, 2022
Official code of paper "PGT: A Progressive Method for Training Models on Long Videos" on CVPR2021

PGT Code for paper PGT: A Progressive Method for Training Models on Long Videos. Install Run pip install -r requirements.txt. Run python setup.py buil

Bo Pang 27 Mar 30, 2022
OBG-FCN - implementation of 'Object Boundary Guided Semantic Segmentation'

OBG-FCN This repository is to reproduce the implementation of 'Object Boundary Guided Semantic Segmentation' in http://arxiv.org/abs/1603.09742 Object

Jiu XU 3 Mar 11, 2019
Python implementation of Bayesian optimization over permutation spaces.

Bayesian Optimization over Permutation Spaces This repository contains the source code and the resources related to the paper "Bayesian Optimization o

Aryan Deshwal 9 Dec 23, 2022
CROSS-LINGUAL ABILITY OF MULTILINGUAL BERT: AN EMPIRICAL STUDY

M-BERT-Study CROSS-LINGUAL ABILITY OF MULTILINGUAL BERT: AN EMPIRICAL STUDY Motivation Multilingual BERT (M-BERT) has shown surprising cross lingual a

CogComp 1 Feb 28, 2022
Implementation of C-RNN-GAN.

Implementation of C-RNN-GAN. Publication: Title: C-RNN-GAN: Continuous recurrent neural networks with adversarial training Information: http://mogren.

Olof Mogren 427 Dec 25, 2022
Detecting and Tracking Small and Dense Moving Objects in Satellite Videos: A Benchmark

This dataset is a large-scale dataset for moving object detection and tracking in satellite videos, which consists of 40 satellite videos captured by Jilin-1 satellite platforms.

Qingyong 87 Dec 22, 2022
Deep Learning to Improve Breast Cancer Detection on Screening Mammography

Shield: This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. Deep Learning to Improve Breast

Li Shen 305 Jan 03, 2023
Turning SymPy expressions into JAX functions

sympy2jax Turn SymPy expressions into parametrized, differentiable, vectorizable, JAX functions. All SymPy floats become trainable input parameters. S

Miles Cranmer 38 Dec 11, 2022
🔥3D-RecGAN in Tensorflow (ICCV Workshops 2017)

3D Object Reconstruction from a Single Depth View with Adversarial Learning Bo Yang, Hongkai Wen, Sen Wang, Ronald Clark, Andrew Markham, Niki Trigoni

Bo Yang 125 Nov 26, 2022
A Traffic Sign Recognition Project which can help the driver recognise the signs via text as well as audio. Can be used at Night also.

Traffic-Sign-Recognition In this report, we propose a Convolutional Neural Network(CNN) for traffic sign classification that achieves outstanding perf

Mini Project 64 Nov 19, 2022
Leveraging Instance-, Image- and Dataset-Level Information for Weakly Supervised Instance Segmentation

Leveraging Instance-, Image- and Dataset-Level Information for Weakly Supervised Instance Segmentation This paper has been accepted and early accessed

Yun Liu 39 Sep 20, 2022
Heat transfer problemas solved using python

heat-transfer Heat transfer problems solved using python isolation-convection.py compares the temperature distribution on the problem as shown in the

2 Nov 14, 2021
Neural Magic Eye: Learning to See and Understand the Scene Behind an Autostereogram, arXiv:2012.15692.

Neural Magic Eye Preprint | Project Page | Colab Runtime Official PyTorch implementation of the preprint paper "NeuralMagicEye: Learning to See and Un

Zhengxia Zou 56 Jul 15, 2022
Acoustic mosquito detection code with Bayesian Neural Networks

HumBugDB Acoustic mosquito detection with Bayesian Neural Networks. Extract audio or features from our large-scale dataset on Zenodo. This repository

31 Nov 28, 2022
Code for ACL2021 long paper: Knowledgeable or Educated Guess? Revisiting Language Models as Knowledge Bases

LANKA This is the source code for paper: Knowledgeable or Educated Guess? Revisiting Language Models as Knowledge Bases (ACL 2021, long paper) Referen

Boxi Cao 30 Oct 24, 2022
Unofficial pytorch implementation of 'Image Inpainting for Irregular Holes Using Partial Convolutions'

pytorch-inpainting-with-partial-conv Official implementation is released by the authors. Note that this is an ongoing re-implementation and I cannot f

Naoto Inoue 525 Jan 01, 2023
Tensorflow solution of NER task Using BiLSTM-CRF model with Google BERT Fine-tuning And private Server services

Tensorflow solution of NER task Using BiLSTM-CRF model with Google BERT Fine-tuning

MaCan 4.2k Dec 29, 2022
Parameter Efficient Deep Probabilistic Forecasting

PEDPF Parameter Efficient Deep Probabilistic Forecasting (PEDPF) is a repository containing code to run experiments for several deep learning based pr

Olivier Sprangers 10 Jun 13, 2022
The object detection pipeline is based on Ultralytics YOLOv5

AYOLOv2 The main goal of this repository is to rewrite the object detection pipeline with a better code structure for better portability and adaptabil

153 Dec 22, 2022