Note for RL Stable Baselines


sudo apt-get update && sudo apt-get install cmake libopenmpi-dev zlib1g-dev
pip install stable-baselines[mpi]
pip install pyglet==1.3.1  # 1.4.1 will fail rendering


Import Libraries


import gym
import numpy as np
from stable_baselines import PPO2
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv

Import Gym Environment






PPO是一种基于策略的算法,这意味着必须使用最新策略来收集用于更新网络的轨迹。它通常比DQN,SAC或TD3等非策略性算法效率低,但在wall-clock time方面则要快得多。

env = gym.make('CartPole-v1')
# 矢量化的环境便于多进程训练,PPO需要一个矢量化的环境
env = DummyVecEnv([lambda: env])

Train a PPO Agent

# :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
model = PPO2(MlpPolicy, env, verbose=0)
model.learn(total_timesteps=1000, log)


The policy class to use will be inferred and the environment will be automatically created. This works because both are registered.

# :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
model = PPO2('MlpPolicy', "CartPole-v1", verbose=1).learn(1000, log_interval=50)

Train a SAC Agent

from stable_baselines import SAC

model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1)
model.learn(total_timesteps=50000, log_interval=50)

Train a DQN Agent


Vanilla DQN: DQN without Extensions

from stable_baselines import DQN

# Deactivate all the DQN extensions to have the original version, In practice, it is recommend to have them activated
kwargs = {'double_q': False, 'prioritized_replay': False, 'policy_kwargs': dict(dueling=False)}

# Note that the MlpPolicy of DQN is different from the one of PPO, but stable-baselines handles that automatically if you pass a string
dqn_model = DQN('MlpPolicy', 'CartPole-v1', verbose=1, **kwargs)
dqn_model.learn(total_timesteps=10000, log_interval=50)


Saving and Loading Models


import os

# Create save dir
save_dir = "./model/gym/"
os.makedirs(save_dir, exist_ok=True)

model = PPO2('MlpPolicy', 'Pendulum-v0', verbose=0).learn(8000)
# The model will be saved under + "/PPO2_tutorial")

# sample an observation from the environment
obs = model.env.observation_space.sample()

# Check prediction before saving
print("pre saved", model.predict(obs, deterministic=True))

del model # delete trained model to demonstrate loading

loaded_model = PPO2.load(save_dir + "/PPO2_tutorial")
# Check that the prediction is the same after loading (for the same observation)
print("loaded", loaded_model.predict(obs, deterministic=True))



import os
from stable_baselines.common.vec_env import DummyVecEnv

# Create save dir
save_dir = "./model/gym/"
os.makedirs(save_dir, exist_ok=True)

model = A2C('MlpPolicy', 'Pendulum-v0', verbose=0, gamma=0.9, n_steps=20).learn(8000)
# The model will be saved under + "/A2C_tutorial")

del model # delete trained model to demonstrate loading

# load the model, and when loading set verbose to 1
loaded_model = A2C.load(save_dir + "/A2C_tutorial", verbose=1)

# show the save hyper-parameters
print("loaded:", "gamma =", loaded_model.gamma, "n_steps =", loaded_model.n_steps)

# as the environment is not serializable, we need to set a new instance of the environment
loaded_model.set_env(DummyVecEnv([lambda: gym.make('Pendulum-v0')]))
# and continue training

Visualize The Trained Agent

model = SAC.load('saved_model_name')
env = model.get_env()
obs = env.reset()
for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, rewards, done, info = env.step(action)

Gym Wrappers

Anatomy of a Gym Wrapper

Wrappers are used to transform an environment in a modular way.




class TimeLimitWrapper(gym.Wrapper):
  :param env: (gym.Env) Gym environment that will be wrapped
  :param max_steps: (int) Max number of steps per episode
  def __init__(self, env, max_steps=100):
    # Call the parent constructor, so we can access self.env later
    super(TimeLimitWrapper, self).__init__(env)
    self.max_steps = max_steps
    # Counter of steps per episode
    self.current_step = 0
  def reset(self):
    Reset the environment
    # Reset the counter
    self.current_step = 0
    return self.env.reset()

  def step(self, action):
    :param action: ([float] or int) Action taken by the agent
    :return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional information
    self.current_step += 1
    obs, reward, done, info = self.env.step(action)
    # Overwrite the done signal when
    if self.current_step >= self.max_steps:
        done = True
        # Update the info dict to signal that the limit was exceeded
        info['time_limit_reached'] = True
    return obs, reward, done, info


from gym.envs.classic_control.pendulum import PendulumEnv

# Here we create the environment directly because gym.make() already wrap the environment in a TimeLimit wrapper otherwise
env = PendulumEnv()
# Wrap the environment
env = TimeLimitWrapper(env, max_steps=100)



Add Remaining Time to Observation Space for Fixed Length Episode

from gym.wrappers import TimeLimit

class TimeFeatureWrapper(gym.Wrapper):
    Add remaining time to observation space for fixed length episodes.
    See and

    :param env: (gym.Env)
    :param max_steps: (int) Max number of steps of an episode
        if it is not wrapped in a TimeLimit object.
    :param test_mode: (bool) In test mode, the time feature is constant,
        equal to zero. This allow to check that the agent did not overfit this feature,
        learning a deterministic pre-defined sequence of actions.
    def __init__(self, env, max_steps=1000, test_mode=False):
        assert isinstance(env.observation_space, gym.spaces.Box)
        # Add a time feature to the observation
        low, high = env.observation_space.low, env.observation_space.high
        low, high= np.concatenate((low, [0])), np.concatenate((high, [1.]))
        env.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32)

        super(TimeFeatureWrapper, self).__init__(env)

        if isinstance(env, TimeLimit):
            self._max_steps = env._max_episode_steps
            self._max_steps = max_steps
        self._current_step = 0
        self._test_mode = test_mode

    def reset(self):
        self._current_step = 0
        return self._get_obs(self.env.reset())

    def step(self, action):
        self._current_step += 1
        obs, reward, done, info = self.env.step(action)
        return self._get_obs(obs), reward, done, info

    def _get_obs(self, obs):
        Concatenate the time feature to the current observation.

        :param obs: (np.ndarray)
        :return: (np.ndarray)
        # Remaining time is more general
        time_feature = 1 - (self._current_step / self._max_steps)
        if self._test_mode:
            time_feature = 1.0
        # Optionally: concatenate [time_feature, time_feature ** 2]
        return np.concatenate((obs, [time_feature]))

Train on Atari Games



from stable_baselines.common.cmd_util import make_atari_env
from stable_baselines.common.vec_env import VecFrameStack
from stable_baselines import ACER

if __name__ == '__main__':

    # There already exists an environment generator
    # that will make and wrap atari environments correctly.
    # Here we are also multiprocessing training (num_env=4 => 4 processes)
    env = make_atari_env('PongNoFrameskip-v4', num_env=4, seed=0)
    # Frame-stacking with 4 frames
    env = VecFrameStack(env, n_stack=4)

    model = ACER('CnnPolicy', env, verbose=1)

    obs = env.reset()
    while True:
        action, _states = model.predict(obs)
        obs, rewards, dones, info = env.step(action)

Mujoco: Normalizing Input Features

标准化输入功能对于成功训练RL智能体(默认情况下,缩放图像但不缩放其他类型的输入)可能至关重要,例如在Mujoco上进行训练时。 为此,存在一个包装器,该包装器将计算输入要素的滑动平均值和标准偏差(对于奖励也可以这样做)。

import gym

from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines import PPO2

env = DummyVecEnv([lambda: gym.make("Reacher-v2")])
# Automatically normalize the input features
env = VecNormalize(env, norm_obs=True, norm_reward=False,

model = PPO2(MlpPolicy, env)

# Don't forget to save the VecNormalize statistics when saving the agent
log_dir = "/tmp/" + "ppo_reacher"), "vec_normalize.pkl"))

## Custom Policy Network

Stable baseline提供了针对图像的CNNPolicies以及其他类型输入的MlpPolicies的策略网络。


  1. 通过在实例化model时传入$policy_kwargs$参数实现,Code Link.
  2. 通过构建新的策略网络类来实现,Code Link.
    1. Warning: When loading a model with a custom policy, you must pass the custom policy explicitly when loading the model.
    2. 为了代码简洁,可以使用registered class来实现,Code Link
  3. 通过完全重新定义整个神经网络实现,Code Link


  • net_arch=[128, 128]: 代表值网络和策略网络都是共享前面的两层128,128的神经元
  • net_arch=[128, dict(vf=[256], pi=[16])]: 表示贡献一层128神经元,值网络是256个神经元,而策略网络是16个神经元


import gym

from stable_baselines.common.policies import FeedForwardPolicy
from stable_baselines import A2C

# Custom MLP policy of three layers of size 128 each
class CustomPolicy(FeedForwardPolicy):
    def __init__(self, *args, **kwargs):
        super(CustomPolicy, self).__init__(*args, **kwargs,
                                           net_arch=[dict(pi=[128, 128, 128], vf=[128, 128, 128])],

model = A2C(CustomPolicy, 'LunarLander-v2', verbose=1)
# Train the agent

Recurrent Policies



Note: Here the net_arch parameter takes an additional (mandatory) ‘lstm’ entry within the shared network section. The LSTM is shared between value network and policy network.

class CustomLSTMPolicy(LstmPolicy):
    def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=64, reuse=False, **_kwargs):
        super().__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse,
                         net_arch=[8, 'lstm', dict(vf=[5, 10], pi=[10])],
                         layer_norm=True, feature_extraction="mlp", **_kwargs)
from stable_baselines import PPO2

# For recurrent policies, with PPO2, the number of environments run in parallel
# should be a multiple of nminibatches.
model = PPO2('MlpLstmPolicy', 'CartPole-v1', nminibatches=1, verbose=1)

# Retrieve the env
env = model.get_env()

obs = env.reset()
# Passing state=None to the predict function means
# it is the initial state
state = None
# When using VecEnv, done is a vector
done = [False for _ in range(env.num_envs)]
for _ in range(1000):
    # We need to pass the previous state and a mask for recurrent policies
    # to reset lstm state when a new episode begin
    action, state = model.predict(obs, state=state, mask=done)
    obs, reward , done, _ = env.step(action)
    # Note: with VecEnv, env.reset() is automatically called
    # Show the env

Hindsight Experience Replay

针对Goal-conditioned连续控制的环境任务,rl-baseline提供了HER的方法,可以直接使用,See more in HER

Continual Learning

对于两个类似的环境,可以先在A环境训练model,然后将其参数作为B环境的model初始值继续训练,See more in Continual learning

Multiprocessing of Environment

Vectorized Environments and Imports

Vectorized Environments are a method for stacking multiple independent environments into a single environment. Instead of training an RL agent on 1 environment per step, it allows us to train it on n environments per step. This provides two benefits:

  • Agent experience can be collected more quickly
  • The experience will contain a more diverse range of states, it usually improves exploration

Stable-Baselines provides two types of Vectorized Environment:

  • SubprocVecEnv which run each environment in a separate process
  • DummyVecEnv which run all environment on the same process

In practice, DummyVecEnv is usually faster than SubprocVecEnv because of communication delays that subprocesses have.

Define an Environment Function



from stable_baselines.common import set_global_seeds

def make_env(env_id, rank, seed=0):
    Utility function for multi-processed env.
    :param env_id: (str) the environment ID
    :param num_env: (int) the number of environment you wish to have in subprocesses
    :param seed: (int) the initial seed for RNG
    :param rank: (int) index of the subprocess
    def _init():
        env = gym.make(env_id)
        # Important: use a different seed for each environment
        env.seed(seed + rank)
        return env
    return _init


  • 并行多进程的数量是通过num_cpu变量设置的
  • 因为我们使用了矢量化的环境(SubprocVecEnv),发送到多个环境的动作是array(one action per process),同理observations, rewards and dones are arrays
env_id = "CartPole-v1"
num_cpu = 4  # Number of processes to use
# Create the vectorized environment
env = SubprocVecEnv([make_env(env_id, i) for i in range(num_cpu)])

model = ACKTR(MlpPolicy, env, verbose=0)

Note: When using SubprocVecEnv, users must wrap the code in an if __name__ == “__main__“: if using the forkserver or spawn start method (default on Windows). On Linux, the default start method is fork which is not thread safe and can create deadlocks.

See more in Multi-processing.

Callbacks and Hyper-Parameter Tuning

Hyper-Parameter Tuning




A Functional Approach


Stop Training after Two Episode


def simple_callback(_locals, _globals):
    Callback called at each step (for DQN an others) or after n steps (see ACER or PPO2)
    :param _locals: (dict)
    :param _globals: (dict)
    # get callback variables, with default values if uninitialized
    callback_vars = get_callback_vars(_locals["self"], called=False)

    if not callback_vars["called"]:
        print("callback - first call")
        callback_vars["called"] = True
        return True # returns True, training continues.
        print("callback - second call")
        return False # returns False, training stops.

model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1)
model.learn(8000, callback=simple_callback)

Auto Saving Best Model

import os

import numpy as np

from stable_baselines.bench import Monitor
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.results_plotter import load_results, ts2xy

def auto_save_callback(_locals, _globals):
    Callback called at each step (for DQN an others) or after n steps (see ACER or PPO2)
    :param _locals: (dict)
    :param _globals: (dict)
    # get callback variables, with default values if uninitialized
    callback_vars = get_callback_vars(_locals["self"], n_steps=0, best_mean_reward=-np.inf)

    # skip every 20 steps
    if callback_vars["n_steps"] % 20 == 0:
        # Evaluate policy training performance
        x, y = ts2xy(load_results(log_dir), 'timesteps')
        if len(x) > 0:
            mean_reward = np.mean(y[-100:])

            # New best model, you could save the agent here
            if mean_reward > callback_vars["best_mean_reward"]:
                callback_vars["best_mean_reward"] = mean_reward
                # Example for saving best model
                print("Saving new best model at {} timesteps".format(x[-1]))
                _locals['self'].save(log_dir + 'best_model')
    callback_vars["n_steps"] += 1
    return True

# Create log dir
log_dir = "./model/gym/"
os.makedirs(log_dir, exist_ok=True)

# Create and wrap the environment
env = gym.make('CartPole-v1')
env = Monitor(env, log_dir, allow_early_resets=True)
env = DummyVecEnv([lambda: env])

model = A2C('MlpPolicy', env, verbose=0)
model.learn(total_timesteps=10000, callback=auto_save_callback)

Realtime Plotting of Performance

import matplotlib.pyplot as plt
import numpy as np

def plotting_callback(_locals, _globals):
    Callback called at each step (for DQN an others) or after n steps (see ACER or PPO2)
    :param _locals: (dict)
    :param _globals: (dict)
    # get callback variables, with default values if uninitialized
    callback_vars = get_callback_vars(_locals["self"], plot=None)

    # get the monitor's data
    x, y = ts2xy(load_results(log_dir), 'timesteps')
    if callback_vars["plot"] is None: # make the plot
        fig = plt.figure(figsize=(6,3))
        ax = fig.add_subplot(111)
        line, = ax.plot(x, y)
        callback_vars["plot"] = (line, ax, fig)
    else: # update and rescale the plot
        callback_vars["plot"][0].set_data(x, y)
        callback_vars["plot"][-2].set_xlim([_locals["total_timesteps"] * -0.02,
                                            _locals["total_timesteps"] * 1.02])

# Create log dir
log_dir = "./model/gym/"
os.makedirs(log_dir, exist_ok=True)

# Create and wrap the environment
env = gym.make('MountainCarContinuous-v0')
env = Monitor(env, log_dir, allow_early_resets=True)
env = DummyVecEnv([lambda: env])

model = PPO2('MlpPolicy', env, verbose=0)
model.learn(20000, callback=plotting_callback)

Progress Bar

from import tqdm
# this callback uses the 'with' block, allowing for correct initialiZation and destruction
class progressbar_callback(object):
    def __init__(self, total_timesteps): # init object with total timesteps
        self.pbar = None
        self.total_timesteps = total_timesteps

    def __enter__(self): # create the progress bar and callback, return the callback
        self.pbar = tqdm(total=self.total_timesteps)

        def callback_progressbar(local_, global_):
            self.pbar.n = local_["self"].num_timesteps

        return callback_progressbar

    def __exit__(self, exc_type, exc_val, exc_tb): # close the callback
        self.pbar.n = self.total_timesteps

model = TD3('MlpPolicy', 'Pendulum-v0', verbose=0)
with progressbar_callback(2000) as callback: # this guarantees that the tqdm progress bar closes correctly
    model.learn(2000, callback=callback)



def compose_callback(*callback_funcs): # takes a list of functions, and returns the composed function.
    def _callback(_locals, _globals):
        continue_training = True
        for cb_func in callback_funcs:
            if cb_func(_locals, _globals) is False: # as a callback can return None for legacy reasons.
                continue_training = False
        return continue_training
    return _callback

# Create log dir
log_dir = "/tmp/gym/"
os.makedirs(log_dir, exist_ok=True)

# Create and wrap the environment
env = gym.make('CartPole-v1')
env = Monitor(env, log_dir, allow_early_resets=True)
env = DummyVecEnv([lambda: env])

model = PPO2('MlpPolicy', env, verbose=0)
with progressbar_callback(10000) as progress_callback:
    model.learn(10000, callback=compose_callback(progress_callback, plotting_callback, auto_save_callback))


class EvalCallback(object):
    Callback for evaluating an agent.
    :param eval_env: (gym.Env) The environment used for initialization
    :param n_eval_episodes: (int) The number of episodes to test the agent
    :param eval_freq: (int) Evaluate the agent every eval_freq call of the callback.

    def __init__(self, eval_env, n_eval_episodes=5, eval_freq=20, log_dir='./model/gym/'):
        super(EvalCallback, self).__init__()
        self.eval_env = eval_env
        self.n_eval_episodes = n_eval_episodes
        self.eval_freq = eval_freq
        self.n_calls = 0
        self.best_mean_reward = -np.inf
        self.log_dir = log_dir

    def __call__(self, locals_, globals_):
        This method will be called by the model.
        :param locals_: (dict)
        :param globals_: (dict)
        :return: (bool)
        # Get the self object of the model
        self_ = locals_['self']

        if self.n_calls % self.eval_freq == 0:
            # Evaluate the agent
            for i in range(self.n_eval_episodes):
                reward = evaluate(self_, self.eval_env, num_episodes=self.n_eval_episodes)

                # Save the agent and update self.best_mean_reward
                if self.best_mean_reward < reward:
                    self.best_mean_reward = reward

            print("Evaluation: Best mean reward in the evaluated env: {:.2f}, saved it.".format(self.best_mean_reward))
   + 'best_evaluated_model_with_reward_' + str(self.best_mean_reward))

        self.n_calls += 1
        return True

env = gym.make("CartPole-v1")
env = DummyVecEnv([lambda: env])

# Env for evaluating the agent
eval_env = gym.make("CartPole-v1")
eval_env = DummyVecEnv([lambda: eval_env])

# Create log dir
log_dir = "./model/gym/"
os.makedirs(log_dir, exist_ok=True)

# Create the callback object
callback = EvalCallback(eval_env, n_eval_episodes=5, eval_freq=20, log_dir=log_dir)

# Create the RL model
model = PPO2('MlpPolicy', env, verbose=0)

# Train the RL model
model.learn(int(100000), callback=callback, log_interval=1000)

Using Custom Environment

要想在个人自定义的环境中使用RL baselines库,只需要环境满足gym接口就行,即必须继承自OpenAI的Gym类并实现其方法。

Gym Basics


  • reset(): 在每一轮轨迹迭代过程中的最开始时调用,将环境复位,返回一个observation
  • step(action): 调用该函数后环境执行对应action,返回下一个时刻的observation, 及时奖励,轨迹迭代过程是否结束和额外信息
  • (Optional) render(method=‘human’): 用来渲染显示环境,在colab等没有显示器的环境下,无法直接调用human参数,需要改为method=‘rgb_array’,间接的方式获取环境中的image


  • observation_space: 环境返回的状态空间属性,表示状态类型、大小
  • action_space: 表示动作空间属性


  • gym.spaces.Box: 是一个位于 $𝑅^𝑛$ 的(可能是无界的)盒子。 具体来说,Box表示n个封闭间隔的笛卡尔积。 每个间隔的形式为 $[a,b]$,$(-oo,b]$,$[a,oo)$ 或 $(-oo,oo$。 示例:可以使用Box空间描述一维矢量或图像的observation
observation_space = spaces.Box(low=0, high=255, shape=(HEIGHT, WIDTH, N_CHANNELS), dtype=np.uint8)
  • gym.spaces.Discrete: 一个处于$[0, 1, … , n-1]$的离散空间, 示例:如果你有两个动作(left and right),那么可以将动作空间表示为Discrete(2),第一个动作为0, 第二个动作为1

RL Stable Baselines新的方法,可以用来检测自定义环境是否完全支持gym:

from stable_baselines.common.env_checker import check_env

env = CustomEnv(arg1, ...)
# It will check your custom environment and output additional warnings if needed


import gym

env = gym.make('CartPole-v1')

# Box(4,) means that it is a Vector with 4 components
print("Observation space: ", env.observation_space)
print("Shape is: ", env.observation_space.shape)

# Discrete(2) means that there is two discrete actions
print("Action space: ", env.action_space)
print("Shape is: ", env.action_space.shape)

# The reset method is called at the beginning of an episode
obs = env.reset()

# Sample a random action
action = env.action_space.sample()
print("Sampled action: ", action)

# Execute the action using step
obs, reward, done, info = env.step(action)

# Note that obs is a numpy array
# Info is an empty dict for now but can contain any debugging info
# Reward is a scalar
print("The observation is {}, its shape is {}, reward is {}, done: {}, info: {}".format(obs, obs.shape, reward, done,

See more Link.

Tensorboard Integration

要想在rl baselines中使用Tensorboard,只需要:

  1. 在model定义时添加log路径
  2. 在load的函数中添加log路径,因为save模型时默认不保存log地址
# 添加tensorboard_log参数即可集成tensorboard支持
# A2C的构造器可以直接根据字符串将环境进行矢量化,因为注册的原因
model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log='./model/tensorboard/a2c_cartpole/')

# 可以设置自定义的训练日志,默认为算法名称
model.learn(total_timesteps=100, tb_log_name='first_run')
# 默认情况下保存model不会保存tensorboard的路径地址'./model/gym/a2c_cartpole.pkl')

# 为了展示load,删除model
del model

# The A2C algorithm require a vectorized environment ot run
env = gym.make('CartPole-v1')
env = DummyVecEnv([lambda: env])

# 显式设置log路径
model = A2C.load('./model/gym/a2c_cartpole.pkl', env=env, tensorboard_log='./model/tensorboard/a2c_cartpole/')
print('Loaded model.')

# Pass reset_num_timesteps=False to continue the training curve in tensorboard
# By default, it will create a new curve
model.learn(total_timesteps=10000, tb_log_name="second_run", reset_num_timesteps=False)


tensorboard --logdir ./a2c_cartpole_tensorboard/;./ppo2_cartpole_tensorboard/

可以自定义打印记录和输出的tensor或者自定义变量,也可以将所有输出到terminal的内容输出到tensorboard中,Code Link.

RL Baselines Zoo

It’s a collection of pre-trained Reinforcement Learning agents using Stable-Baselines. It also provides basic scripts for training, evaluating agents, tuning hyper-parameters and recording videos.

Pre-training (Behavior Cloning)


行为克隆(Behavior Cloning)将模仿学习的问题(即使用专家演示)视为监督学习问题。 也就是说,在给定专家轨迹(观察-行动对)的情况下,对策略网络进行了训练以重现专家行为:对于给定的观察,策略所采取的行动必须是专家所采取的行动。

Generate Expert Trajectories



Note: 使用图片作为输入,由于图片占用内存,不能全部作为numpy对象存储(可能导致内存溢出), 因此来自专家的图片数据需要放到一个文件夹中。


from stable_baselines import DQN
from stable_baselines.gail import generate_expert_traj

# 预训练一个DQN神经网络生成10个轨迹作为expert策略样本
model = DQN('MlpPolicy', 'CartPole-v1', verbose=1)
# Train a DQN agent for 1e5 timesteps and generate 10 trajectories
# data will be saved in a numpy archive named `expert_cartpole.npz`
generate_expert_traj(model, './model/expert_trajectories/expert_dqn_cartpole', n_timesteps=int(1e5), n_episodes=10)


# Here the expert is a random agent
# but it can be any python function, e.g. a PID controller
def dummy_expert(_obs):
    Random agent. It samples actions randomly
    from the action space of the environment.

    :param _obs: (np.ndarray) Current observation
    :return: (np.ndarray) action taken by the expert
    return env.action_space.sample()

# Data will be saved in a numpy archive named `expert_cartpole.npz`
# when using something different than an RL expert,
# you must pass the environment object explicitly
generate_expert_traj(dummy_expert, save_path='./model/expert_trajectories/dummy_expert_cartpole', env=env,

Pre-Train a Model using Behavior Cloning

from stable_baselines import PPO2
from stable_baselines.gail import ExpertDataset

# Using only one expert trajectory
# you can specify `traj_limitation=-1` for using the whole dataset
dataset = ExpertDataset(expert_path='./model/expert_trajectories/expert_dqn_cartpole.npz',
                        traj_limitation=1, batch_size=128)

model = PPO2('MlpPolicy', 'CartPole-v1', verbose=1)

# Pre-train the PPO2 model
model.pretrain(dataset, n_epochs=10000)

# As an option, you can train the RL agent

Data Structure of the Expert Dataset

Data Structure of the Expert Dataset The expert dataset is a .npz archive. The data is saved in python dictionary format with keys: actions, episode_returns, rewards, obs, episode_starts.

In case of images, obs contains the relative path to the images.

obs, actions: shape (N * L, ) + S

where N = # episodes, L = episode length and S is the environment observation/action space.

S = (1, ) for discrete space.

Dealing with NaNs and infs

在训练RL的过程中,模型训练有可能因为NaN or inf的存在而完全崩溃,主要是因为IEEE制定的关于浮点计算的规定中,共有5中情况会导致NaN or inf,详见IEEE Standard for Floating-Point Arithmetic (IEEE 754)

五中情况中,只有Division by zero会产生异常,其他的都会很快地反向传播影响训练。

  • Python会提出异常:ZeroDivisionError: float division by zero,但是忽略其他四种情况
  • Numpy默认提出警告:RuntimeWarning: invalid value encountered但是不会停止代码运行
  • Tensorflow直接忽视五种情况,不做任何提醒

针对该问题,stable baselines提供了一个VecCheckNan wrapper. It will monitor the actions, observations, and rewards, indicating what action or observation caused it and from what. 该装饰器非常重要,Code Link.


from stable_baselines.common.vec_env import DummyVecEnv, VecCheckNan
# Create environment
env = DummyVecEnv([lambda: NanAndInfEnv()])
env = VecCheckNan(env, raise_exception=True)

Auxiliary Functions

Making GIF

import imageio
import numpy as np

from stable_baselines import A2C

model = A2C("MlpPolicy", "LunarLander-v2").learn(100000)

images = []
obs = model.env.reset()
img = model.env.render(mode='rgb_array')
for i in range(350):
    action, _ = model.predict(obs)
    obs, _, _, _ = model.env.step(action)
    img = model.env.render(mode='rgb_array')

imageio.mimsave('lander_a2c.gif', [np.array(img) for i, img in enumerate(images) if i % 2 == 0], fps=29)

Making Video

当在Google Colab中运行时,由于没有显示器,无法直接渲染环境,因此需要一个faker display,但是切记在本地运行有显示器时注释掉。

# Set up fake display; otherwise rendering will fail
import os
os.system("Xvfb :1 -screen 0 1024x768x24 &")
os.environ['DISPLAY'] = ':1'

Record Video

from stable_baselines.common.vec_env import VecVideoRecorder

def record_video(env_id, model, video_length=500, prefix='', video_folder='videos/'):
    :param env_id: (str)
    :param model: (RL model)
    :param video_length: (int)
    :param prefix: (str)
    :param video_folder: (str)
    eval_env = DummyVecEnv([lambda: gym.make(env_id)])
    # Start the video at step=0 and record 500 steps
    eval_env = VecVideoRecorder(eval_env, video_folder=video_folder,
                              record_video_trigger=lambda step: step == 0, video_length=video_length,

    obs = eval_env.reset()
    for _ in range(video_length):
        action, _ = model.predict(obs)
        obs, _, _, _ = eval_env.step(action)

    # Close the video recorder

Show Video

import base64
from pathlib import Path

from IPython import display as ipythondisplay

def show_videos(video_path='', prefix=''):
    Taken from

    :param video_path: (str) Path to the folder containing videos
    :param prefix: (str) Filter the video, showing only the only starting with this prefix
    html = []
    for mp4 in Path(video_path).glob("{}*.mp4".format(prefix)):
        video_b64 = base64.b64encode(mp4.read_bytes())
        html.append('''<video alt="{}" autoplay
                    loop controls style="height: 400px;">
                    <source src="data:video/mp4;base64,{}" type="video/mp4" />
                </video>'''.format(mp4, video_b64.decode('ascii')))

Helper Function

def evaluate(model, num_episodes=100):
    Evaluate a RL agent
    :param model: (BaseRLModel object) the RL Agent
    :param num_episodes: (int) number of episodes to evaluate it
    :return: (float) Mean reward for the last num_episodes
    # This function will only work for a single Environment
    env = model.get_env()
    all_episode_rewards = []
    for i in range(num_episodes):
        episode_rewards = []
        done = False
        obs = env.reset()
        while not done:
            # _states are only useful when using LSTM policies
            action, _states = model.predict(obs)
            # here, action, rewards and done are arrays
            # because we are using vectorized env
            obs, reward, done, info = env.step(action)


    mean_episode_reward = np.mean(all_episode_rewards)
    print("Mean reward:", mean_episode_reward, "Num episodes:", num_episodes)

    return mean_episode_reward

Suppress the Warning Message

# Filter tensorflow version warnings
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # or any {'0', '1', '2'}
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)
import tensorflow as tf
import logging

Get Callback Variables


def get_callback_vars(model, **kwargs):
    Helps store variables for the callback functions
    :param model: (BaseRLModel)
    :param **kwargs: initial values of the callback variables
    # save the called attribute in the model
    if not hasattr(model, "_callback_vars"):
        model._callback_vars = dict(**kwargs)
    else: # check all the kwargs are in the callback variables
        for (name, val) in kwargs.items():
            if name not in model._callback_vars:
                model._callback_vars[name] = val
    return model._callback_vars # return dict reference (mutable)

