Breaking down State-of-the-Art PPO Implementations in JAX – Towards Data Science

Since its publication in a 2017 paper by OpenAI, Proximal Policy Optimization (PPO) is widely regarded as one of the state-of-the-art algorithms in Reinforcement Learning. Indeed, PPO has demonstrated remarkable performances across various tasks, from attaining superhuman performances in Dota 2 teams to solving a Rubiks cube with a single robotic hand while maintaining three main advantages: simplicity, stability, and sample efficiency.

However, implementing RL algorithms from scratch is notoriously difficult and error-prone, given the numerous error sources and implementation details to be aware of.

In this article, well focus on breaking down the clever tricks and programming concepts used in a popular implementation of PPO in JAX. Specifically, well focus on the implementation featured in the PureJaxRL library, developed by Chris Lu.

Disclaimer: Rather than diving too deep into theory, this article covers the practical implementation details and (numerous) tricks used in popular versions of PPO. Should you require any reminders about PPOs theory, please refer to the references section at the end of this article. Additionally, all the code (minus the added comments) is copied directly from PureJaxRL for pedagogical purposes.

Proximal Policy Optimization is categorized within the policy gradient family of algorithms, a subset of which includes actor-critic methods. The designation actor-critic reflects the dual components of the model:

Additionally, this implementation pays particular attention to weight initialization in dense layers. Indeed, all dense layers are initialized by orthogonal matrices with specific coefficients. This initialization strategy has been shown to preserve the gradient norms (i.e. scale) during forward passes and backpropagation, leading to smoother convergence and limiting the risks of vanishing or exploding gradients[1].

Orthogonal initialization is used in conjunction with specific scaling coefficients:

The training loop is divided into 3 main blocks that share similar coding patterns, taking advantage of Jaxs functionalities:

Before going through each block in detail, heres a quick reminder about the jax.lax.scan function that will show up multiple times throughout the code:

A common programming pattern in JAX consists of defining a function that acts on a single sample and using jax.lax.scan to iteratively apply it to elements of a sequence or an array, while carrying along some state. For instance, well apply it to the step function to step our environment N consecutive times while carrying the new state of the environment through each iteration.

In pure Python, we could proceed as follows:

However, we avoid writing such loops in JAX for performance reasons (as pure Python loops are incompatible with JIT compilation). The alternative is jax.lax.scan which is equivalent to:

Using jax.lax.scan is more efficient than a Python loop because it allows the transformation to be optimized and executed as a single compiled operation rather than interpreting each loop iteration at runtime.

We can see that the scan function takes multiple arguments:

Additionally, scan returns:

Finally, scan can be used in combination with vmap to scan a function over multiple dimensions in parallel. As well see in the next section, this allows us to interact with several environments in parallel to collect trajectories rapidly.

As mentioned in the previous section, the trajectory collection block consists of a step function scanned across N iterations. This step function successively:

Scanning this function returns the latest runner_state and traj_batch, an array of transition tuples. In practice, transitions are collected from multiple environments in parallel for efficiency as indicated by the use of jax.vmap(env.step, )(for more details about vectorized environments and vmap, refer to my previous article).

After collecting trajectories, we need to compute the advantage function, a crucial component of PPOs loss function. The advantage function measures how much better a specific action is compared to the average action in a given state:

Where Gt is the return at time t and V(St) is the value of state s at time t.

As the return is generally unknown, we have to approximate the advantage function. A popular solution is generalized advantage estimation[3], defined as follows:

With the discount factor, a parameter that controls the trade-off between bias and variance in the estimate, and t the temporal difference error at time t:

As we can see, the value of the GAE at time t depends on the GAE at future timesteps. Therefore, we compute it backward, starting from the end of a trajectory. For example, for a trajectory of 3 transitions, we would have:

Which is equivalent to the following recursive form:

Once again, we use jax.lax.scan on the trajectory batch (this time in reverse order) to iteratively compute the GAE.

Note that the function returns advantages + traj_batch.value as a second output, which is equivalent to the return according to the first equation of this section.

The final block of the training loop defines the loss function, computes its gradient, and performs gradient descent on minibatches. Similarly to previous sections, the update step is an arrangement of several functions in a hierarchical order:

Lets break them down one by one, starting from the innermost function of the update step.

This function aims to define and compute the PPO loss, originally defined as:

Where:

However, the PureJaxRL implementation features some tricks and differences compared to the original PPO paper[4]:

Heres the complete loss function:

The update_minibatch function is essentially a wrapper around loss_fn used to compute its gradient over the trajectory batch and update the model parameters stored in train_state.

Finally, update_epoch wraps update_minibatch and applies it on minibatches. Once again, jax.lax.scan is used to apply the update function on all minibatches iteratively.

From there, we can wrap all of the previous functions in an update_step function and use scan one last time for N steps to complete the training loop.

A global view of the training loop would look like this:

We can now run a fully compiled training loop using jax.jit(train(rng)) or even train multiple agents in parallel using jax.vmap(train(rng)).

There we have it! We covered the essential building blocks of the PPO training loop as well as common programming patterns in JAX.

To go further, I highly recommend reading the full training script in detail and running example notebooks on the PureJaxRL repository.

Thank you very much for your support, until next time

Full training script, PureJaxRL, Chris Lu, 2023

[1] Explaining and illustrating orthogonal initialization for recurrent neural networks, Smerity, 2016

[2] Initializing neural networks, DeepLearning.ai

[3] Generalized Advantage Estimation in Reinforcement Learning, Siwei Causevic, Towards Data Science, 2023

[4] Proximal Policy Optimization Algorithms, Schulman et Al., OpenAI, 2017

See more here:

Breaking down State-of-the-Art PPO Implementations in JAX - Towards Data Science

Related Posts

Comments are closed.