ScaDaMaLe WASP-UU 2022 - Student Group Project 13 - Distributed Reinforcement Learning

Project description

We extend the Proximal Policy Optimization (PPO) and Soft Actor-Critic (SAC) reinforcement learning algorithms from the stable_baselines3 library to make them run in a distributed fashion.

We use Horovod, which is a distributed deep learning training framework that supports several ML libraries, including PyTorch through the horovod.torch package. But more importantly, it can be run on top of spark using sparkdl.HorovodRunner. The HorovodRunner launches spark jobs on training functions that implement the Horovod framework. All that is needed to utilize the Horovod framework is to wrap the original torch.optim.Optimizer optimizer in a horovod.torch.DistributedOptimizer.

Reinforcement Learning

Definition: Reinforcement Learning is one of three basic machine learning paradigms, alongside supervised learning and unsupervised learning. It is a machine learning training method based on rewarding desired behaviors and/or punishing undesired ones. In general, a reinforcement learning agent is able to perceive and interpret its environment, take actions and learn through trial and error.

Environment: Gymnasium, where the classic “agent-environment loop” is implemented. The agent performs some actions in the environment (usually by passing some control inputs to the environment, e.g. torque inputs of motors) and observes how the environment’s state changes. One such action-observation exchange is referred to as a timestep.

Example tasks:

Stable Baselines3

Stable Baselines3 is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the latest major version of Stable Baselines.

Github repository: https://github.com/DLR-RM/stable-baselines3

Paper: https://jmlr.org/papers/volume22/20-1364/20-1364.pdf

PPO:

The Proximal Policy Optimization algorithm combines ideas from A2C (having multiple workers) and TRPO (it uses a trust region to improve the actor).

The main idea is that after an update, the new policy should be not too far from the old policy. For that, ppo uses clipping to avoid too large update.

SAC:

Soft Actor Critic (SAC) Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.

SAC is the successor of Soft Q-Learning SQL and incorporates the double Q-learning trick from TD3. A key feature of SAC, and a major difference with common RL algorithms, is that it is trained to maximize a trade-off between expected return and entropy, a measure of randomness in the policy.

Get all necessary imports.

import csv
import gym
import horovod.torch as hvd
import numpy as np
import os
import shutil
import torch as th
import warnings

from datetime import datetime
from gym import spaces
from matplotlib import pyplot as plt
from sparkdl import HorovodRunner
from torch.nn import functional as F
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

from stable_baselines3 import PPO, SAC
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_parameters_by_name, get_schedule_fn, polyak_update
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy

And global variables

LOG_DIR = '/dbfs/drl_logs'

Distributed Proximal Policy Optimization, DPPO

Theoretical Background

PPO is a reinforcement learning algorithm, which is trained using the following objective (that should be maximized)

\[ L^{\text{CLIP}}(\theta) = \mathbb E_t \big[ \text{min}(r_t(\theta)A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)A_t) \big], \]

let's go through what this means. First, \(\pi_{\theta}(a_t|s_t)\) is the "policy". This is a neural network that we want to train to estimate the "best" action to take in any given state. Then we have \(r_t(\theta) = \frac{\pi_{\theta}(a_t | s_t)}{\pi_{\theta_{\text{old}}}(a_t | s_t)}\), which is the ratio of the probability of the network choosing a certain action, compared to a previous version of the network. Then we have \(A_t = -V(s_t) + r_t + \gamma r_{t+1} ...\) which is the estimated advantage. Intuitively this is how much better the policy performed than expected by the value network V. The inituition is that we want to get as large advantages as possible, and it is especially important when the previous policy was unlikely to have yielded the result. However, we still want the new parameters too not go too far from the old ones. To ensure this we "clip" the ratio. This ensures that the policy network gets no gradients after a certain threshold epsilon. See image below:

alt text

Here it is clear that the network will cease getting gradient after improving sufficiently on for all timesteps t. A further detail is that the value network has an additional value loss which is trained to minimize the advantage. The algorithm can be summarized as follows.

  1. Gather a "dataset" of state action reward pairs, given the current policy. (Hereafter referred to as rollouts)
  2. Train the policy to maximize advantage while clipping, also train the value network to predict the values for accurate estimation of advantage. (Hereafter training)

Practical Implementation

Looking at the above algorithm it is immediately obvious that the rollouts are embarassingly parallel, since no communication is needed between nodes. However, during training, the parameters of the models are updated according to the optimizer (which requires gradients to be synched between nodes). One further caveat: The stable_baselines3 implementation uses inplace gradient clipping. This is a non-linear operation, which means that the sum of clipped gradients over the nodes is not the same as the clipped summed gradient. Hence we first need to sync (sum) the gradients, and then perform clipping. Horovod by default syncs the gradients at optimizer.step(), however it also permits manual syncing, which we have implemented in the code below. We summarize the communication needed between nodes during rollouts and training in the figure below:

alt text

We extend the original PPO class from stable_baselines3. We modify the _setup_model function to wrap the optimizer in a hvd.DistributedOptimizer. We also had to change the train function which performs the gradient step. This change was necessary to account for the gradient clipping, which must happen after averaging the gradients across machines.

class DPPO(PPO):
    """
    Distributed Proximal Policy Optimization (DPPO)
    """

    def _setup_model(self) -> None:
        super()._setup_model()
        self.policy.optimizer = hvd.DistributedOptimizer(self.policy.optimizer, named_parameters=self.policy.named_parameters())
        hvd.broadcast_parameters(self.policy.state_dict(), root_rank=0)
        hvd.broadcast_optimizer_state(self.policy.optimizer, root_rank=0)

    def train(self) -> None:
        """
        Update policy using the currently gathered rollout buffer.
        """
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizer learning rate
        self._update_learning_rate(self.policy.optimizer)
        # Compute current clip range
        clip_range = self.clip_range(self._current_progress_remaining)
        # Optional: clip range for the value function
        if self.clip_range_vf is not None:
            clip_range_vf = self.clip_range_vf(self._current_progress_remaining)

        entropy_losses = []
        pg_losses, value_losses = [], []
        clip_fractions = []

        continue_training = True

        # train for n_epochs epochs
        for epoch in range(self.n_epochs):
            approx_kl_divs = []
            # Do a complete pass on the rollout buffer
            for rollout_data in self.rollout_buffer.get(self.batch_size):
                actions = rollout_data.actions
                if isinstance(self.action_space, spaces.Discrete):
                    # Convert discrete action from float to long
                    actions = rollout_data.actions.long().flatten()

                # Re-sample the noise matrix because the log_std has changed
                if self.use_sde:
                    self.policy.reset_noise(self.batch_size)

                values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
                values = values.flatten()
                # Normalize advantage
                advantages = rollout_data.advantages
                # Normalization does not make sense if mini batchsize == 1, see GH issue #325
                if self.normalize_advantage and len(advantages) > 1:
                    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

                # ratio between old and new policy, should be one at the first iteration
                ratio = th.exp(log_prob - rollout_data.old_log_prob)

                # clipped surrogate loss
                policy_loss_1 = advantages * ratio
                policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
                policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()

                # Logging
                pg_losses.append(policy_loss.item())
                clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
                clip_fractions.append(clip_fraction)

                if self.clip_range_vf is None:
                    # No clipping
                    values_pred = values
                else:
                    # Clip the difference between old and new value
                    # NOTE: this depends on the reward scaling
                    values_pred = rollout_data.old_values + th.clamp(
                        values - rollout_data.old_values, -clip_range_vf, clip_range_vf
                    )
                # Value loss using the TD(gae_lambda) target
                value_loss = F.mse_loss(rollout_data.returns, values_pred)
                value_losses.append(value_loss.item())

                # Entropy loss favor exploration
                if entropy is None:
                    # Approximate entropy when no analytical form
                    entropy_loss = -th.mean(-log_prob)
                else:
                    entropy_loss = -th.mean(entropy)

                entropy_losses.append(entropy_loss.item())

                loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss

                # Calculate approximate form of reverse KL Divergence for early stopping
                # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
                # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
                # and Schulman blog: http://joschu.net/blog/kl-approx.html
                with th.no_grad():
                    log_ratio = log_prob - rollout_data.old_log_prob
                    approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
                    approx_kl_divs.append(approx_kl_div)

                if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
                    continue_training = False
                    if self.verbose >= 1:
                        print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
                    break

                # Optimization step
                self.policy.optimizer.zero_grad()
                loss.backward()
                # Need to synch gradients first, to prevent non-linear gradient clipping to happen too early
                self.policy.optimizer.synchronize()
                # Clip grad norm
                th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
                # Dont need to sync now.
                with self.policy.optimizer.skip_synchronize():
                    self.policy.optimizer.step()

            if not continue_training:
                break

        self._n_updates += self.n_epochs
        explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())

        # Logs
        self.logger.record("train/entropy_loss", np.mean(entropy_losses))
        self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
        self.logger.record("train/value_loss", np.mean(value_losses))
        self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
        self.logger.record("train/clip_fraction", np.mean(clip_fractions))
        self.logger.record("train/loss", loss.item())
        self.logger.record("train/explained_variance", explained_var)
        if hasattr(self.policy, "log_std"):
            self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/clip_range", clip_range)
        if self.clip_range_vf is not None:
            self.logger.record("train/clip_range_vf", clip_range_vf)

Next, we define our distributed training function that Horovod will run on each process. We make it general since we want to use it for different algorithms. It takes an algorithm class algo as input. Furthermore, it takes the environment name, type of policy, and the number of time steps, as well as any keyword arguments to be passed to the algorithm class. We do any logging or model saving only on the main process, i.e. rank 0.

def train_hvd(algo, env_name="Pendulum-v1", policy="MlpPolicy", total_timesteps=100_000, **kwargs):
    
    # Initialize Horovod
    hvd.init()
    
    # Create environment, model, and run training
    env = gym.make(env_name)
    # log reward etc. only on the main process by wrapping the environment in a Monior
    if hvd.rank() == 0:
        env = Monitor(env, os.path.join(LOG_DIR, env_name, algo.__name__ + '-' + str(hvd.size())))
    model = algo(policy, env, **kwargs)
    model.learn(total_timesteps=total_timesteps, log_interval=1)
    
    # Save model only on process 0
    if hvd.rank() == 0:
        exp_time = datetime.now().strftime('%Y-%m-%d_%H%M%S')
        log_dir = os.path.join(LOG_DIR, exp_time)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        # DBFS doesnt support zip, need to create on driver and then move file to DBFS
        filename = f'{algo.__name__}_{env_name}.zip'
        model.save(filename)
        shutil.move(filename, f"{log_dir}/")
    
    env.close()

Single-process PPO

First, let us train a PPO agent in a single process using the original implementation from stable_baselines3 (i.e. not using our DPPO class or train_hvd function).

Start with some parameters:

# The environment where to test our distributed PPO algorithm
ppo_env_name = 'CartPole-v1'
#ppo_env_name = 'Pendulum-v1'
#ppo_env_name = 'MountainCarContinuous-v0'
#ppo_env_name = 'BipedalWalker-v3'

# PPO parameters
ppo_total_timesteps = 100_000
ppo_learning_rate = 1e-3 # default: 3e-4
ppo_n_steps = 4096 * 8 # default: 2048
ppo_batch_size = 4096 # default: 64

# How many processes to use for distributed training
ppo_world_sizes = [1, 2, 4, 8]

Train the PPO agent.

# Create a gym environment and wrap it in a Monitor to track the reward
env = gym.make(ppo_env_name)
env = Monitor(env, os.path.join(LOG_DIR, ppo_env_name, 'PPO'))

# Define our PPO agent
model = PPO(
    "MlpPolicy",
    env,
    learning_rate=ppo_learning_rate,
    n_steps=ppo_n_steps,
    batch_size=ppo_batch_size,
    verbose=1)

# Train our agent
model.learn(total_timesteps=ppo_total_timesteps, log_interval=1)
env.close()

Multi-process DPPO

Train our distributed DPPO algorithm for different number of processes (world_size). To have comparable training settings for different number of processes, we divide the per-process rollout buffer size n_steps with the number of processes (since the total buffer size will then be the same). We also multiply the learning rate learning_rate with the number of processes (since the total batch size will be larger). Note that this means that a gradient step for the multi-process case will essentially be equivalent to world_size number of single-process gradient steps. We will take this into account when plotting the results later to validate that the training settings actually are equivalent.

# Loop over number of processes and do an experiment for each
for world_size in ppo_world_sizes:

    # Initialize the HorovodRunner
    hr = HorovodRunner(np=world_size, driver_log_verbosity='all')

    # Launch the spark job on our training function
    hr.run(
        train_hvd,
        algo = DPPO,
        env_name = ppo_env_name,
        policy = "MlpPolicy",
        total_timesteps = ppo_total_timesteps,
        learning_rate = ppo_learning_rate * world_size,
        n_steps = ppo_n_steps // world_size,
        batch_size = ppo_batch_size,
        verbose = 1
    )

PPO and DPPO results

We compare the different runs by plotting the reward over per-process training steps, total number of training steps, and wall time. We expect a faster convergence rate for a larger number of processes when looking at the per-process training steps, and the same reward for the same number of total training steps, since these should correspond to the same training setting. We also expect a faster convergence in wall time for more processes. We also compute the speedup for different number of processes.

Define a function to get the logs from a run.

def get_logs(log_file, as_dict=False):
    steps, rewards, ep_lengths, times = [], [], [], []
    with open(log_file) as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=',')
        for i, row in enumerate(csv_reader):
            if i < 2: # skip first two rows (headers)
                continue
            rewards.append(float(row[0]))
            ep_lengths.append(int(row[1]))
            times.append(float(row[2]))
            steps.append((steps[-1] if len(steps) != 0 else 0) + ep_lengths[-1])
    if as_dict:
        return {'steps': steps,
                'rewards': rewards,
                'ep_lengths': ep_lengths,
                'times': times}
    return steps, rewards, ep_lengths, times

Define a function to plot the reward.

def plot_results(env_name, algo_names, x_axis='steps', num_procs=None, smooth=None):
    
    # Read logs
    logs = {}
    if isinstance(algo_names, str):
        logs[algo_names] = get_logs(os.path.join(LOG_DIR, env_name, f'{algo_names}.monitor.csv'), as_dict=True)
    else: # list of str
        for algo_name in algo_names:
            logs[algo_name] = get_logs(os.path.join(LOG_DIR, env_name, f'{algo_name}.monitor.csv'), as_dict=True)
    
    # Plot reward
    for algo_name in algo_names:

        # Scale x-axis if number of processes is given
        if num_procs is not None:
            scale_by_num_proc = num_procs[algo_name]
        else:
            scale_by_num_proc = 1
        
        # Smooth and plot the reward
        reward = logs[algo_name]['rewards']
        if smooth is not None:
            reward = [reward[0]]*smooth + reward + [reward[-1]]*smooth
            reward = np.convolve(reward, [1/(smooth*2 + 1)]*(smooth*2 + 1), mode='valid')
        plt.plot(scale_by_num_proc*np.array(logs[algo_name][x_axis]), reward)
    
    plt.legend(algo_names)
    plt.title(env_name)
    plt.ylabel('reward')
    if x_axis == 'times':
        if num_procs is not None:
            plt.xlabel('total cpu time [s]')
        else:
            plt.xlabel('wall time [s]')
    elif x_axis == 'steps':
        if num_procs is not None:
            plt.xlabel('total steps')
        else:
            plt.xlabel('steps per process')
    else:
        plt.xlabel(x_axis)
    plt.gcf().set_size_inches(8, 4)
    plt.show()

Define a function to plot the speedup.

def plot_speedup(env_name, runs):
    algo_name = runs[0]

    # Compute wall time, total steps, and speedup
    wall_times = []
    total_steps = []
    speedups = []
    nprocs = []

    # Get logs from single-process run
    logs1 = get_logs(os.path.join(LOG_DIR, env_name, f'D{algo_name}-1.monitor.csv'), as_dict=True)
    wall_time1 = logs1['times'][-1]
    total_step1 = logs1['steps'][-1]

    # Compute results for each run
    for run in runs:
        logs = get_logs(os.path.join(LOG_DIR, env_name, f'{run}.monitor.csv'), as_dict=True)
        wall_times.append(logs['times'][-1])
        total_steps.append(logs['steps'][-1])
        if '-' in run:
            nprocs.append(int(run.split('-')[1]))
            total_steps[-1] *= nprocs[-1]
            speedup = wall_time1 / wall_times[-1]
            speedup *= total_steps[-1] / total_step1 # adjust for different total steps
            speedups.append(round(speedup, 2))

    # Print run times and total steps
    print('Run times')
    for run, wall_time, total_step in zip(runs, wall_times, total_steps):
        print(f'{run}: {round(wall_time, 2)} s ({total_step} total steps)')

    # Print speedup
    print('\nSpeedups')
    for speedup, nproc in zip(speedups, nprocs):
        print(f'D{algo_name}-{nproc}: {speedup}')

    # Plot speedup as a bar plot
    x = np.arange(len(speedups))  # the label locations
    width = 0.35  # the width of the bars
    fig, ax = plt.subplots()
    rects1 = ax.bar(x - width/2, speedups, width, label='speedup')
    rects2 = ax.bar(x + width/2, nprocs, width, label='ideal')

    # Add some text for labels, title and custom x-axis tick labels, etc.
    ax.set_ylabel('Speedup')
    ax.set_xlabel('Number of processes')
    ax.set_title(f'D{algo_name} speedup on {env_name}')
    plt.xticks(x, nprocs)
    ax.legend()
    ax.bar_label(rects1, padding=3)
    ax.bar_label(rects2, padding=3)
    fig.tight_layout()
    plt.show()

Reward plots

Finally, we plot the reward.

ppo_runs = ['PPO'] + [f'DPPO-{ws}' for ws in ppo_world_sizes]
ppo_nprocs = {**{'PPO': 1}, **{f'DPPO-{ws}': ws for ws in ppo_world_sizes}}

plot_results(ppo_env_name, ppo_runs, x_axis='steps', smooth=5)
plot_results(ppo_env_name, ppo_runs, x_axis='steps', num_procs=ppo_nprocs, smooth=5)
plot_results(ppo_env_name, ppo_runs, x_axis='times', smooth=5)

Speedup

Plot the speedup for the different number of processes.

plot_speedup(ppo_env_name, ppo_runs)

GIF

Here is what our trained agent looks like.

Left: Not trained (the environment resets each time the pole falls too low) Right: Trained with PPO

Distributed Soft Actor-Critic, DSAC

SAC is an off-policy reinforcement learning algorithm. The main difference between it and PPO is the following:

  1. A larger set of networks are used, and they do not share the same optimizer.
  2. Instead of "rollouts", the agent takes a single "step" in the environment, and immediately updates the weights using a "replay buffer". Note that the algorithm is called an "off-policy algorithm" due to the fact that state action pairs in the replay buffer may have been generated by previous policies, not the current policy.

Regarding point 1., in practice this forces us to wrap multiple optimizers. Furthermore, due to quirks of horovod (see: https://github.com/horovod/horovod/issues/1417), we are forced to sync gradients explicitly after the every backwards pass.

Regarding point 2., SAC alternates between updating a rolling buffer and updating model weights using the buffer. We illustrate this below.

alt text

We extend the original SAC class from stable_baselines3. Different from DPPO, we modify the __init__ function to wrap multiple optimizers. The wrapping part is a bit different since actor and critic networks may or may not share the feature extractor. Additionally, a third optimizer is used for an entropy coefficient (scalar), which may or may not be used. SAC does not used gradient clipping, but we still needed to modify the train function to syncronize the actor and critic optimizers between each other.

class DSAC(SAC):
    """
    Distributed Soft Actor-Critic (DSAC)
    """

    def __init__(
        self,
        policy,
        env,
        learning_rate = 3e-4,
        buffer_size = 1_000_000,  # 1e6
        learning_starts = 100,
        batch_size = 256,
        tau = 0.005,
        gamma = 0.99,
        train_freq = 1,
        gradient_steps = 1,
        action_noise = None,
        replay_buffer_class = None,
        replay_buffer_kwargs = None,
        optimize_memory_usage = False,
        ent_coef = "auto",
        target_update_interval = 1,
        target_entropy = "auto",
        use_sde = False,
        sde_sample_freq = -1,
        use_sde_at_warmup = False,
        tensorboard_log = None,
        create_eval_env = False,
        policy_kwargs = None,
        verbose = 0,
        seed = None,
        device = "auto",
        _init_setup_model = True,
    ):

        super().__init__(
            policy,
            env,
            learning_rate,
            buffer_size,
            learning_starts,
            batch_size,
            tau,
            gamma,
            train_freq,
            gradient_steps,
            action_noise,
            replay_buffer_class,
            replay_buffer_kwargs,
            optimize_memory_usage,
            ent_coef,
            target_update_interval,
            target_entropy,
            use_sde,
            sde_sample_freq,
            use_sde_at_warmup,
            tensorboard_log,
            create_eval_env,
            policy_kwargs,
            verbose,
            seed,
            device,
            _init_setup_model,
        )
        
        # Wrap optimizers in Horovod
        # Actor optimizer
        self.actor.optimizer = hvd.DistributedOptimizer(self.actor.optimizer, named_parameters=self.actor.named_parameters())
        hvd.broadcast_parameters(self.actor.state_dict(), root_rank=0)
        hvd.broadcast_optimizer_state(self.actor.optimizer, root_rank=0)
        
        # Critic optimizer
        if self.policy.share_features_extractor:
            critic_parameters = [(name, param) for name, param in self.critic.named_parameters() if "features_extractor" not in name]
        else: # used by default
            critic_parameters = self.critic.named_parameters()
        self.critic.optimizer = hvd.DistributedOptimizer(self.critic.optimizer, named_parameters=critic_parameters)
        hvd.broadcast_parameters(self.critic.state_dict(), root_rank=0)
        hvd.broadcast_optimizer_state(self.critic.optimizer, root_rank=0)
        
        # Entropy coefficient optimizer
        if self.ent_coef_optimizer is not None:
            self.ent_coef_optimizer = hvd.DistributedOptimizer(self.ent_coef_optimizer, named_parameters=[("log_ent_coef", self.log_ent_coef)])
            hvd.broadcast_parameters([self.log_ent_coef], root_rank=0)
            hvd.broadcast_optimizer_state(self.ent_coef_optimizer, root_rank=0)
    
    # Need to redefine this function to synchronize multiple optimizers
    def train(self, gradient_steps: int, batch_size: int = 64) -> None:
        
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizers learning rate
        optimizers = [self.actor.optimizer, self.critic.optimizer]
        if self.ent_coef_optimizer is not None:
            optimizers += [self.ent_coef_optimizer]

        # Update learning rate according to lr schedule
        self._update_learning_rate(optimizers)

        ent_coef_losses, ent_coefs = [], []
        actor_losses, critic_losses = [], []

        for gradient_step in range(gradient_steps):
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

            # We need to sample because `log_std` may have changed between two gradient steps
            if self.use_sde:
                self.actor.reset_noise()

            # Action by the current actor for the sampled state
            actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations)
            log_prob = log_prob.reshape(-1, 1)

            ent_coef_loss = None
            if self.ent_coef_optimizer is not None:
                # Important: detach the variable from the graph
                # so we don't change it with other losses
                # see https://github.com/rail-berkeley/softlearning/issues/60
                ent_coef = th.exp(self.log_ent_coef.detach())
                ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
                ent_coef_losses.append(ent_coef_loss.item())
            else:
                ent_coef = self.ent_coef_tensor

            ent_coefs.append(ent_coef.item())

            # Optimize entropy coefficient, also called
            # entropy temperature or alpha in the paper
            if ent_coef_loss is not None:
                self.ent_coef_optimizer.zero_grad()
                ent_coef_loss.backward()
                self.ent_coef_optimizer.step()

            with th.no_grad():
                # Select action according to policy
                next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
                # Compute the next Q values: min over all critics targets
                next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
                next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
                # add entropy term
                next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1)
                # td error + entropy term
                target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values

            # Get current Q-values estimates for each critic network
            # using action from the replay buffer
            current_q_values = self.critic(replay_data.observations, replay_data.actions)

            # Compute critic loss
            critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
            critic_losses.append(critic_loss.item())

            # Optimize the critic
            self.critic.optimizer.zero_grad()
            critic_loss.backward()
            self.critic.optimizer.step()
            self.actor.optimizer.synchronize() # <----- diff from original function

            # Compute actor loss
            # Alternative: actor_loss = th.mean(log_prob - qf1_pi)
            # Min over all critic networks
            q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1)
            min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
            actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
            actor_losses.append(actor_loss.item())

            # Optimize the actor
            self.actor.optimizer.zero_grad()
            actor_loss.backward()
            self.actor.optimizer.step()
            self.critic.optimizer.synchronize() # <----- diff from original function

            # Update target networks
            if gradient_step % self.target_update_interval == 0:
                polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
                # Copy running stats, see GH issue #996
                polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)

        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/ent_coef", np.mean(ent_coefs))
        self.logger.record("train/actor_loss", np.mean(actor_losses))
        self.logger.record("train/critic_loss", np.mean(critic_losses))
        if len(ent_coef_losses) > 0:
            self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))

Single-process SAC

Train a SAC agent in a single process using the original implementation from stable_baselines3 (i.e. not using our DSAC class or train_hvd function).

Start with some parameters:

# The environment where to test our distributed SAC algorithm
#sac_env_name = 'CartPole-v1'
sac_env_name = 'Pendulum-v1'
#sac_env_name = 'MountainCarContinuous-v0'
#sac_env_name = 'BipedalWalker-v3'

# SAC parameters
sac_total_timesteps = 10_000
sac_learning_rate = 1e-3 # default: 3e-4
sac_buffer_size = 4096 * 8 # default: 1_000_000
sac_learning_starts = 2048 # default: 100
sac_batch_size = 2048 # default: 256

# How many processes to use for distributed training
sac_world_sizes = [1, 2, 4, 8]

Train the SAC agent.

# Create a gym environment and wrap it in a Monitor to track the reward
env = gym.make(sac_env_name)
env = Monitor(env, os.path.join(LOG_DIR, sac_env_name, 'SAC'))

# Define the SAC agent
model = SAC(
    "MlpPolicy",
    env,
    learning_rate=sac_learning_rate,
    buffer_size=sac_buffer_size,
    learning_starts=sac_learning_starts,
    batch_size=sac_batch_size,
    verbose=1)

# Train the agent
model.learn(total_timesteps=sac_total_timesteps, log_interval=1)
env.close()

Multi-process DSAC

Train our distributed DSAC algorithm for different number of processes (world_size). Againt, to have comparable training settings for different number of processes, we divide the per-process replay buffer size (buffer_size) with the number of processes, and multiply the learning rate (learning_rate) with the number of processes. We use the same train_hvd function as for PPO.

for world_size in sac_world_sizes:
    hr = HorovodRunner(np=world_size, driver_log_verbosity='all')
    hr.run(
        train_hvd,
        algo = DSAC,
        env_name = sac_env_name,
        policy = "MlpPolicy",
        total_timesteps = sac_total_timesteps,
        learning_rate = sac_learning_rate * world_size,
        buffer_size = sac_buffer_size // world_size,
        learning_starts = sac_learning_starts // world_size,
        batch_size = sac_batch_size,
        verbose = 1
    )

SAC/DSAC results

We compare the different runs by plotting the reward over per-process training steps, total number of training steps, and wall time.

Reward plots

Plot the reward.

sac_runs = ['SAC'] + [f'DSAC-{ws}' for ws in sac_world_sizes]
sac_nprocs = {**{'SAC': 1}, **{f'DSAC-{ws}': ws for ws in sac_world_sizes}}

plot_results(sac_env_name, sac_runs, x_axis='steps')
plot_results(sac_env_name, sac_runs, x_axis='steps', num_procs=sac_nprocs)
plot_results(sac_env_name, sac_runs, x_axis='times')

Speedup

Plot the speedup for the different number of processes.

plot_speedup(sac_env_name, sac_runs)

GIF

Here is what our trained agent looks like.

Left: Not trained (random policy) Right: Trained with SAC

Conclusions

Reinforcement learning is scalable as we have demonstrated with the two methods PPO and SAC. To increase the scalability, we need to reduce the communications overhead of the gradient syncronizations compared to the inter-process computations. This can be done by either increasing the batch size (together with the learning rate), or by using a larger model. The former increases the number of forward and backward passes per gradient syncronization, while the latter increases the computational cost for each pass.