Semi-Supervised Learning with Ladder Networks in Keras. Get 98% test accuracy on MNIST with just 100 labeled examples !

Overview

Semi-Supervised Learning with Ladder Networks in Keras

This is an implementation of Ladder Network in Keras. Ladder network is a model for semi-supervised learning. Refer to the paper titled Semi-Supervised Learning with Ladder Networks by A Rasmus, H Valpola, M Honkala,M Berglund, and T Raiko

This implementation was used in the official code of our paper Unsupervised Clustering using Pseudo-semi-supervised Learning . The code can be found here and the blog post can be found here

The model achives 98% test accuracy on MNIST with just 100 labeled examples.

The code only works with Tensorflow backend.

Requirements

  • Python 2.7+/3.6+
  • Tensorflow (1.4.0)
  • numpy
  • keras (2.1.4)

Note that other versions of tensorflow/keras should also work.

How to use

Load the dataset

from keras.datasets import mnist
import keras
import random

# get the dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(60000, 28*28).astype('float32')/255.0
x_test = x_test.reshape(10000, 28*28).astype('float32')/255.0

y_train = keras.utils.to_categorical( y_train )
y_test = keras.utils.to_categorical( y_test )

# only select 100 training samples 
idxs_annot = range( x_train.shape[0])
random.seed(0)
random.shuffle( idxs_annot )
idxs_annot = idxs_annot[ :100 ]

x_train_unlabeled = x_train
x_train_labeled = x_train[ idxs_annot ]
y_train_labeled = y_train[ idxs_annot  ]

Repeat the labeled dataset to match the shapes

n_rep = x_train_unlabeled.shape[0] / x_train_labeled.shape[0]
x_train_labeled_rep = np.concatenate([x_train_labeled]*n_rep)
y_train_labeled_rep = np.concatenate([y_train_labeled]*n_rep)

Initialize the model

from ladder_net import get_ladder_network_fc
inp_size = 28*28 # size of mnist dataset 
n_classes = 10
model = get_ladder_network_fc( layer_sizes = [ inp_size , 1000, 500, 250, 250, 250, n_classes ]  )

Train the model

model.fit([ x_train_labeled_rep , x_train_unlabeled   ] , y_train_labeled_rep , epochs=100)

Get the test accuracy

from sklearn.metrics import accuracy_score
y_test_pr = model.test_model.predict(x_test , batch_size=100 )

print "test accuracy" , accuracy_score(y_test.argmax(-1) , y_test_pr.argmax(-1)  )
Owner
Divam Gupta
Graduate student at Carnegie Mellon University | Former Research Fellow at Microsoft Research
Divam Gupta
TensorFlow Implementation of Unsupervised Cross-Domain Image Generation

Domain Transfer Network (DTN) TensorFlow implementation of Unsupervised Cross-Domain Image Generation. Requirements Python 2.7 TensorFlow 0.12 Pickle

Yunjey Choi 864 Dec 30, 2022
This porject is intented to build the most accurate model for predicting the porbability of loan default

Estimating-Loan-Default-Probability IBA ML2 Mid-project / Kaggle Competition This porject is intented to build the most accurate model for predicting

Adil Gahramanov 1 Jan 24, 2022
DeepFaceLab fork which provides IPython Notebook to use DFL with Google Colab

DFL-Colab — DeepFaceLab fork for Google Colab This project provides you IPython Notebook to use DeepFaceLab with Google Colaboratory. You can create y

779 Jan 05, 2023
Event-forecasting - Event Forecasting Algorithms With Python

event-forecasting Event Forecasting Algorithms Theory Correlating events in comp

Intellia ICT 4 Feb 15, 2022
Object detection and instance segmentation toolkit based on PaddlePaddle.

Object detection and instance segmentation toolkit based on PaddlePaddle.

9.3k Jan 02, 2023
Supercharging Imbalanced Data Learning WithCausal Representation Transfer

ECRT: Energy-based Causal Representation Transfer Code for Supercharging Imbalanced Data Learning With Energy-basedContrastive Representation Transfer

Zidi Xiu 11 May 02, 2022
HackBMU-5.0-Team-Ctrl-Alt-Elite - HackBMU 5.0 Team Ctrl Alt Elite

HackBMU-5.0-Team-Ctrl-Alt-Elite The search is over. We present to you ‘Health-A-

3 Feb 19, 2022
using yolox+deepsort for object-tracker

YOLOX_deepsort_tracker yolox+deepsort实现目标跟踪 最新的yolox尝尝鲜~~(yolox正处在频繁更新阶段,因此直接链接yolox仓库作为子模块) Install Clone the repository recursively: git clone --rec

245 Dec 26, 2022
A weakly-supervised scene graph generation codebase. The implementation of our CVPR2021 paper ``Linguistic Structures as Weak Supervision for Visual Scene Graph Generation''

README.md shall be finished soon. WSSGG 0 Overview 1 Installation 1.1 Faster-RCNN 1.2 Language Parser 1.3 GloVe Embeddings 2 Settings 2.1 VG-GT-Graph

Keren Ye 35 Nov 20, 2022
PyTorch implementation of SmoothGrad: removing noise by adding noise.

SmoothGrad implementation in PyTorch PyTorch implementation of SmoothGrad: removing noise by adding noise. Vanilla Gradients SmoothGrad Guided backpro

SSKH 143 Jan 05, 2023
Machine Learning Toolkit for Kubernetes

Kubeflow the cloud-native platform for machine learning operations - pipelines, training and deployment. Documentation Please refer to the official do

Kubeflow 12.1k Jan 03, 2023
HiFT: Hierarchical Feature Transformer for Aerial Tracking (ICCV2021)

HiFT: Hierarchical Feature Transformer for Aerial Tracking Ziang Cao, Changhong Fu, Junjie Ye, Bowen Li, and Yiming Li Our paper is Accepted by ICCV 2

Intelligent Vision for Robotics in Complex Environment 55 Nov 23, 2022
Official implementation of "Implicit Neural Representations with Periodic Activation Functions"

Implicit Neural Representations with Periodic Activation Functions Project Page | Paper | Data Vincent Sitzmann*, Julien N. P. Martel*, Alexander W. B

Vincent Sitzmann 1.4k Jan 06, 2023
The implementation of the lifelong infinite mixture model

Lifelong infinite mixture model 📋 This is the implementation of the Lifelong infinite mixture model 📋 Accepted by ICCV 2021 Title : Lifelong Infinit

Fei Ye 5 Oct 20, 2022
[CVPR2021 Oral] FFB6D: A Full Flow Bidirectional Fusion Network for 6D Pose Estimation.

FFB6D This is the official source code for the CVPR2021 Oral work, FFB6D: A Full Flow Biderectional Fusion Network for 6D Pose Estimation. (Arxiv) Tab

Yisheng (Ethan) He 201 Dec 28, 2022
Current state of supervised and unsupervised depth completion methods

Awesome Depth Completion Table of Contents About Sparse-to-Dense Depth Completion Current State of Depth Completion Unsupervised VOID Benchmark Superv

224 Dec 28, 2022
EMNLP'2021: Simple Entity-centric Questions Challenge Dense Retrievers

EntityQuestions This repository contains the EntityQuestions dataset as well as code to evaluate retrieval results from the the paper Simple Entity-ce

Princeton Natural Language Processing 119 Sep 28, 2022
Statsmodels: statistical modeling and econometrics in Python

About statsmodels statsmodels is a Python package that provides a complement to scipy for statistical computations including descriptive statistics an

statsmodels 8.1k Jan 02, 2023
Setup freqtrade/freqUI on Heroku

UNMAINTAINED - REPO MOVED TO https://github.com/p-zombie/freqtrade Creating the app git clone https://github.com/joaorafaelm/freqtrade.git && cd freqt

João 51 Aug 29, 2022
Repository to run object detection on a model trained on an autonomous driving dataset.

Autonomous Driving Object Detection on the Raspberry Pi 4 Description of Repository This repository contains code and instructions to configure the ne

Ethan 51 Nov 17, 2022