Now Reading
Guide to Google’s STAC: An SSL Framework For Object Detection

Guide to Google’s STAC: An SSL Framework For Object Detection

Google STAC

The Google Brain team has introduced STAC, a semi-supervised learning (SSL) framework to perform object detection in a simplified way. STAC outperforms most supervised approaches with greater data efficiency and simplicity. Further, STAC opens a new gateway to SSL-based visual object detection.

The top challenge in an object detection task is preparing annotated image data. Thanks to great data benchmarks, this issue has been well mitigated. The next challenge arises with the compute memory to handle such a huge data. Research labs and countable companies are well established with necessary hardware, but what about numerous production models all over the world preparing for task-specific deployment? Suppose a big pre-trained model trained on 1000 classes of objects is deployed as a bye-pass to the above issues that need no data preparation or compute power for training. But if the task is to detect some 10 classes only, there is actually no need to deploy a billion-parameter pre-trained model trained with data from 1000 classes.

REGISTER FOR OUR UPCOMING ML WORKSHOP

Semi-supervised learning (SSL) finds solutions to the above two challenges to a greater extent. SSL requires a fraction of the data annotated, leaving all others unannotated. Thus, the first issue of huge annotated data is resolvable. SSL models enable task-specific training with freshly prepared data. Moreover, Semi-supervised approaches yield better performance compared to fully supervised approaches. In recent years, image classification using SSL has become popular among researchers. Despite some SSL modeling attempts, there has been a gap generalizing the SSL approach to object detection

To this end, Kihyuk Sohn, Zizhao Zhang, Chun-Liang Li, Han Zhang, Chen-Yu Lee, and Tomas Pfister from Google Cloud AI Research, Google Brain has developed an SSL framework with two newly introduced tunable hyperparameters, the STAC, intended exclusively for object detection. It follows a consistency-based self-training strategy to efficiently train the model with minimal annotated data and strongly augmented unannotated data.

How does STAC really work?

STAC is provided with a small amount of annotated data and a large amount of unannotated data. Semi-supervised learning is the mixed mode of supervised learning with annotated data and unsupervised learning with unannotated data. STAC employs a supervised-learning-based Model as its teacher model. This can be a Faster R-CNN. The annotated images are fed into the teacher model to train it. With this, the supervised learning part is finished. 

The trained teacher model is used to infer bounding boxes over all the objects in the unannotated images. While inferring the bounding boxes, the model generates confidence values for each bounding box. Non-Maximum Suppression (NMS) is applied to the bounding boxes as post-processing. Further, STAC introduces a hyper-parameter that acts as a cut-off (threshold) to confidence value. Bounding boxes with confidence values above this cut-off are retained, whereas the rest are discarded. The retained bounding boxes are called pseudo labels and the images with pseudo labels are called pseudo-labeled images. 

learning strategy in STAC
The semi-supervised learning strategy in STAC  (source)

Strong augmentation strategies such as colour transformations, box-level geometric transformations, global transformations, and Cut-outs are applied to the unannotated images. The detector is trained with both the annotated images and the pseudo-labeled unannotated images with strong augmentations. STAC introduces a tunable loss weight to control unsupervised learning. 

data augmentation techniques in STAC
Some strong data augmentation techniques used in STAC (source)

Other than the teacher model, STAC is completely simple with just two tunable parameters (one for bounding box confidence thresholding and another for unsupervised loss control) and needs no supervision or intervention at all. 

Confidence thresholding in STAC
Different bounding box confidence thresholding values and their effects on selecting confident bounding boxes (source)

Python Implementation of STAC

STAC requires a Python environment with TensorFlow version 1.14 and a CUDA runtime with 8 GPUs. Most of this implementation references the official source code repository of STAC. Clone the source code to the local (or virtual) environment using the following command.

!git clone https://github.com/google-research/ssl_detection.git

Output:

Create the environment by installing the dependencies using the following commands.

 %%bash
 cd /content/ssl_detection/
 sudo apt install python3-dev python3-virtualenv python3-tk imagemagick
 virtualenv -p python3 --system-site-packages env3
 . env3/bin/activate
 pip install -r requirements.txt
 python -c 'import tensorflow as tf; print(tf.__version__)'
 # install coco apis
 pip3 install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI' 

Prepare the COCO dataset using the following commands. Unzip the compressed files once the download is finished.

 %%bash
 mkdir -p /content/coco/
 cd /content/coco/
 wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
 wget http://images.cocodataset.org/zips/train2017.zip
 wget http://images.cocodataset.org/zips/val2017.zip
 wget http://images.cocodataset.org/zips/unlabeled2017.zip
 unzip annotations_trainval2017.zip -d .
 unzip -q train2017.zip -d .
 unzip -q val2017.zip -d .
 unzip -q unlabeled2017.zip -d . 

Similarly, download the VOC dataset and untar them using the following commands.

 %%bash
 mkdir -p /content/voc/
 cd /content/voc/
 wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
 wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
 wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
 tar -xf VOCtrainval_06-Nov-2007.tar
 tar -xf VOCtest_06-Nov-2007.tar
 tar -xf VOCtrainval_11-May-2012.tar 

Generate annotated, and unannotated data split from the downloaded annotated dataset.

 %%bash
 cd /content/ssl_detection/prepare_datasets
 for seed in 1 2 3 4 5; do
   for percent in 1 2 5 10 20; do
     python3 prepare_coco_data.py --percent $percent --seed $seed &
   done
 done 

Download JSON files for the downloaded VOC dataset and COCO dataset.

 %%bash
 cd /content/
 wget https://storage.cloud.google.com/gresearch/ssl_detection/STAC_JSON.tar
 tar -xf STAC_JSON.tar.gz 

Download an untrained FasterRCNN backbone model 

See Also
How Google’s Cloud Vision APIs Analyse A Decade Of Television News And Half A Billion Images

 %%bash
 cd /content/coco/
 wget http://models.tensorpack.com/FasterRCNN/ImageNet-R50-AlignPadding.npz 

Prepare COCO dataset for training. Generate a path to save the model checkpoints during training. Prepare CUDA for training the FasterRCNN backbone model with 8 GPUs using the following commands. (Users can opt for 4, 16, or 32 GPUs)

 %%bash
 cd /content/ssl_detection/detection
 # Labeled and Unlabeled datasets
 [email protected]
 UNLABELED_DATASET=${DATASET}-unlabeled
 # PATH to save trained models
 CKPT_PATH=result/${DATASET}
 # PATH to save pseudo labels for unlabeled data
 PSEUDO_PATH=${CKPT_PATH}/PSEUDO_DATA
 # Train with 8 GPUs
 export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 

Train the teacher model- FasterRCNN with the above settings on the prepared dataset.

 %%bash
 python3 train_stg1.py \
     --logdir ${CKPT_PATH} --simple_path --config \
     BACKBONE.WEIGHTS=/content/coco/ImageNet-R50-AlignPadding.npz \
     DATA.BASEDIR=/content/coco/ \
     DATA.TRAIN="('${DATASET}',)" \
     MODE_MASK=False \
     FRCNN.BATCH_PER_IM=64 \
     PREPROC.TRAIN_SHORT_EDGE_SIZE="[500,800]" \
     TRAIN.EVAL_PERIOD=20 \
     TRAIN.AUGTYPE_LAB='default' 

Evaluate the teacher model using the following commands.

 %%bash
 if [ ! -d ${PSEUDO_PATH} ]; then
     mkdir -p ${PSEUDO_PATH}
 fi
 # model-180000 is the last checkpoint
 # save eval.json at $PSEUDO_PATH
 python3 predict.py \
     --evaluate ${PSEUDO_PATH}/eval.json \
     --load "${CKPT_PATH}"/model-180000 \
     --config \
     DATA.BASEDIR=/content/coco/ \
     DATA.TRAIN="('${UNLABELED_DATASET}',)" 

Once the teacher model is trained, it can be used to prepare pseudo labels for the unannotated images.

 %%bash
 python3 predict.py \
     --predict_unlabeled ${PSEUDO_PATH} \
     --load "${CKPT_PATH}"/model-180000 \
     --config \
     DATA.BASEDIR=/content/coco/ \
     DATA.TRAIN="('${UNLABELED_DATASET}',)" \
     EVAL.PSEUDO_INFERENCE=True 

Train the STAC detector with necessary configurations using the following command. It should be noted that training may take hours based on memory availability.

 %%bash
 python3 train_stg2.py \
     --logdir=${CKPT_PATH}/STAC --simple_path \
     --pseudo_path=${PSEUDO_PATH} \
     --config \
     BACKBONE.WEIGHTS=/content/coco/ImageNet-R50-AlignPadding.npz \
     DATA.BASEDIR=/content/coco/ \
     DATA.TRAIN="('${DATASET}',)" \
     DATA.UNLABEL="('${UNLABELED_DATASET}',)" \
     MODE_MASK=False \
     FRCNN.BATCH_PER_IM=64 \
     PREPROC.TRAIN_SHORT_EDGE_SIZE="[500,800]" \
     TRAIN.EVAL_PERIOD=20 \
     TRAIN.AUGTYPE_LAB='default' \
     TRAIN.AUGTYPE='strong' \
     TRAIN.CONFIDENCE=0.9 \
     TRAIN.WU=2 

A trained model can be directly deployed for inference.

Performance of STAC

STAC is evaluated on MS-COCO dataset and VOC07 along with recent supervised baseline models. STAC greatly outperforms supervised models trained either with strong augmentation or without augmentation. STAC, trained with only 5% of annotated data, yields far better performance than a Supervised model that is trained with 10% of annotated data.

Comparison of STAC with Supervised models
Comparison of STAC with Supervised models (source)
Performance of newly introduced parameters
Performance of STAC for various values of its newly introduced parameters, along with a comparison with supervised models (source)

References

What Do You Think?

Join Our Telegram Group. Be part of an engaging online community. Join Here.

Subscribe to our Newsletter

Get the latest updates and relevant offers by sharing your email.

Copyright Analytics India Magazine Pvt Ltd

Scroll To Top