MITB Banner

Watch More

Hands-On Guide To Train RL Agents using Stable BaseLines on Atari Gym Environment

Reinforcement learning is continuously being made easy by OpenAI. On their, mission to develop and promote friendly AI that helps humanity, OpenAI released Stable-Baselines. It was created by Robotics Lab U2IS (INRIA Flowers Team) at ENSTA Paris with a goal to provide Scikit Learn like coding structure to give a unified style to program the RL agents.

Most of the RL algorithms released by OpenAI are presented with a unified API that supports learning, saving, loading the RL agents and much more. This API provides an easy to use the platform to practice and learn the state of the art algorithms. The following features are covered using Google Colab and are a part of the getting started guide.

The notebooks above provide all the features being offered by Stable-baselines and this blog also provides all the features in detail.

Making the Agent Learn

In this hands-on guide, we will be training an RL agent with state of the art algorithm in a few lines of code using the Stable-Baselines API. The play session of the trained agent will also be recorded in form of a .gif or .mp4 format.

The below snippet allows using a random agent to play DemonAttack-V0 and records the gameplay in a .mp4 format.

"""Random Agent Video"""

env_id = 'DemonAttackNoFrameskip-v4'

video_folder = '/gdrive/My Drive/videos'

video_length = 300

env = DummyVecEnv([lambda: gym.make(env_id)])

obs = env.reset()

# Record the video starting at the first step

env = VecVideoRecorder(env,video_folder,
record_video_trigger=lambda x:x ==0,
video_length=video_length,
name_prefix="random-agent-{}".format(env_id))

env.reset()

for _ in range(video_length + 1):
action = [env.action_space.sample()]
obs, _, _, _ = env.step(action)

# Save the video

env.close()

Have a look at the video to get see the agent playing.

Training an Agent to play the same game is just an addition of using learn function like Scikit-learn fit and predict function.

In our previous article, we used Gym Retro to train a rational agent but the challenge was to use write our own agent and make it learn to play in the Retro Gym environment.

With Stable-Baselines we can use the Atari games environment (Similar to Gym Retro) as well as use a pre-trained agent to play in the environment.

Nevertheless, the pre-trained agent must be trained for some episodes/timesteps in the present environment to get the expected results.

Using Stable-Baselines, Training an RL agent on Atari games is straightforward thanks to make_atari_env helper function. It handles the preprocessing and multiprocessing needed to reflect the environment.

The below agent has been trained using PPO2 (Proximal Policy Optimization) using CNN Policy and it performs far better than a random agent.

The below snippet allows using a trained agent to play DemonAttack-V0 and records the gameplay in a .mp4 format.

"""Trained RL Agent"""

import imageio

import time

import numpy as np

from stable_baselines.ddpg.policies import CnnPolicy

from stable_baselines.common.policies import MlpLstmPolicy, CnnLstmPolicy

from stable_baselines import A2C, PPO2

env = make_atari_env('DemonAttackNoFrameskip-v4', num_env=16, seed=0)

env = DummyVecEnv([lambda: gym.make('DemonAttackNoFrameskip-v4')])

model = PPO2("CnnPolicy", env, verbose=1)

s_time = time.time()

model.learn(total_timesteps=int(1e5))

e_time = time.time()

print(f"Total Run-Time : , {round(((e_time - s_time) * 1000), 3)} ms")

video_length = 3750

# Record the video starting at the first step

env = VecVideoRecorder(env, video_folder,

    record_video_trigger=lambda x: x == 1000,     video_length=video_length,

    name_prefix="trained-agent-{}".format(env_id))

env.reset()

for _ in range(video_length + 1):

  action = [env.action_space.sample()]

  obs, _, _, _ = env.step(action)

# Save the video

env.close()

Have a look at the video to get see the agent playing.



The RL agent was trained on Google Colab using TPU’s as hardware accelerators for 1,00,000  timesteps. The total time to complete the training was 1 hour and 30 minutes.

API Features

1. Hyperparameter Tuning

The API also provides flexibility to access, modify and reload the model hyperparameters.

One can leverage this feature to assess a large set of models across different network structures.

2. Continual Learning

The Continual Learning feature is also provided by the API which helps to switch the learning environment and re-learn on the new environment.

3. Custom Policy

Defining Custom policies is also possible by using a wrapper on any available policy as below.

from stable_baselines.common.policies import FeedForwardPolicy

from stable_baselines.common.vec_env import DummyVecEnv

from stable_baselines import A2C

# Custom MLP policy of three layers of size 128 each

class CustomPolicy(FeedForwardPolicy):

    def __init__(self, *args, **kwargs):

        super(CustomPolicy, self).__init__(*args, **kwargs,

        net_arch=[dict(pi=[128, 128, 128], vf=[128, 128, 128])],

        feature_extraction="mlp")

model = A2C(CustomPolicy, 'LunarLander-v2', verbose=1)

# Train the agent

model.learn(total_timesteps=100000)

OpenAI attempt to open-source their algorithms in a unified way not only provides an easy to use platform. However, it also provides the research community to use the complex environments and benchmark various state of the art RL algorithms.

Access all our open Survey & Awards Nomination forms in one place >>

Picture of Anurag Upadhyaya

Anurag Upadhyaya

Experienced Data Scientist with a demonstrated history of working in Industrial IOT (IIOT), Industry 4.0, Power Systems and Manufacturing domain. I have experience in designing robust solutions for various clients using Machine Learning, Artificial Intelligence, and Deep Learning. I have been instrumental in developing end to end solutions from scratch and deploying them independently at scale.

Download our Mobile App

CORPORATE TRAINING PROGRAMS ON GENERATIVE AI

Generative AI Skilling for Enterprises

Our customized corporate training program on Generative AI provides a unique opportunity to empower, retain, and advance your talent.

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox
Recent Stories