
Evolution of LLMs
The landscape of language models(LMs) has evolved dramatically since the introduction of the Transformer architecture in 2017. Here we will explore the
- mathematical foundations
- architectural innovations
- training breakthroughs
We will talk about everything the code, math, and ideas that revolutionized NLP.
Additionally you can treat this blog as a sort of part 2, to my original blog on transformers which you can checkout here.
How this blog is structured
We will go year by year, going through the revolutionary ideas introduced by each paper.
In the beginning of each section, I have added the abstract, as well as the authors. I have done this to show you, the people were involved behind each idea. As well as what they felt like was the main contribution of their paper.
Below that I have provided the link to the original paper as well as my own implementation of it, subsequently there is a quick summary section which you can skim over if you feel like you know the crux behind the idea.
Note: All the quick summaries are AI generated, and may contain some mistakes. The core content is all human generated though, so it definitely contains mistakes :)
After that, each section contains intuition, code, and mathematical explanation (wherever required) for each idea. I have tried to add all the prerequisite knowledge wherever possible (Like the PPO section contains derivation of policy gradient methods, as well as explanation for TRPO). I have provided links to resources wherever I have felt I cannot provide enough background or do sufficient justice to the source material.
Additionally there has been a lot of innovation in vision modeling, TTS, Image gen, Video gen etc each of which deserves it’s own blog(And there will be!! I promise you that). As this is primarily an LLM blog, I will just give quick intro and links to some ground breaking innovations involving other ML papers.
Note: Do not take for granted all the hardware, data and benchmark innovations. Though I will briefly mention them. I implore you to explore them further if they interest you. This blog is strictly restricted to breakthroughs in Large Language Models, and mostly open source one’s. Even though current models by OpenAI, Anthropic, Google etc are amazing, not much is known about them to the public. So we will only briefly talk about them.
The AI timeline
This is a timeline of the most influential work. To read about more architectures that were huge at the time but died down eventually, consider going through the Transformer catalog.
The blog “Transformer models: an introduction and catalog — 2023 Edition” helped me immensely while making the timeline. Additionally this blog was helpful too.
Links post 2017 are broken as it’s still work in progress |
2017
2018
- Universal Language Model Fine-tuning for Text Classification
- Deep contextualized word representations
- Improving Language Understanding by Generative Pre-Training
- SentencePiece: A simple and language independent subword tokenizer and detokenizer for Neural Text Processing
- BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
2019
- Language Models are Unsupervised Multitask Learners
- RoBERTa: A Robustly Optimized BERT Pretraining Approach
- DistilBERT, a distilled version of BERT: smaller,faster, cheaper and lighter
- BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
- XLNet: Generalized Autoregressive Pretraining for Language Understanding
- Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism
- Generating Long Sequences with Sparse Transformers
2020
- Reformer: The Efficient Transformer
- Longformer: The Long-Document Transformer
- GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding
- Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks
- Big Bird: Transformers for Longer Sequences
- GPT-3
- Rethinking Attention with Performers
- T5
- Measuring Massive Multitask Language Understanding
- ZeRO (Zero Redundancy Optimizer)
- ELECTRA
- Switch Transformer
- Scaling Laws
2021
- RoFormer: Enhanced Transformer with Rotary Position Embedding
- Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM
- Transcending Scaling Laws with 0.1% Extra Compute
- Improving language models by retrieving from trillions of tokens
- CLIP
- Dall-e
- FSDP
- HumanEval
- LoRA
- Self-Instruct: Aligning Language Models with Self-Generated Instructions
- PaLM
- Gopher (DeepMind)
- Megatron-Turing NLG
2022
- EFFICIENTLY SCALING TRANSFORMER INFERENCE
- Fast Inference from Transformers via Speculative Decoding
- Chinchilla
- Chain-of-thought prompting
- InstructGPT
- BLOOM
- Emergent Abilities of Large Language Models
- Flash Attention
- Grouped-query attention
- ALiBi position encoding
- DeepSpeed Inference: Enabling Efficient Inference of Transformer Models at Unprecedented Scale
- Claude 1
- FLAN (Fine-tuned LAnguage Net) (Google)
- Red Teaming Language Models with Language Models
- HELM (Holistic Evaluation of Language Models)
- DALL-E 2 (OpenAI)
- Stable Diffusion (Stability AI)
- GPTQ
- Beyond the Imitation Game: Quantifying and extrapolating the capabilities of language models
- Minerva
- ChatGPT
2023
- Efficient Memory Management for Large Language Model Serving with PagedAttention
- QLoRA: Efficient Finetuning of Quantized LLMs
- Parameter-Efficient Fine-Tuning Methods for Pretrained Language Models: A Critical Review and Assessment
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
- AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration
- Generative Agents: Interactive Simulacra of Human Behavior
- Voyager: An Open-Ended Embodied Agent with Large Language Models
- Universal and Transferable Adversarial Attacks on Aligned Language Models
- Tree of Thoughts: Deliberate Problem Solving with Large Language Models
- Mpt
- WizardLM: Empowering Large Language Models to Follow Complex Instructions
- DeepSpeed-Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales
- GPT-4
- Mistral 7b
- LLaMA
- Mixtral 8x7B
- LLaMA 2
- Vicuna (LMSYS)
- Alpaca
- Direct Preference Optimization (DPO)
- Constitutional AI
- Toy Models of Superposition
- Towards Monosemanticity: Decomposing Language Models With Dictionary Learning
- PaLM 2
- LAION-5B (LAION)
- LIMA
- Mamba
- LLaVA (Visual Instruction Tuning)
- Claude 1/Claude 2
- Gemini
- Qwen
- Qwen-VL
- Phi-1
- Reinforced Self-Training (ReST) for Language Modeling
- The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits
2024
- Gemma
- Gemma 2
- Chatbot Arena: An Open Platform for Evaluating LLMs by Human Preference
- TinyLlama: An Open-Source Small Language Model
- MordernBert
- Jamba: A Hybrid Transformer-Mamba Language Model
- Claude 3
- LLaMA 3
- Gemini 1.5
- Qwen 2
- phi-2/phi-3
- OpenAI o1
- RSO (Reinforced Self-training with Online feedback)
- SPIN (Self-Played Improvement Narration)
- DBRX
- FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
- Qwen 2.5 (Alibaba)
- DeepSeek 2.5 (DeepSeek)
- Claude 3.5 Sonnet (Anthropic)
- DeepSeek-R1 (DeepSeek)
- Phi 3
- Phi 4
- Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models
2025
Note: I am releasing this blog early as a preview to get feedback from the community. It is still a work in progress and I plan to explain as well as implement each paper from each year. Do let me know your thoughts through my socials, or in the comments below!!!
2017: The Foundation Year
Transformer
Link to paper: Attention is all you need
Link to implementation: [WORK IN PROGRESS]
Quick Summary
This is the famous “Attention Is All You Need” paper by Vaswani et al. that introduced the Transformer architecture - a groundbreaking neural network model that revolutionized natural language processing.
Key Innovation
The paper proposes replacing traditional recurrent neural networks (RNNs) and convolutional networks with a model based entirely on attention mechanisms. The core insight is that self-attention can capture dependencies between words regardless of their distance in a sequence, without needing to process them sequentially.
Architecture Highlights
- Encoder-Decoder Structure: 6 layers each, with multi-head self-attention and feed-forward networks
- Multi-Head Attention: Uses 8 parallel attention heads to capture different types of relationships
- Positional Encoding: Sine/cosine functions to inject sequence order information
- No Recurrence: Enables much better parallelization during training
Results The Transformer achieved state-of-the-art performance on machine translation tasks:
- 28.4 BLEU on English-to-German (WMT 2014)
- 41.8 BLEU on English-to-French
- Trained significantly faster than previous models (12 hours vs. days/weeks)
Impact This architecture became the foundation for modern language models like BERT, GPT, and others. The paper’s core principle - that attention mechanisms alone are sufficient for high-quality sequence modeling - fundamentally changed how we approach NLP tasks.
The work demonstrated superior performance while being more parallelizable and interpretable than previous sequence-to-sequence models.
THE foundational paper that introduced some key ideas such as:
- Scaled dot-product attention
- Multi-head attention mechanism
- Positional encodings
- Layer normalization
- Masked attention for autoregressive models
We have talked deeply about each of these topics previously and I implore you to check that part out here.
Problem
Sequential models like RNNs and LSTMs process text word-by-word, creating a fundamental bottleneck: each word must wait for the previous word to be processed. This sequential nature makes training painfully slow and prevents the model from understanding long-range dependencies effectively.
For example, in the sentence “The cat that lived in the house with the red door was hungry”, by the time the model reaches “was hungry”, it has largely forgotten about “The cat” due to the vanishing gradient problem. The model struggles to connect distant but related words.
Solution
The Transformer replaced sequential processing with parallel attention mechanisms. Instead of processing words one-by-one, it looks at all words simultaneously and uses attention to determine which words are most relevant to each other, regardless of their distance in the sentence.
This attention-based approach allows the model to directly connect “The cat” with “was hungry” in a single step, while also enabling massive parallelization during training - turning what used to take weeks into hours.
Training a Transformer
This is one topic that we didn’t talk about extensively so let’s go over it, because that is where the true beauty of GPT lies. How to train over huge amounts of data.
We will build iteratively, first starting small. And going massive as we approach the GPT paper.
This blog helped me with this section.
Preparing the data
The original Transformer was trained for neural machine translation using English-German sentence pairs. The data preparation involves several crucial steps:
# Data preparation
import torch
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
def prepare_training_data(sentences):
# 1. Add special tokens
processed_sentences = []
for sentence in sentences:
processed_sentences.append("<START> " + sentence + " <EOS>")
# 2. Create vocabulary
vocab = build_vocab(processed_sentences)
vocab_size = len(vocab)
# 3. Convert to tensor sequences
sequences = []
for sentence in processed_sentences:
tokens = sentence.split()
sequence = torch.tensor([vocab[token] for token in tokens])
sequences.append(sequence)
# 4. Pad sequences
padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=0)
return padded_sequences, vocab_size
-
Special tokens (
<START>
and<EOS>
): These tell the model where sentences begin and end. The<START>
token signals the decoder to begin generation, while<EOS>
teaches it when to stop. Without these, the model wouldn’t know sentence boundaries. As we will move through the years, we will see how the special tokens used in LLMs have evolved as well. For example, think what will happen inside an LLM when it encounters a token that it hasn’t seen during training, like a chinese character etc. -
Vocabulary creation: The vocabulary maps every unique word/token in the training data to a number. This is how we convert text (which computers can’t process) into numerical tensors (which they can). The vocabulary size determines the final layer dimension of our model.
-
Tensor conversion: Neural networks work with numbers, not words. Each word gets replaced by its vocabulary index, creating sequences of integers that can be fed into the model.
-
Padding: Sentences have different lengths, but neural networks need fixed-size inputs for batch processing. Padding with zeros makes all sequences the same length, enabling efficient parallel computation.
Key Training Innovations
The Transformer introduced several training techniques that became standard:
Teacher Forcing with Masking
# During training, decoder sees target sequence shifted by one position
encoder_input = source_sequence
decoder_input = target_sequence[:, :-1] # Remove last token
decoder_output = target_sequence[:, 1:] # Remove first token
# Look-ahead mask prevents seeing future tokens
def create_look_ahead_mask(seq_len):
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
return mask.bool()
mask = create_look_ahead_mask(decoder_input.size(1))
Why this works: Teacher forcing trains the decoder to predict the next token given all previous tokens, without requiring separate training data. The input-output shift creates a “next token prediction” task from translation pairs. The look-ahead mask ensures the model can’t “cheat” by seeing future tokens during training - it must learn to predict based only on past context, just like during real inference.
Custom Learning Rate Schedule The paper introduced a specific learning rate scheduler that warms up then decays:
# Learning rate schedule from the paper
import math
class TransformerLRScheduler:
def __init__(self, optimizer, d_model=512, warmup_steps=4000):
self.optimizer = optimizer
self.d_model = d_model
self.warmup_steps = warmup_steps
self.step_count = 0
def step(self):
self.step_count += 1
lr = self.get_lr()
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
def get_lr(self):
arg1 = self.step_count ** -0.5
arg2 = self.step_count * (self.warmup_steps ** -1.5)
return (self.d_model ** -0.5) * min(arg1, arg2)
Why this schedule: The warmup phase gradually increases the learning rate, preventing the model from making drastic weight updates early in training when gradients are noisy. After warmup, the learning rate decays proportionally to the square root of the step number, allowing for fine-tuning as training progresses. This schedule was crucial for training stability with the Transformer’s deep architecture.
Padding Masks for Loss Computation
import torch.nn.functional as F
def masked_loss(target, prediction, pad_token=0):
# Don't compute loss on padding tokens
mask = (target != pad_token).float()
# Reshape for cross entropy
prediction = prediction.view(-1, prediction.size(-1))
target = target.view(-1)
mask = mask.view(-1)
# Compute cross entropy loss
loss = F.cross_entropy(prediction, target, reduction='none')
masked_loss = loss * mask
return masked_loss.sum() / mask.sum()
Why masking matters: Padding tokens (zeros) are artificial - they don’t represent real words. Computing loss on them would teach the model incorrect patterns and waste computational resources. The mask ensures we only compute loss on actual content, leading to more meaningful gradients and better learning. This also prevents the model from learning to predict padding tokens, which would be useless during inference.
Training Configuration
The original paper used these hyperparameters:
- Model size: 512 dimensions (base model)
- Attention heads: 8
- Encoder/Decoder layers: 6 each
- Feed-forward dimension: 2048
- Dropout: 0.1
- Optimizer: Adam with custom learning rate schedule
- Training time: ~12 hours on 8 P100 GPUs
The Training Loop
import torch
import torch.nn as nn
from torch.optim import Adam
def train_step(model, optimizer, scheduler, encoder_input, decoder_input, decoder_output):
model.train()
optimizer.zero_grad()
# Forward pass
prediction = model(encoder_input, decoder_input)
# Compute masked loss and accuracy
loss = masked_loss(decoder_output, prediction)
accuracy = masked_accuracy(decoder_output, prediction)
# Backward pass
loss.backward()
optimizer.step()
scheduler.step()
return loss.item(), accuracy.item()
# Main training loop
model = TransformerModel(src_vocab_size, tgt_vocab_size, d_model=512)
optimizer = Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)
scheduler = TransformerLRScheduler(optimizer, d_model=512)
for epoch in range(num_epochs):
for batch in dataloader:
src_batch, tgt_batch = batch
# Prepare inputs
encoder_input = src_batch[:, 1:] # Remove START token
decoder_input = tgt_batch[:, :-1] # Remove EOS token
decoder_output = tgt_batch[:, 1:] # Remove START token
loss, accuracy = train_step(
model, optimizer, scheduler,
encoder_input, decoder_input, decoder_output
)
if step % 100 == 0:
print(f'Epoch {epoch}, Step {step}, Loss: {loss:.4f}, Acc: {accuracy:.4f}')
Why This Training Approach Worked
- Parallelization: Unlike RNNs, all positions could be computed simultaneously
- Stable Gradients: Layer normalization and residual connections prevented vanishing gradients
- Efficient Attention: Scaled dot-product attention was computationally efficient
- Smart Masking: Prevented information leakage while enabling parallel training
This training methodology laid the groundwork for scaling to the massive language models we see today. The key insight was that with proper masking and attention mechanisms, you could train much larger models much faster than sequential architectures allowed.
While the original Transformer showed the power of attention-based training, it was still limited to translation tasks with paired data. The real revolution came when researchers realized they could use similar training techniques on massive amounts of unlabeled text data - setting the stage for GPT and the era of large language models.
RLHF - Reinforcement Learning from Human Preferences
Link to paper: Deep reinforcement learning from human preferences
Link to implementation: [WORK IN PROGRESS]
Quick Summary
This paper presents a method for training reinforcement learning (RL) agents using human feedback instead of explicitly refined reward functions. Here’s a high-level overview:
The authors address a fundamental challenge in RL: for many complex tasks, designing appropriate reward functions is difficult or impossible. Instead of requiring engineers to craft these functions, they develop a system where:
- Humans compare short video clips of agent behavior (1-2 seconds)
- These comparisons train a reward predictor model
- The agent optimizes its policy using this learned reward function
Key contributions:
- They show this approach can solve complex RL tasks using feedback on less than 1% of the agent’s interactions
- This dramatically reduces the human oversight required, making it practical for state-of-the-art RL systems
- They demonstrate training novel behaviors with just about an hour of human time
- Their approach works across domains including Atari games and simulated robot locomotion
The technique represents a significant advance in aligning AI systems with human preferences, addressing concerns about misalignment between AI objectives and human values. By having humans evaluate agent behavior directly, the system learns rewards that better capture what humans actually want.
As mind boggling as it sounds, the famed algorithm RLHF came out in 2017, the same year attention is all you need came out. Let us understand the ideas put forth and why it was such a big deal.
(If you are not familiar with the idea of RL, I will recommend checking this small course by HuggingFace out)
Problem
Designing reward functions for complex behaviors is nearly impossible. How do you mathematically define “write a helpful response” or “be creative but truthful”? Traditional RL requires explicit numerical rewards for every action, but many desirable behaviors are subjective and context-dependent.
For example, it’s impossible to write code that scores joke quality, but humans can easily compare two jokes and say which is funnier.
Solution :
One possible solution is to allow a human to provide feedback on the agents’s current behavior and use this feedback to define the task. But this poses another problem, this would require hundreds of hours as well as domain experience. It was discovered by the researchers that preference modeling by a human even on a small subset provided great results.
An ideal solution will
- Enable us to solve tasks about which we can tell the desired behavior but not necessarily demonstrate or describe it.
- Allows systems to learn from non-expert users
- Scales to large problems
- Is economical
In their experiment, the researchers asked labellers to compare short video clips of the agent’s behavior. They found that by using a small sample of clips they were able to train the system to behave as desired.
Image sourced from paper
The human observes the agent acting in the environment he then gives he’s feedback. Which is taken by reward predictor which numerical defines the reward. Which is sent to the RL algorithm this updates the agent based on the feedback and observation from the environment. That then changes the action of the agent.
This sounds simple enough in principle, but how do you teach a model to learn from these preferences. I.e reward modeling.
Note: We will be talking more in depth about RL algorithms in the next section. The topics in RL are rather complicated and usually talked in the end after an LLM is trained. So you can skip this part for now if it is daunting.
Reward predictor in RLHF
The following blogs helped me while writing this section
The reward predictor is trained to predict which of two given trajectories(σ¹, σ²) will be preferred by a human
Example:
Imagine two robot trajectories:
- Trajectory A: Robot goes directly to the goal
- Trajectory B: Robot roams around then goes to the goal
A human would prefer A (more efficient). The reward model learns to assign higher values to the observation-action pairs in trajectory A, eventually learning that “efficient movement” correlates with human preference.
Reward predictor equation
\[\hat{P}\left[\sigma^{1} \succ \sigma^{2}\right]=\frac{\exp \sum \hat{r}\left(o_{t}^{1}, a_{t}^{1}\right)}{\exp \sum \hat{r}\left(o_{t}^{1}, a_{t}^{1}\right)+\exp \sum \hat{r}\left(o_{t}^{2}, a_{t}^{2}\right)}\]It is trained using cross-entropy loss to match human preferences:
\[\operatorname{loss}(\hat{r})=-\sum_{\left(\sigma^{1}, \sigma^{2}, \mu\right) \in D} \mu(1) \log \hat{P}\left[\sigma^{1} \succ \sigma^{2}\right]+\mu(2) \log \hat{P}\left[\sigma^{2} \succ \sigma^{1}\right]\]Mathematical Notation
- $\hat{P}\left[\sigma^{1} \succ \sigma^{2}\right]$: Predicted probability that trajectory segment $\sigma^{1}$ is preferred over trajectory segment $\sigma^{2}$
- $\hat{r}$: The learned reward function
- $o_{t}^{i}$: Observation at time $t$ in trajectory segment $i$
- $a_{t}^{i}$: Action at time $t$ in trajectory segment $i$
- $\sigma^{i}$: Trajectory segment $i$ (a sequence of observation-action pairs)
- $\exp$: Exponential function
- $\sum$: Summation over all timesteps in the trajectory segment
- $\operatorname{loss}(\hat{r})$: Cross-entropy loss function for the reward model
- $D$: Dataset of human preference comparisons
- $\mu$: Distribution over ${1,2}$ indicating human preference
- $\mu(1)$: Probability that human preferred segment 1
- $\mu(2)$: Probability that human preferred segment 2
- $\log$: Natural logarithm
Let us understand the Reward Function Fitting Process
The Preference-Predictor Model
The authors instead of directly creating a reward function (which rewards an agent when it does the desired behavior and punishes otherwise), they created a preference predictor. Which predicts which of the two given sequence of actions will be preferred by a human.
The Mathematical Formulation (Equation 1)
The equation P̂[σ¹ ≻ σ²] represents the predicted probability that a human would prefer trajectory segment σ¹ over segment σ².
Breaking down the formula:
- $\sigma^{[1]}$ and $\sigma^{[2]}$ are two different trajectory segments (short video clips of agent behavior)
- $o_{t}^{[i]}$ and $a_{t}^{[i]}$ represent the observation and action at time $t$ in trajectory $i$
- $\hat{r}(o_{t}^{[i]}, a_{t}^{[i]})$ is the estimated reward for that observation-action pair
- The formula uses the softmax function (exponential normalization):
This means:
- Sum up all the predicted rewards along trajectory 1
- Sum up all the predicted rewards along trajectory 2
- Apply exponential function to both sums
- The probability of preferring trajectory 1 is the ratio of exp(sum1) to the total exp(sum1) + exp(sum2)
The Loss Function
The goal is to find parameters for r̂ that make its predictions match the actual human preferences:
\[\operatorname{loss}(\hat{r}) = -\sum_{\left(\sigma^{[1]}, \sigma^{[2]}, \mu\right) \in D} \left[\mu([1])\log \hat{P}\left[\sigma^{[1]} \succ \sigma^{[2]}\right] + \mu([2])\log \hat{P}\left[\sigma^{[2]} \succ \sigma^{[1]}\right]\right]\]Where:
- $\left(\sigma^{[1]}, \sigma^{[2]}, \mu\right) \in D$ means we’re summing over all the comparison data in our dataset $D$
- $\mu$ is a distribution over ${1,2}$ indicating which segment the human preferred
- If the human strictly preferred segment 1, then $\mu([1]) = 1$ and $\mu([2]) = 0$
- If the human strictly preferred segment 2, then $\mu([1]) = 0$ and $\mu([2]) = 1$
- If the human found them equal, then $\mu([1]) = \mu([2]) = 0.5$
This is the standard cross-entropy loss function used in classification problems, measuring how well our predicted probabilities match the actual human judgments.
Consider reading this beautiful blog on Entropy by Christopher Olah, if you wish to gain a deeper understanding of cross-entropy.
The Bradley-Terry Model Connection
Note from Wikipedia: The Bradley–Terry model is a probability model for the outcome of pairwise comparisons between items, teams, or objects. Given a pair of items $i$ and $j$ drawn from some population, it estimates the probability that the pairwise comparison $i > j$ turns out true, as
\[\Pr(i>j) = \frac{p_i}{p_i + p_j}\]where $p_i$ is a positive real-valued score assigned to individual $i$. The comparison $i > j$ can be read as “i is preferred to j”, “i ranks higher than j”, or “i beats j”, depending on the application.
This approach is based on the Bradley-Terry model, which is a statistical model for paired comparisons. It’s similar to:
-
The Elo rating system in chess: Players have ratings, and the difference in ratings predicts the probability of one player beating another.
-
In this case: Trajectory segments have “ratings” (the sum of rewards), and the difference in ratings predicts the probability of a human preferring one segment over another.
In essence, the reward function learns to assign higher values to states and actions that humans tend to prefer, creating a preference scale that can be used to guide the agent’s behavior.
The most important breakthrough: We can align AI systems with human values using comparative feedback from non-experts. This insight would prove crucial when training language models - instead of trying to define “helpful” or “harmless” mathematically, we can simply ask humans to compare outputs.
This comparative approach scales much better than rating individual responses, making it practical for training large language models on human preferences.
Fun story: One time researchers tried to RL a helicopter and it started flying backwards |
PPO: Proximal Policy Optimization
Link to paper: Proximal Policy Optimization Algorithms
Link to implementation: [WORK IN PROGRESS]
Quick Summary
This paper by John Schulman et al. from OpenAI introduces Proximal Policy Optimization (PPO), a family of policy gradient methods for reinforcement learning that achieves the reliability and data efficiency of Trust Region Policy Optimization (TRPO) while being much simpler to implement and more compatible with various neural network architectures.
Key contributions:
- A novel “clipped” surrogate objective function that provides a pessimistic estimate of policy performance
- An algorithm that alternates between data collection and multiple epochs of optimization on the same data
- Empirical validation showing PPO outperforms other online policy gradient methods across continuous control tasks and Atari games
- A balance between sample complexity, implementation simplicity, and computation time
The core innovation is their clipped probability ratio approach, which constrains policy updates without requiring the complex second-order optimization techniques used in TRPO. This makes PPO more practical while maintaining performance guarantees.
Another LLM algo that came out in 2017, and that too again by OpenAI. Really goes to show how much they tried to advance AI and be public about it (At least in the early days).
This is going to be math heavy so be prepared (Dw, I will guide you in each step)
Problem
However, there is room for improvement in developing a method that is scalable (to large models and parallel implementations), data efficient, and robust (i.e., successful on a variety of problems without hyperparameter tuning). Q-learning (with function approximation) fails on many simple problems and is poorly understood, vanilla policy gradient methods have poor data effiency and robustness; and trust region policy optimization (TRPO) is relatively complicated, and is not compatible with architectures that include noise (such as dropout) or parameter sharing (between the policy and value function, or with auxiliary tasks).
Essentially there were a lot of RL algorithms, but none of them worked efficiently at scale.
Solution
This paper seeks to improve the current state of affairs by introducing an algorithm that attains the data efficiency and reliable performance of TRPO, while using only first-order optimization. We propose a novel objective with clipped probability ratios, which forms a pessimistic estimate (i.e., lower bound) of the performance of the policy. To optimize policies, we alternate between sampling data from the policy and performing several epochs of optimization on the sampled data
The authors found a way to take the best RL algorithm of the time (TRPO) and make it work at scale.
The following blogs & articles helped me write this section
- Spinning up docs by OpenAI, consider going through this to help understand the nomenclature used throughout this section
- RL blogs by jonathan hui, they really simplified the ideas for me
- Understanding Policy Gradients, this blog really helped me understand the math behind the idea
- These blogs were extremely helpful too (each word is a different link)
- The bible of modern RL
What is Reinforcement Learning
Image taken from HuggingFace Course
In RL we create an Agent (An ML model like Artificial Neural Network) give it a defined set of Actions $A_t$ (In this case it would be, move left, move right, Press A to shoot).
The agent then chooses an action and interacts with the Environment, which returns a new state as well as reward (positive if we survived or did a favourable outcome, negative if we die or do an unfavourable outcome).
Step by Step it looks something like this:
- The agent recieves state $S_0$ from the environment (In this that would be the first frame of the game)
- Based on state $S_0$, the agent takes action $A_0$ (chooses to move right)
- The environment goes to new frame, new state $S_1$.
- The environment gives the agent, reward $R_t$ (still alive!!!).
The idea behind RL is based on reward hypothesis, which states that
All goals can be described as the maximization of the expected return (expected cumulative reward) |
Which can be mathematically represented as $R(\tau) = r_{t+1} + r_{t+2} + r_{t+3} + r_{t+4} + \ldots$ ($\tau$ read as tau)
Remember this, It will prove useful later.
Policy π: The Agent’s Brain
The Policy π is the brain of our Agent, it’s the function that tells an Agent what action it should take at a given state and time.
The policy is what we want to train and make an optimum policy π*, that maximizes expected return when the agent acts according to it. (remember that is the idea behind RL)
Image taken from OpenAI Spinning Up
There are many RL algorithms present that we can use to train the policy as you can see from the image above, But most of them are developed from two central methods:
- Policy based methods : Directly, by teaching the agent to learn which action to take, given the current state
- Value based methods : Indirectly, teach the agent to learn which state is more valuable and then take the action that leads to the more valuable states
Image taken from HuggingFace Course
(Don’t get scared by the equations, I will explain them as we move forward. Also, this was a quick recap of RL, for a better deep dive. Consider going through the HF course)
As this section is dedicated to PPO, I will primarily be talking about the topics concerned with it. It can broadly be put in the following order:
- Policy Gradient Methods
- TRPO
- PPO
I am skipping over many other intersting and amazing algorithms like Q-Learning, DQN, Actor-critic etc. As they are not relevant to this section. I still implore you to explore them through the links I have provided to get a better, broader and deeper grasp of RL.
Before we move to the next section, I want to talk about a question that baffled me when I started learning about RL.
“Why do we need a value based approach”
Policy based approach seem to work great and are intuitive as well, given a state, choose an action. Then why do we use value based approaches. Needless complexity. Think for a minute then see the answer
Answer
Value-based methods shine in scenarios where policy-based methods struggle:
1. Discrete Action Spaces with Clear Optimal Actions In environments like Atari games or grid worlds, there’s often a single best action for each state. Value-based methods (like DQN) can directly learn which action has the highest expected return, making them sample-efficient for these deterministic scenarios.
2. Exploration Efficiency Value functions provide natural exploration strategies. Methods like ε-greedy or UCB can systematically explore based on value estimates. Policy methods often struggle with exploration, especially in sparse reward environments where random policy perturbations rarely discover good behavior.
3. Off-Policy Learning Value-based methods can learn from any data - even old experiences stored in replay buffers. This makes them incredibly sample-efficient. Policy methods traditionally required on-policy data, though modern techniques like importance sampling have bridged this gap.
4. Computational Efficiency In discrete action spaces, value-based methods often require just one forward pass to select an action (argmax over Q-values). Policy methods might need to sample from complex probability distributions or solve optimization problems.
Where Policy Methods Fail:
- High-dimensional discrete actions: Computing argmax becomes intractable
- Continuous control: You can’t enumerate all possible actions to find the maximum
- Stochastic optimal policies: Sometimes the best strategy is inherently random (like rock-paper-scissors), which value methods can’t represent directly
The truth is, both approaches are complementary tools for different types of problems.
Policy Gradient Methods
Policy gradient methods directly optimize a policy function by adjusting its parameters in the direction of greater expected rewards. They work by:
- Collecting experience (state-action pairs and rewards) using the current policy
- Estimating the policy gradient (the direction that would improve the policy)
- Updating the policy parameters using this gradient
The Gradient Estimator
In our discussion so far, we talked about deterministic policy based methods. Ie given a state, choose an action $\pi(s) = a$. But when we are talking about policy gradients, we use a stochastic policy based method. Ie given a state, return a probability distribution of actions $\pi(a|s) = P[A|s]$.
We also need to be aware of a few terms and mathematical tricks before moving forward:
-
Trajectory: A series of state action pair is called a trajectory.
\[\tau = (s_1,a_1,s_2,a_2,\ldots,s_H,a_H)\] -
Log derivative trick:
\[\nabla_\theta \log z = \frac{1}{z} \nabla_\theta z\]This trick allows us to convert the gradient of a probability into the gradient of its logarithm, which is computationally more stable and easier to work with.
(To derive it just apply chain rule and know that the derivative of $\log(x)$ = $1/x$)
-
Definition of Expectation:
For discrete distributions: \(\mathbb{E}_{x \sim p(x)}[f(x)] = \sum_x p(x)f(x) \tag{1}\)
For continuous distributions: \(\mathbb{E}_{x \sim p(x)}[f(x)] = \int_x p(x)f(x) \, dx \tag{2}\)
If you are new to the idea of expectation, Consider checking this amazing blog on the topic.
Deriving the Policy Gradient
Let $\tau$ be a trajectory (sequence of state-action pairs), $\theta$ be the weights of our neural network policy. Our policy $\pi_\theta$ outputs action probabilities that depend upon the current state and network weights.
We begin with the reward hypothesis: we want to maximize $R(\tau)$ where $\tau$ is a trajectory.
We can write the objective as the probability of a trajectory being chosen by the policy multiplied by the reward for that trajectory:
\[J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[R(\tau)] = \sum_\tau \pi_\theta(\tau)R(\tau)\]This formulation is crucial because it connects:
- $\pi_\theta(\tau)$: How likely our current policy is to generate trajectory $\tau$
- $R(\tau)$: How much reward we get from that trajectory
For continuous trajectory spaces, we can write this as:
\[J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[R(\tau)] = \int \pi_\theta(\tau)R(\tau)d\tau\]Now we can derive the policy gradient by taking the gradient of our objective:
\[\nabla_\theta J(\theta) = \nabla_\theta \int \pi_\theta(\tau)R(\tau)d\tau \tag{3}\] \[= \int \nabla_\theta \pi_\theta(\tau)R(\tau)d\tau \tag{4}\] \[= \int \pi_\theta(\tau) \frac{\nabla_\theta \pi_\theta(\tau)}{\pi_\theta(\tau)} R(\tau)d\tau \tag{5}\] \[= \int \pi_\theta(\tau) \nabla_\theta \log \pi_\theta(\tau) R(\tau)d\tau \tag{6}\] \[= \mathbb{E}_{\tau \sim \pi_\theta}[\nabla_\theta \log \pi_\theta(\tau) R(\tau)] \tag{7}\]Step-by-step explanation:
- (3) Start with gradient of our objective function
- (4) Push gradient inside the integral
- (5) Multiply and divide by $\pi_\theta(\tau)$
- (6) Apply the log derivative trick: $\nabla_\theta \log(z) = \frac{1}{z} \nabla_\theta z$
- (7) Convert back to expectation form
The trajectory probability factors as: \(\pi_\theta(\tau) = \prod_{t=0}^{T} \pi_\theta(a_t|s_t)\)
So the log probability becomes: \(\log \pi_\theta(\tau) = \sum_{t=0}^{T} \log \pi_\theta(a_t|s_t)\)
What does this mean for us? If you want to maximize your expected reward, you can use gradient ascent. The gradient of the expected reward has an elegant form - it’s simply the expectation of the trajectory return times the sum of log probabilities of actions taken in that trajectory.
In reinforcement learning, a trajectory $\tau = (s_1, a_1, s_2, a_2, \ldots, s_T, a_T)$ is generated through a sequential process. The probability of observing this specific trajectory under policy $\pi_\theta$ comes from the chain rule of probability.
This is quite complex to intuitively understand in my opinion. Consider going through this stack exchange. Intuition: Let’s calculate the joint probability of a sequence like $P(\text{sunny weather, white shirt, ice cream})$ - what’s the chance it’s sunny outside, I’m wearing a white shirt, and I chose to eat ice cream all happening together? We can break this down step by step: First, what’s the probability it’s sunny outside? That’s $P(\text{sunny})$. Given that it’s sunny, what are the chances I wear a white shirt? That’s $P(\text{white shirt | sunny})$. Finally, given it’s sunny and I’m wearing white, what’s the probability I eat ice cream? That’s $P(\text{ice cream | sunny, white shirt})$. \(P(\text{sunny, white shirt, ice cream}) = P(\text{sunny}) \cdot P(\text{white shirt | sunny}) \cdot P(\text{ice cream | sunny, white shirt})\) By multiplying these conditional probabilities, we get the full joint probability. In reinforcement learning, trajectories work the same way: $P(s_1, a_1, s_2, a_2, \ldots)$ breaks down into “what state do we start in?” then “what action do we take?” then “where do we transition?” and so on. Each step depends only on what happened before, making complex trajectory probabilities manageable to compute and optimize. |
The joint probability of a sequence of events can be factored as: \(P(s_1, a_1, s_2, a_2, \ldots, s_T, a_T) = P(s_1) \cdot P(a_1|s_1) \cdot P(s_2|s_1, a_1) \cdot P(a_2|s_1, a_1, s_2) \cdots\)
However, in the Markov Decision Process (MDP) setting, we have two key assumptions:
- Markov Property: Next state depends only on current state and action: $P(s_{t+1}|s_1, a_1, \ldots, s_t, a_t) = P(s_{t+1}|s_t, a_t)$
- Policy Markov Property: Action depends only on current state: $P(a_t|s_1, a_1, \ldots, s_t) = \pi_\theta(a_t|s_t)$
Chapter 3 of RL book by Sutton and Barto covers the topic well |
Applying these assumptions:
\[\pi_\theta(\tau) = \pi_\theta(s_1, a_1, \ldots, s_T, a_T) = p(s_1) \prod_{t=1}^{T} \pi_\theta(a_t|s_t)p(s_{t+1}|s_t, a_t)\] \[\underbrace{p(s_1) \prod_{t=1}^{T} \pi_\theta(a_t|s_t)p(s_{t+1}|s_t, a_t)}_{\pi_\theta(\tau)}\]- $p(s_1)$: Initial state distribution (environment dependent)
- $\pi_\theta(a_t|s_t)$: Policy probability of choosing action $a_t$ in state $s_t$
- $p(s_{t+1}|s_t, a_t)$: Environment transition probability (environment dependent)
When we take the log of a product, it becomes a sum:
\[\log \pi_\theta(\tau) = \log p(s_1) + \sum_{t=1}^{T} \log \pi_\theta(a_t|s_t) + \sum_{t=1}^{T} \log p(s_{t+1}|s_t, a_t)\]The first and last terms do not depend on $\theta$ and can be removed when taking gradients(and this is often done in practice):
- $\log p(s_1)$: Initial state is determined by environment, not our policy
- $\log p(s_{t+1}|s_t, a_t)$: Environment dynamics don’t depend on our policy parameters
Therefore: \(\nabla_\theta \log \pi_\theta(\tau) = \nabla_\theta \sum_{t=1}^{T} \log \pi_\theta(a_t|s_t) = \sum_{t=1}^{T} \nabla_\theta \log \pi_\theta(a_t|s_t)\)
So the policy gradient: \(\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[\nabla_\theta \log \pi_\theta(\tau) R(\tau)]\)
becomes: \(\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}\left[\left(\sum_{t=1}^{T} \nabla_\theta \log \pi_\theta(a_t|s_t)\right) R(\tau)\right]\)
The trajectory return $R(\tau)$ is the total reward collected along the trajectory: \(R(\tau) = \sum_{t=1}^{T} r(s_t, a_t)\)
So our gradient becomes: \(\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}\left[\left(\sum_{t=1}^{T} \nabla_\theta \log \pi_\theta(a_t\|s_t)\right) \left(\sum_{t=1}^{T} r(s_t, a_t)\right)\right]\)
How do we compute expectations in practice?
We can’t compute the expectation $\mathbb{E}_{\tau \sim \pi{\theta}}[\cdot]$ analytically because:
- There are infinitely many possible trajectories
- We don’t know the environment dynamics $p(s_{t+1}|s_t, a_t)$
Instead, we use Monte Carlo sampling:
- Collect $N$ sample trajectories by running our current policy: ${\tau_1, \tau_2, \ldots, \tau_N}$
- Approximate the expectation using the sample average:
Applying Monte Carlo approximation
This is a fabulous video to understand Monte Carlo approximation.
Substituting our specific function: \(f(\tau) = \left(\sum_{t=1}^{T} \nabla_\theta \log \pi_\theta(a_t|s_t)\right) \left(\sum_{t=1}^{T} r(s_t, a_t)\right)\)
We get: \(\boxed{\nabla_\theta J(\theta) \approx \frac{1}{N} \sum_{i=1}^{N} \left(\sum_{t=1}^{T} \nabla_\theta \log \pi_\theta(a_{i,t}|s_{i,t})\right) \left(\sum_{t=1}^{T} r(s_{i,t}, a_{i,t})\right)}\)
\[\boxed{\theta \leftarrow \theta + \alpha \nabla_\theta J(\theta)}\]Where:
- $i$ indexes the sampled trajectories ($1$ to $N$)
- $t$ indexes time steps within each trajectory ($1$ to $T$)
- $(s_{i,t}, a_{i,t})$ is the state-action pair at time $t$ in trajectory $i$
The elegant result is that we only need gradients of our policy’s action probabilities - the environment dynamics completely disappear from our gradient computation! This makes policy gradients model-free and widely applicable.
And we use this policy gradient to update the policy $\theta$.
To get an intuition behind the idea consider reading the intuition part of this blog.
Policy Gradient for Continuous Space
So far, we’ve been working with discrete action spaces, like our super mad bot game where you can move left, move right, or press A to shoot. But what happens when your agent needs to control a robot arm, steer a car, or even select the “best” next token in language model fine-tuning? Welcome to the world of continuous control!
In discrete spaces, our policy outputs probabilities for each possible action:
- Move left: 30%
- Move right: 45%
- Shoot: 25%
But in continuous spaces, actions are real numbers. Imagine trying to control a robot arm where the joint angle can be any value between -180° and +180°. You can’t enumerate probabilities for every possible angle, there are infinitely many! (like in real numbers, you cannot even count the numbers present between 179 and 180… Where do you even begin?)
The solution is to make our neural network output parameters of a probability distribution (eg mean and standard deviation of a normal distribution) instead of individual action probabilities. Specifically, we use a Gaussian (normal) distribution.
Here’s how it works:
Instead of: $\pi_\theta(a_t|s_t) = \text{[probability for each discrete action]}$
We use: $\pi_\theta(a_t|s_t) = \mathcal{N}(f_{\text{neural network}}(s_t); \Sigma)$
Let’s break it down:
- Feed the state $s_t$ into your neural network
- Network outputs the mean $\mu = f_{\text{neural network}}(s_t)$ - this is the “preferred” action
- Choose a covariance matrix $\Sigma$ - this controls how much exploration/uncertainty around that mean
- Sample the actual action from the Gaussian: $a_t \sim \mathcal{N}(\mu, \Sigma)$
Now comes the amazing part. Remember our policy gradient formula?
\[\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}\left[\left(\sum_{t=1}^{T} \nabla_\theta \log \pi_\theta(a_t|s_t)\right) R(\tau)\right]\]The exact same formula still applies! We just need to compute $\nabla_\theta \log \pi_\theta(a_t|s_t)$ differently.
Let’s start with what a Multivariant Gaussian distribution actually looks like. For continuous actions, we assume they follow this probability density function:
\[f(x) = \frac{1}{(2\pi)^{d/2} |\Sigma|^{1/2}} \exp\left\{-\frac{1}{2}(x - \mu)^T \Sigma^{-1} (x - \mu)\right\}\]This looks scary, but it’s just the mathematical way of saying: “actions are most likely to be near the mean $\mu$, with spread determined by covariance $\Sigma$.”
(To understand where this idea comes from, read 13.7 from RL by Sutton and Barton)
Now, since our policy $\pi_\theta(a_t|s_t) = \mathcal{N}(f_{\text{neural network}}(s_t); \Sigma)$, we have:
\[\log \pi_\theta(a_t|s_t) = \log f(a_t)\]Taking the logarithm of our Gaussian:
\[\log \pi_\theta(a_t|s_t) = \log\left[\frac{1}{(2\pi)^{d/2} |\Sigma|^{1/2}} \exp\left\{-\frac{1}{2}(a_t - \mu)^T \Sigma^{-1} (a_t - \mu)\right\}\right]\]Using properties of logarithms ($\log(AB) = \log A + \log B$ and $\log(e^x) = x$):
\[\log \pi_\theta(a_t|s_t) = \log\left[\frac{1}{(2\pi)^{d/2} |\Sigma|^{1/2}}\right] - \frac{1}{2}(a_t - \mu)^T \Sigma^{-1} (a_t - \mu)\]The first term is just a constant (doesn’t depend on our neural network parameters $\theta$), so we can ignore it when taking gradients:
\[\log \pi_\theta(a_t|s_t) = -\frac{1}{2}(a_t - \mu)^T \Sigma^{-1} (a_t - \mu) + \text{const}\]Since $\mu = f_{\text{neural network}}(s_t)$, we can rewrite this as:
\[\log \pi_\theta(a_t|s_t) = -\frac{1}{2}||f(s_t) - a_t||^2_\Sigma + \text{const}\]Both the above equations are the same, it’s just a shorthand of writing it this way. It is also known as Mahalanobis distance squared.
Now we can compute the gradient with respect to our network parameters $\theta$:
\[\nabla_\theta \log \pi_\theta(a_t|s_t) = \nabla_\theta \left[-\frac{1}{2}(a_t - f(s_t))^T \Sigma^{-1} (a_t - f(s_t))\right]\]Let’s define $u = a_t - f(s_t)$ to simplify notation. Our expression becomes:
\[\nabla_\theta \log \pi_\theta(a_t|s_t) = \nabla_\theta \left[-\frac{1}{2} u^T \Sigma^{-1} u\right]\]Since $a_t$ and $\Sigma^{-1}$ don’t depend on $\theta$, we have:
\[\frac{\partial u}{\partial \theta} = \frac{\partial}{\partial \theta}(a_t - f(s_t)) = -\frac{\partial f(s_t)}{\partial \theta}\]For the quadratic form $u^T \Sigma^{-1} u$, using the chain rule:
\[\frac{\partial}{\partial \theta}(u^T \Sigma^{-1} u) = \frac{\partial u^T}{\partial \theta} \Sigma^{-1} u + u^T \Sigma^{-1} \frac{\partial u}{\partial \theta}\]Since $\Sigma^{-1}$ is symmetric, we can write:
\[\frac{\partial}{\partial \theta}(u^T \Sigma^{-1} u) = 2 u^T \Sigma^{-1} \frac{\partial u}{\partial \theta}\]Substituting back our expressions:
\[\nabla_\theta \log \pi_\theta(a_t|s_t) = -\frac{1}{2} \cdot 2 \cdot u^T \Sigma^{-1} \frac{\partial u}{\partial \theta}\] \[= -u^T \Sigma^{-1} \left(-\frac{\partial f(s_t)}{\partial \theta}\right)\] \[= u^T \Sigma^{-1} \frac{\partial f(s_t)}{\partial \theta}\]Substituting $u = a_t - f(s_t)$ back:
\[\nabla_\theta \log \pi_\theta(a_t|s_t) = (a_t - f(s_t))^T \Sigma^{-1} \frac{\partial f(s_t)}{\partial \theta}\]Since $\Sigma^{-1}$ is symmetric, $(a_t - f(s_t))^T \Sigma^{-1} = \Sigma^{-1}(a_t - f(s_t))$ when treated as a row vector, so we can write:
\[\nabla_\theta \log \pi_\theta(a_t|s_t) = \Sigma^{-1}(a_t - f(s_t)) \frac{\partial f(s_t)}{\partial \theta}\]Rearranging to match the original form:
\[\nabla_\theta \log \pi_\theta(a_t|s_t) = -\Sigma^{-1}(f(s_t) - a_t) \frac{\partial f(s_t)}{\partial \theta}\]This gradient has a beautiful intuitive interpretation:
- $(f(s_t) - a_t)$: The difference between what your network predicted and the action you actually took
- $\frac{df}{d\theta}$: How to change the network parameters to affect the output
- $\Sigma^{-1}$: Weighting factor (less weight for high-variance directions)
When you collect experience and compute rewards, here’s what happens:
- Good action taken ($R(\tau) > 0$): The gradient pushes $f(s_t)$ closer to the good action $a_t$
- Bad action taken ($R(\tau) < 0$): The gradient pushes $f(s_t)$ away from the bad action $a_t$
- Standard backpropagation: This gradient flows back through the network to update $\theta$
Our policy gradient update remains: \(\theta \leftarrow \theta + \alpha \nabla_\theta J(\theta)\)
The only difference is how we compute $\nabla_\theta \log \pi_\theta(a_t|s_t)$:
- Discrete case: Gradient of softmax probabilities
- Continuous case: Gradient of Gaussian log-likelihood (what we just derived!)
Everything else stays identical - collect trajectories, compute returns, update parameters. The same core algorithm seamlessly handles both discrete and continuous control problems!
Policy Gradient Improvements
There are two methods in which RL is trained
- Monte Carlo Learning: Cummulative reward of the entire episode (Entire run of the enviorment)
- Temporal Difference Learning: Reward is used to update policy in every step
Image taken from Reinforcement Learning and Bandits for Speech and Language Processing: Tutorial, Review and Outlook
Policy Gradient (PG) uses MC this causes it to have low bias (Expected reward is close to actual reward, as the same policy is used throughout the run) but high variance (Some runs produce great results, some really bad).
A stack exchange on bias & variance in RL |
Remember, our policy gradient formula is:
\[\nabla_\theta J(\theta) \approx \frac{1}{N} \sum_{i=1}^{N} \left(\sum_{t=1}^{T} \nabla_\theta \log \pi_\theta(a_{i,t}|s_{i,t})\right) \left(\sum_{t=1}^{T} r(s_{i,t}, a_{i,t})\right)\]We can rewrite this more compactly as:
\[\nabla_\theta J(\theta) \approx \frac{1}{N} \sum_{i=1}^{N} \sum_{t=1}^{T} \nabla_\theta \log \pi_\theta(a_{i,t}|s_{i,t}) \cdot Q(s_{i,t}, a_{i,t})\]Where $Q(s,a)$ represents the total reward we get from taking action $a$ in state $s$ (this is called the Q-function or action-value function).
The Baseline Trick
Here’s a mathematical insight: we can subtract any term from our gradient as long as that term doesn’t depend on our policy parameters $\theta$.
Why? Because: \(\nabla_\theta [f(\theta) - c] = \nabla_\theta f(\theta) - \nabla_\theta c = \nabla_\theta f(\theta) - 0 = \nabla_\theta f(\theta)\)
So instead of using $Q(s,a)$ directly, we can use $Q(s,a) - V(s)$, where $V(s)$ is some baseline function.
The most natural choice for baseline is $V(s) =$ the expected reward from state $s$ (The value function). This represents “how good is this state on average?”
Our new gradient becomes: \(\nabla_\theta J(\theta) \approx \frac{1}{N} \sum_{i=1}^{N} \sum_{t=1}^{T} \nabla_\theta \log \pi_\theta(a_{i,t}|s_{i,t}) \cdot (Q(s_{i,t}, a_{i,t}) - V(s_{i,t}))\)
This is defined as the Advantage Function: \(A^{\pi}(s,a) = Q^{\pi}(s,a) - V^{\pi}(s)\)
The advantage function answers the question: “How much better is taking action $a$ in state $s$ compared to the average action in that state?”
- $A(s,a) > 0$: Action $a$ is better than average → increase its probability
- $A(s,a) < 0$: Action $a$ is worse than average → decrease its probability
- $A(s,a) = 0$: Action $a$ is exactly average → no change needed
Our final policy gradient becomes: \(\boxed{\nabla_\theta J(\theta) \approx \frac{1}{N} \sum_{i=1}^{N} \sum_{t=1}^{T} \nabla_\theta \log \pi_\theta(a_{i,t}|s_{i,t}) \cdot A^{\pi}(s_{i,t}, a_{i,t})}\)
Let’s understand why this reduces variance with an example:
Situation 1: Trajectory A gets +10 rewards, Trajectory B gets -10 rewards
- If average performance is 0: $A_A = +10$, $A_B = -10$
- Result: Increase A’s probability, decrease B’s probability ✓
Situation 2: Trajectory A gets +10 rewards, Trajectory B gets +1 rewards
- If average performance is +5.5: $A_A = +4.5$, $A_B = -4.5$
- Result: Increase A’s probability, decrease B’s probability ✓
Even when both trajectories have positive rewards, the advantage function correctly identifies which one is relatively better!
In deep learning, we want input features to be zero-centered. The advantage function does exactly this for our rewards:
- Without baseline: All positive rewards → always increase probabilities
- With advantage: Rewards centered around zero → increase good actions, decrease bad ones
This gives our policy gradient much clearer, less conflicting signals, significantly reducing variance and improving convergence.
Vanilla Policy Gradient Algorithm
Now that we understand the advantage function, let’s see how it all comes together in the complete algorithm:
\[\nabla U(\theta) \approx \hat{g} = \frac{1}{m} \sum_{i=1}^{m} \nabla_\theta \log P(\tau^{(i)}; \theta)(R(\tau^{(i)}) - b)\](The notation may change from paper to paper, but the core idea remains the same)
Image taken from RL — Policy Gradient Explained
Reward Discount
There’s one more important technique that further reduces variance: reward discounting.
Reward discount reduces variance by reducing the impact of distant actions. The intuition is that actions taken now should have more influence on immediate rewards than on rewards received far in the future.
You can think of it in terms of money, would rather have money right now, or have it later.
Instead of using the raw cumulative reward, we use a discounted return:
\[Q^{\pi,\gamma}(s, a) \leftarrow r_0 + \gamma r_1 + \gamma^2 r_2 + \cdots | s_0 = s, a_0 = a\]Where:
- $\gamma \in [0,1]$ is the discount factor
- $\gamma = 0$: Only immediate rewards matter
- $\gamma = 1$: All future rewards are equally important
- $\gamma \approx 0.99$: Common choice that slightly prioritizes near-term rewards
The corresponding objective function becomes:
\[\nabla_\theta J(\theta) \approx \frac{1}{N} \sum_{i=1}^{N} \sum_{t=1}^{T} \nabla_\theta \log \pi_\theta(a_{i,t}|s_{i,t}) \left(\sum_{t'=t}^{T} \gamma^{t'-t} r(s_{i,t'}, a_{i,t'})\right)\]Why Discounting Helps:
- Reduces variance: Distant rewards have less influence, so random events far in the future don’t dominate the gradient
- Focuses learning: The agent learns to optimize for more predictable, near-term outcomes
- Mathematical stability: Prevents infinite returns in continuing tasks
All of this comprises the complete Vanila Policy Gradient Algorithm which serves as the foundation for more advanced methods like PPO, TRPO, and GRPO, which we’ll explore in subsequent sections.
TRPO
The Sample Efficiency Problem
Our vanilla policy gradient algorithm works, but it has a critical flaw that makes it impractical for real-world applications. Let’s examine what happens during training:
- Collect trajectories using current policy π_θ
- Compute gradients from these trajectories
- Update policy θ → θ_new
- Throw away all previous data and start over
This last step is the problem. Imagine training a robot to walk - every time you make a small adjustment to the policy, you must collect entirely new walking data and discard everything you learned before. For complex tasks requiring thousands of timesteps per trajectory, this becomes computationally prohibitive.
Recall our policy gradient formula:
\[\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}\left[\sum_{t=1}^{T} \nabla_\theta \log \pi_\theta(a_t|s_t) \cdot A(s_t, a_t)\right]\]The expectation $\mathbb{E}_{\tau \sim \pi \theta}$ means we must sample trajectories using the current policy π_θ. When we update θ, this distribution changes, invalidating all our previous samples.
Importance Sampling
What if we could reuse old data to estimate the performance of our new policy? This is exactly what importance sampling enables. The core idea is beautifully simple:
If you want to compute an expectation under distribution p, but you have samples from distribution q, you can reweight the samples by the ratio p/q. |
For any function f(x), the expectation under distribution p can be computed as:
\[\mathbb{E}_{x \sim p}[f(x)] = \sum_x p(x)f(x)\]But using importance sampling, we can compute this same expectation using samples from a different distribution q:
\[\mathbb{E}_{x \sim p}[f(x)] = \sum_x p(x)f(x) = \sum_x \frac{p(x)}{q(x)} \cdot q(x)f(x) = \mathbb{E}_{x \sim q}\left[\frac{p(x)}{q(x)} f(x)\right]\]The magic happens in that middle step - we multiply and divide by q(x), creating a ratio p(x)/q(x) that reweights our samples.
Let’s see this in action with an example. Suppose we want to compute the expected value of f(x) = x under two different distributions:
Distribution p: P(x=1) = 0.5, P(x=3) = 0.5
Distribution q: P(x=1) = 0.8, P(x=3) = 0.2
Direct calculation under p: \(\mathbb{E}_{x \sim p}[f(x)] = 0.5 \times 1 + 0.5 \times 3 = 2.0\)
Using importance sampling with samples from q:
If we sample from q and get samples [1, 1, 1, 3], we can estimate the expectation under p by reweighting:
For x=1: weight = p(1)/q(1) = 0.5/0.8 = 0.625
For x=3: weight = p(3)/q(3) = 0.5/0.2 = 2.5
The reweighted result matches our direct calculation!
Now we can revolutionize our policy gradient approach. Instead of:
\[\mathbb{E}_{\tau \sim \pi_\theta}[f(\tau)]\]We can use:
\[\mathbb{E}_{\tau \sim \pi_{\theta_{old}}}\left[\frac{\pi_\theta(\tau)}{\pi_{\theta_{old}}(\tau)} f(\tau)\right]\]Remember that trajectory probabilities factor as: \(\pi_\theta(\tau) = {p(s_1) \prod_{t=1}^{T} \pi_\theta(a_t|s_t)p(s_{t+1}|s_t, a_t)}\)
The environment dynamics $p(s_{t+1}|s_t, a_t)$ abd $p(s_1)$ are the same for both policies, so they cancel out in the ratio:
\[\frac{\pi_\theta(\tau)}{\pi_{\theta_{old}}(\tau)} = \frac{\prod_{t=1}^{T} \pi_\theta(a_t\|s_t)}{\prod_{t=1}^{T} \pi_{\theta_{old}}(a_t\|s_t)} = \prod_{t=1}^{T} \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}\]Our objective becomes:
\[J(\theta) = \mathbb{E}_{\tau \sim \pi_{\theta_{old}}}\left[\prod_{t=1}^{T} \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} \cdot R(\tau)\right]\]This is huge! We can now:
- Collect data with policy ${\pi_{\theta_{old}}}$
- Reuse this data multiple times to evaluate different policies ${\pi_{\theta}}$
- Dramatically improve sample efficiency
But there’s a catch. Importance sampling works well only when the two distributions are similar. If πθ becomes very different from πθ_old, the probability ratios can explode or vanish:
- Ratio » 1: New policy assigns much higher probability to some actions
- Ratio « 1: New policy assigns much lower probability to some actions
- Ratio ≈ 0: Catastrophic - new policy never takes actions the old policy preferred
Consider what happens if one action has ratio = 100 while others have ratio = 0.01. A single high-ratio sample can dominate the entire gradient estimate, leading to:
- Unstable training: Gradients vary wildly between batches
- Poor convergence: The algorithm makes erratic updates
- Sample inefficiency: We need many more samples to get reliable estimates
Constrained Policy Updates
The breakthrough insight: constrain how much the policy can change to keep importance sampling ratios well-behaved. This leads us naturally to the concept of trust regions - regions where we trust our importance sampling approximation to be accurate.
But, we must also ask. How do we guarantee that our policy updates always improve performance?
These observations bring us to two key concepts:
- The Minorize-Maximization (MM) algorithm
- Trust regions
Minorize-Maximization (MM) Algorithm
Can we guarantee that any policy update always improves the expected rewards? This seems impossible, but it’s theoretically achievable through the MM algorithm.
The idea: Instead of directly optimizing the complex true objective η(θ), we iteratively optimize simpler lower bound functions M(θ) that approximate η(θ) locally.
The MM algorithm follows this iterative process:
- Find a lower bound M that approximates the expected reward η locally at the current guess θ_i
- Optimize the lower bound M to find the next policy guess θ_{i+1}
- Repeat until convergence
For this to work, M must be:
- A lower bound: M(θ) ≤ η(θ) for all θ
- Tight at current point: M(θ_i) = η(θ_i)
- Easier to optimize: M should be simpler than η (typically quadratic)
The lower bound function has the form: $M(\theta) = g \cdot (\theta - \theta_{old}) - \frac{1}{2}(\theta - \theta_{old})^T F (\theta - \theta_{old})$
This is a quadratic approximation where:
- g is the gradient at θ_old
- F is a positive definite matrix (often related to the Hessian)
Image taken from RL — Trust Region Policy Optimization (TRPO) Explained
If M is a lower bound that never crosses η, then maximizing M must improve η.
Proof sketch:
- Since $M(\theta_{\text{old}}) = \eta(\theta_{\text{old}})$ and $M(\theta) \leq \eta(\theta)$ everywhere
- If we find $\theta_{\text{new}}$ such that $M(\theta_{\text{new}}) > M(\theta_{\text{old}})$
- Then $\eta(\theta_{\text{new}}) \geq M(\theta_{\text{new}}) > M(\theta_{\text{old}}) = \eta(\theta_{\text{old}})$
- Therefore $\eta(\theta_{\text{new}}) > \eta(\theta_{\text{old}})$ ✓
In simpler terms, we have a function $\eta(\theta)$ parameterized by $\theta$ (the weights of our neural network). It is not computationally tractable to optimize this function directly. Hence we create a close approximation function $M(\theta)$ using the lower bound function form described above. This approximation comes from the general theory of Minorize-Maximization algorithms (see Hunter & Lange, 2004).
This approximation $M(\theta)$ is computationally feasible and easier to optimize. What we have proved here is that as we improve $M(\theta)$, that improvement guarantees we also improve $\eta(\theta)$.
By optimizing a lower bound function approximating η locally, MM guarantees policy improvement every iteration and leads us to the optimal policy eventually. |
Trust Regions
There are two major optimization paradigms:
- Line Search (like gradient descent): Choose direction first, then step size
- Trust Region: Choose maximum step size first (the size of the trust region), then find optimal point within that region
In trust region methods, we:
- Define a trust region of radius δ around current policy θ_old
- Find the optimal policy within this constrained region
- Adapt the radius based on how well our approximation worked
The optimization problem becomes: $\max_{\theta} \; M(\theta)$ $\text{subject to} \; |\theta - \theta_{old}| \leq \delta$
Adaptive Trust Region Sizing
The trust region radius δ can be dynamically adjusted:
- If approximation is good: Expand δ for next iteration
- If approximation is poor: Shrink δ for next iteration
- If policy diverges too much: Shrink δ to prevent importance sampling breakdown
Why Trust Regions Work for RL
In reinforcement learning, trust regions serve a dual purpose:
- Mathematical: Keep our quadratic approximation M valid
- Statistical: Prevent importance sampling ratios from exploding
When policies change too much, both our lower bound approximation AND our importance sampling become unreliable. Trust regions keep us in the safe zone for both.
Mathematical Notation Reference
Symbol | Meaning |
---|---|
$\pi_\theta(a|s)$ | Policy probability of action a given state s |
$\pi_{\theta_{old}}(a|s)$ | Old policy probability |
$\tau$ | Trajectory $(s_1, a_1, s_2, a_2, \ldots)$ |
$\pi_\theta(\tau)$ | Probability of trajectory under policy $\pi_\theta$ |
$\frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$ | Importance sampling ratio for single timestep |
$\prod_{t=1}^{T} \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$ | Importance sampling ratio for full trajectory |
$R(\tau)$ | Total reward of trajectory |
$A(s_t, a_t)$ | Advantage function |
$\eta(\theta)$ | Expected reward under policy $\pi_\theta$ |
$M(\theta)$ | Lower bound function in MM algorithm |
$\theta_{old}$ | Current policy parameters |
$\delta$ | Trust region radius |
$F$ | Positive definite matrix (approximating curvature) |
$g$ | Policy gradient vector |
Trust Region Policy Optimization (TRPO)
Now we can finally understand how TRPO elegantly combines all the concepts we’ve explored:
- Importance Sampling - to reuse old data efficiently
- MM Algorithm - to guarantee policy improvement
- Trust Regions - to constrain policy changes and keep approximations valid
TRPO is a culmination of these ideas into a practical, theoretically-grounded algorithm.
Recall that our original objective was:
\[J(\pi) = \mathbb{E}_{\tau \sim \pi}[R(\tau)]\]This is the expected return (total reward) when following policy π. Instead of maximizing absolute performance $J(\pi’)$, TRPO maximizes the policy improvement:
\[\max_{\pi'} J(\pi') - J(\pi)\]This is mathematically equivalent to maximizing $J(\pi’)$ (since $J(\pi)$ is constant), but conceptually important - we’re explicitly measuring progress from our current policy.
Why focus on improvement? Because we can construct better approximations for the improvement $J(\pi’) - J(\pi)$ than for the absolute performance $J(\pi’)$. The MM algorithm works by finding lower bounds for this improvement.
To apply the MM algorithm, TRPO constructs a lower bound function ℒ that uses importance sampling:
\[\mathcal{L}_\pi(\pi') = \frac{1}{1-\gamma} \mathbb{E}_{s\sim d^\pi} \left[ \frac{\pi'(a|s)}{\pi(a|s)} A^\pi(s,a) \right] = \mathbb{E}_{\tau \sim \pi} \left[ \sum_{t=0}^{\infty} \gamma^t \frac{\pi'(a_t|s_t)}{\pi(a_t|s_t)} A^\pi(s_t, a_t) \right]\]ℒ looks complex, but let’s break this down piece by piece to understand what’s really happening here.
The discounted state visitation distribution $d^\pi(s)$ tells us how often we expect to visit each state when following policy π:
\[d^\pi(s) = (1-\gamma) \sum_{t=0}^{\infty} \gamma^t P(s_t = s|\pi)\]Think of this as a “popularity contest” for states. If γ = 1, this becomes just the regular state visit frequency under policy π. But when γ < 1, we care more about states we visit early in episodes than those we reach later. It’s like asking: “If I run my policy many times, which states will I spend most of my time in, giving more weight to earlier visits?”
The advantage function $A^\pi(s,a)$ we’ve already met - it tells us how much better taking action $a$ in state $s$ is compared to what the policy would do on average in that state.
But here’s where the magic happens. The function ℒ is essentially asking a clever question using importance sampling: “If I reweight all the actions my current policy π took according to how likely my new policy π’ would be to take them, what would my expected advantage be?”
This is brilliant because it lets us estimate how well policy π’ would perform without actually running it in the environment. We just take all our old experience from policy π and reweight it according to the probability ratio $\frac{\pi’(a|s)}{\pi(a|s)}$. When the new policy is more likely to take an action than the old one, we give that experience more weight. When it’s less likely, we give it less weight.
This importance sampling approach is what allows TRPO to reuse old data efficiently - a huge computational win over vanilla policy gradients that throw away all previous experience after each update.
The theoretical foundation comes from this crucial bound (proven in Appendix 2 of the TRPO paper):
\[J(\pi') - J(\pi) \geq \mathcal{L}_\pi(\pi') - C\sqrt{\mathbb{E}_{s\sim d^\pi}[D_{KL}(\pi' \| \pi)[s]]}\]This tells us:
- Left side: True policy improvement
- Right side: Our lower bound estimate minus a penalty term
The penalty term grows with KL divergence, so the bound becomes loose when policies differ too much.
Consider reading this blog to get a better idea about KLD |
Image taken from Wikipedia
The KL divergence measures how different two probability distributions are:
\[D_{KL}(P \| Q) = \sum_x P(x) \log \frac{P(x)}{Q(x)}\]For continuous distributions, this becomes:
\[D_{KL}(P \| Q) = \int P(x) \log \frac{P(x)}{Q(x)} dx\]Think of KL divergence as asking: “If I have samples from distribution P, how surprised would I be if I thought they came from distribution Q instead?” When the distributions are identical, KL divergence is zero. As they become more different, the divergence grows.
TRPO can be formulated in two mathematically equivalent ways:
KL-Penalized (Unconstrained): \(\max_{\pi'} \mathcal{L}_\pi(\pi') - C\sqrt{\mathbb{E}_{s\sim d^\pi}[D_{KL}(\pi' \| \pi)[s]]}\)
KL-Constrained: \(\max_{\pi'} \mathcal{L}_\pi(\pi')\) \(\text{subject to } \mathbb{E}_{s\sim d^\pi}[D_{KL}(\pi'||\pi)[s]] \leq \delta\)
These formulations arise directly from the theoretical bound we mentioned earlier:
\[J(\pi') - J(\pi) \geq \mathcal{L}_\pi(\pi') - C\sqrt{\mathbb{E}_{s\sim d^\pi}[D_{KL}(\pi'||\pi)[s]]}\]The unconstrained version simply maximizes this lower bound directly. The constrained version takes a different approach: instead of penalizing large KL divergences, it prevents them entirely by adding a hard constraint.
These are mathematically equivalent due to Lagrangian duality - a beautiful result from optimization theory. For every penalty coefficient C in the unconstrained problem, there exists a constraint threshold δ in the constrained problem that gives the same optimal solution. You can think of it like this: instead of saying “I’ll pay a penalty for going over the speed limit,” you’re saying “I absolutely won’t go over the speed limit.” Both approaches can lead to the same driving behavior, just with different enforcement mechanisms.
The lower bound is what we try to maximize to find the optimum $\theta$
Image taken from here
However, in practice, the constrained formulation wins by a landslide. Here’s why: the penalty coefficient C becomes a nightmare to tune when the discount factor γ gets close to 1. As γ approaches 1, the coefficient explodes, making the algorithm incredibly sensitive to small changes in γ. Imagine trying to tune a parameter that changes by orders of magnitude when you adjust γ from 0.99 to 0.995 - it’s practically impossible.
\[C \propto 1/(1-\gamma)^2\]The constrained version, on the other hand, gives you direct, interpretable control. The parameter δ simply says “don’t let the policy change too much,” which is much easier to understand and tune across different environments. It’s the difference between having a thermostat that directly controls temperature versus one that requires you to calculate complex equations involving heat transfer coefficients.
This practical insight would later inspire PPO’s breakthrough innovation. PPO took the unconstrained formulation and made it work brilliantly by replacing the complex second-order penalty with a simple first-order clipping mechanism. Instead of computing expensive Fisher Information Matrices, PPO just clips the importance sampling ratios directly - achieving similar performance with a fraction of the computational cost. |
The beauty of TRPO lies in its theoretical guarantee. Since we have the fundamental bound:
\[J(\pi') - J(\pi) \geq \mathcal{L}_\pi(\pi') - C\sqrt{\mathbb{E}_{s\sim d^\pi}[D_{KL}(\pi'(·|s) \| \pi(·|s))]}\]TRPO’s algorithm ensures three key things happen:
- Optimize $\mathcal{L}_\pi(\pi’)$ using importance sampling
- Constrain the KL divergence to stay small
- Rely on the fact that $\mathcal{L}_\pi(\pi) = 0$ when $\pi’ = \pi$
This last point is crucial and deserves explanation.
Why is ℒ_π(π) = 0? At the current policy, the importance sampling ratio becomes $\frac{\pi(a|s)}{\pi(a|s)} = 1$ for all actions. So we get:
\[\mathcal{L}_\pi(\pi) = \mathbb{E}_{s\sim d^\pi} \left[ \mathbb{E}_{a \sim \pi} \left[ 1 \cdot A^\pi(s,a) \right] \right] = \mathbb{E}_{s\sim d^\pi} \left[ \mathbb{E}_{a \sim \pi} \left[ A^\pi(s,a) \right] \right]\]But by definition, the advantage function has zero expectation under the policy - $\mathbb{E}_{a \sim \pi}[A^\pi(s,a)] = 0$ because it measures how much better each action is compared to the average. This means if we can make ℒ_π(π’) > 0 while keeping KL divergence small, we’re guaranteed that J(π’) > J(π). TRPO never moves backwards.
You can read more about the proof here |
$\mathcal{L}_\pi(\pi’) \geq 0$ implies $J(\pi’) \geq J(\pi)$ (Our new policy will always be better or equal to our current policy)
TRPO’s guarantee: Every policy update improves performance or leaves it unchanged. We never move backwards.
Think of TRPO this way:
- Sample trajectories with current policy $\pi$
- Estimate how well any nearby policy $\pi’$ would do on these same trajectories (importance sampling)
- Find the best nearby policy within our trust region (constrained optimization)
- Verify the policy is actually better before committing (safety check)
The trust region ensures our importance sampling estimates remain accurate, while the MM algorithm structure guarantees we always improve. The constrained optimization problem: \(\max_{\pi'} \mathcal{L}_\pi(\pi')\) \(\text{subject to } \mathbb{E}_{s\sim d^\pi}[D_{KL}(\pi'||\pi)[s]] \leq \delta\)
looks intimidating, but we can solve it elegantly using a Taylor expansion around our current policy parameters θ_k. This is where the mathematical beauty of TRPO really shines through.
Definition from Wikipedia
Let’s expand both the objective function and the constraint to second order around $\theta_k$. For the objective function $\mathcal{L}$:
Where:
- g = ∇θ Lθₖ(θ) |_θₖ (the gradient of the objective at θₖ)
- HL = ∇²θ Lθₖ(θ) |θₖ (the Hessian of the objective at θₖ)
We can skip the terms beyond second order because they become negligibly small when $\theta$ is close to $\theta_k$. This is the fundamental assumption of trust region methods - we’re making small enough steps that higher-order terms don’t significantly affect our approximation quality. |
For the KL constraint: \(\overline{D}_{KL}(\theta|\theta_k) \approx \overline{D}_{KL}(\theta_k|\theta_k) + \nabla_\theta \overline{D}_{KL}(\theta|\theta_k)|_{\theta_k}^T (\theta - \theta_k) + \frac{1}{2}(\theta - \theta_k)^T H_{KL} (\theta - \theta_k)\)
Now comes the key insight that simplifies everything. At the current policy θ_k, several terms vanish:
- $\mathcal{L}_{\theta_k}(\theta_k) = 0$ (we showed this earlier - the advantage has zero expectation)
- $\overline{D}_{KL}(\theta_k|\theta_k) = 0$ (KL divergence of a distribution with itself is always zero)
-
∇*θ D̄_KL(θ θₖ) *θₖ = 0 (the gradient of KL divergence at the reference point is zero)
This leaves us with a beautifully clean quadratic optimization problem:
\(\max_\theta g^T (\theta - \theta_k)\) \(\text{subject to } \frac{1}{2}(\theta - \theta_k)^T H_{KL} (\theta - \theta_k) \leq \delta\)
where g is the policy gradient: \(g = \nabla_\theta \mathcal{L}_{\theta_k}(\theta) |_{\theta_k}\)
and $H_{KL}$ is the Hessian of the KL divergence, which has a special name: the Fisher Information Matrix (FIM):
\[H_{KL} = \nabla^2_\theta \overline{D}_{KL}(\theta|\theta_k) |_{\theta_k} = F = \mathbb{E}_{s,a \sim \pi_k} \left[ \nabla_\theta \log \pi_\theta(a|s) \nabla_\theta \log \pi_\theta(a|s)^T \right]\]This constrained quadratic optimization problem has a closed-form solution that can be derived using Lagrange multipliers. Setting up the Lagrangian:
\[\mathcal{L}(\theta, \lambda) = g^T (\theta - \theta_k) - \lambda \left( \frac{1}{2}(\theta - \theta_k)^T F (\theta - \theta_k) - \delta \right)\]Taking the gradient with respect to θ and setting it to zero: \(\nabla_\theta \mathcal{L} = g - \lambda F (\theta - \theta_k) = 0\)
Solving for the optimal step: \(\theta - \theta_k = \frac{1}{\lambda} F^{-1} g\)
To find λ, we substitute back into the constraint: \(\frac{1}{2} \left( \frac{1}{\lambda} F^{-1} g \right)^T F \left( \frac{1}{\lambda} F^{-1} g \right) = \delta\)
\[\frac{1}{2\lambda^2} g^T F^{-1} g = \delta\] \[\lambda = \sqrt{\frac{g^T F^{-1} g}{2\delta}}\]Putting it all together, we get the Natural Policy Gradient update:
\[\theta_{k+1} = \theta_k + \sqrt{\frac{2\delta}{g^T F^{-1} g}} F^{-1} g\]Why “Natural”? This is where things get philosophically beautiful. Regular gradient descent uses the Euclidean distance in parameter space - it treats all parameter changes as equal. But this is fundamentally wrong for probability distributions!
Consider two neural networks that represent the same policy but with different parameterizations. Vanilla gradient descent would give them different updates, even though they’re the same policy. The Natural Policy Gradient fixes this by using the Fisher Information Matrix to measure distance in the space of probability distributions rather than parameter space.
The Fisher Information Matrix captures the “curvature” of the log-likelihood surface. Areas where small parameter changes cause big changes in the policy get more weight in the distance metric. This makes the algorithm reparameterization invariant - the actual policy updates remain the same regardless of how you parameterize your neural network.
Think of it like this: if you’re navigating on a curved surface, you shouldn’t use straight-line distances to plan your path. The Fisher Information Matrix gives you the “natural” metric for that curved space, ensuring you take the most efficient steps toward better policies.
The Backtracking Line Search
Our elegant mathematical derivation gives us the Natural Policy Gradient update:
\[\theta_{k+1} = \theta_k + \sqrt{\frac{2\delta}{g^T F^{-1} g}} F^{-1} g\]But here’s the rub: this assumes our quadratic approximation is perfect. In reality, neural networks are highly nonlinear, and our Taylor expansion is only valid in a small neighborhood around $\theta_k$. The computed step might violate our KL constraint or even decrease performance when the approximation breaks down.
TRPO’s solution is beautifully practical: backtracking line search. Instead of blindly taking the full computed step, TRPO modifies the update to:
\[\theta_{k+1} = \theta_k + \alpha^j \sqrt{\frac{2\delta}{g^T F^{-1} g}} F^{-1} g\]where $\alpha \in (0, 1)$ is the backtracking coefficient (typically 0.5), and $j$ is the smallest nonnegative integer such that the new policy satisfies both:
- The KL constraint: 𝔼{s∼d^π}[D_KL(π{θ{k+1}}(·|s) ‖ π{θₖ}(·|s))] ≤ δ
- Positive improvement: ℒ{θₖ}(θ{k+1}) ≥ 0
This conservative verification ensures TRPO never violates its theoretical guarantees, even when the quadratic approximation becomes inaccurate. It’s the algorithm’s safety net - systematically reducing the step size until both conditions are met.
However, there’s a computational nightmare lurking here. The Natural Policy Gradient requires computing $F^{-1}g$, which means inverting the Fisher Information Matrix:
\[F = \mathbb{E}_{s,a \sim \pi_k} \left[ \nabla_\theta \log \pi_\theta(a|s) \nabla_\theta \log \pi_\theta(a|s)^T \right]\]For modern deep networks with millions of parameters, F is a massive n×n matrix. Computing its inverse is O(n³) - completely impractical. Even storing the full matrix requires O(n²) memory, which quickly becomes impossible for large networks.
This computational bottleneck is what led to the development of Truncated Natural Policy Gradient, and ultimately to TRPO’s clever use of conjugate gradient methods to approximate the matrix inversion without ever computing F⁻¹ explicitly.
Truncated Natural Policy Gradient
The solution to our computational nightmare is elegantly simple: instead of computing F⁻¹g directly, we use the Conjugate Gradient method to solve the linear system:
\[F x = g\]This iterative approach finds x ≈ F⁻¹g without ever computing the matrix inverse, requiring only matrix-vector products Fv. It’s like finding the solution to a puzzle without having to understand every piece - we just need to know how the pieces interact.
Conjugate Gradient works by generating search directions that are “conjugate” - meaning they’re orthogonal after being transformed by matrix F. This ensures that each new search direction doesn’t undo progress from previous iterations. For quadratic problems like ours, CG guarantees finding the exact solution in at most n steps (where n is the number of parameters), but typically converges much faster in practice.
The key insight is that we never need to compute or store the full Fisher Information Matrix. Instead, we only need the matrix-vector product $Fx$ for any vector $x$. This can be computed efficiently using automatic differentiation:
\[Fx = \nabla_\theta \left( \left( \nabla_\theta \overline{D}_{KL}(\theta||\theta_k) \right)^T x \right)\]This Hessian-vector product gives us exactly what conjugate gradient needs without ever materializing the massive $F$ matrix.
The Complete TRPO Algorithm
TRPO weaves together all these concepts into a surprisingly elegant algorithm that carefully balances theoretical guarantees with practical implementation:
Step 1: Data Collection
- Collect trajectories using the current policy $\pi_k$
- Estimate the advantage function $A^{\pi_k}$ using any method (GAE, Monte Carlo returns, or temporal difference learning)
Step 2: Gradient Computation
- Compute the policy gradient \(g = \nabla_\theta \mathcal{L}_{\theta_k}(\theta) \|_{\theta_k}\)
- Set up the Fisher Information Matrix function for conjugate gradient operations
Step 3: Search Direction
- Solve $Fx = g$ using Conjugate Gradient to get the search direction $x$
- This gives us the natural gradient direction without explicitly inverting $F$
Step 4: Step Size Calculation
- Compute the initial step size to satisfy the trust region constraint
- Calculate $\alpha = \sqrt{\frac{2\delta}{g^T F^{-1} g}}$
Step 5: Conservative Verification (Line Search with Exponential Backoff)
- Propose an update: $\theta’ = \theta_k + \alpha \cdot x$
- Verify two critical conditions:
- KL divergence constraint: \(\mathbb{E}_{s\sim d^\pi}[D_{KL}(\pi'(·\|s) \| \pi(·\|s))] \leq \delta\)
- Surrogate improvement: $\mathcal{L}_{\theta_k}(\theta’) \geq 0$
- If either verification fails: reduce $\alpha$ (typically by half) and try again
- Only commit to the policy update after both conditions are satisfied
This conservative approach guarantees the theoretical properties we derived, but it also reveals TRPO’s fundamental tension between theory and practice. The algorithm is theoretically beautiful but computationally demanding, requiring multiple verification steps and potential backtracking that can make each update quite expensive.
TRPO’s Limitations
Despite its theoretical elegance, TRPO faces several practical challenges that motivated simpler alternatives:
- Computational Overhead: Computing Fisher Information Matrices and running conjugate gradient makes each update significantly more expensive than first-order methods like Adam
- Sample Inefficiency: Requires large batch sizes to accurately estimate the FIM - small batches lead to noisy estimates and unstable training
- Scalability Issues: Second-order computations become impractical for very large neural networks where first-order methods excel
TRPO’s story represents a classic tension in machine learning: the trade-off between theoretical rigor and practical utility. While TRPO provided crucial theoretical insights about policy optimization - principled policy updates, trust region concepts, and guaranteed improvement - its computational complexity limited its real-world impact.
This limitation sparked a natural question: could we achieve similar performance guarantees with a much simpler algorithm? The answer would come in the form of Proximal Policy Optimization (PPO), which took TRPO’s core insights and packaged them into an algorithm so simple and effective that it would become the workhorse of modern policy optimization.
As we noted earlier from the PPO paper: “Q-learning (with function approximation) fails on many simple problems and is poorly understood, vanilla policy gradient methods have poor data efficiency and robustness; and trust region policy optimization (TRPO) is relatively complicated, and is not compatible with architectures that include noise (such as dropout) or parameter sharing.”
PPO’s breakthrough was recognizing that you don’t need complex second-order methods to implement trust regions effectively. Instead of computing Fisher Information Matrices and running conjugate gradient, PPO simply clips the importance sampling ratios directly. This first-order approach achieves similar practical performance while being orders of magnitude simpler to implement and debug.
Note: TRPO in a crux is simple, it is a constrained optimization problem. To solve which we need second order derivatives. Which is computationally expensive and no current ML framework solves it without significant overhead. But do know, it is a tough topic to truly understand. One needs to be well-versed with many prerequisite mathematical knowledge. Do not be dishearted if it takes you time to understand it thorougly. Read slowly, daily, iteratively.
Proximal Policy Optimization (PPO)
PPO emerged from a simple observation: what if we could achieve TRPO’s stability guarantees without all the computational complexity? The genius of PPO lies in replacing TRPO’s hard KL constraint with a clever objective function that naturally prevents large policy updates.
Let’s first understand the core innovation. Remember TRPO’s constrained optimization problem:
\[\max_{\theta} \mathcal{L}_{\theta_{old}}(\theta) \quad \text{subject to } \mathbb{E}_{s \sim d^{\pi_{old}}}[D_{KL}(\pi_{old}(\cdot|s) || \pi_\theta(\cdot|s))] \leq \delta\]PPO asks: instead of explicitly constraining the KL divergence, what if we modify the objective function itself to penalize large policy changes? This transforms a constrained optimization problem into an unconstrained one that standard optimizers like Adam can handle.
The Clipped Surrogate Objective
PPO introduces a brilliantly simple idea. Define the probability ratio:
\[r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}\]This ratio tells us how much more (or less) likely the new policy is to take the same action compared to the old policy. When $r_t(\theta) = 1$, the policies agree perfectly. When $r_t(\theta) = 2$, the new policy is twice as likely to take that action.
The vanilla policy gradient objective using importance sampling would be:
\[L^{IS}(\theta) = \mathbb{E}_t[r_t(\theta) \cdot A_t]\]But this can lead to destructively large policy updates when $r_t(\theta)$ becomes very large or very small. PPO’s innovation is to clip this ratio:
\[L^{CLIP}(\theta) = \mathbb{E}_t[\min(r_t(\theta) \cdot A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \cdot A_t)]\]
Image taken from paper
Let’s unpack this equation with concrete examples to build intuition.
Case 1: Positive Advantage (A_t > 0)
When an action led to better-than-expected rewards, we want to increase its probability. Let’s say $A_t = 2$ and $\epsilon = 0.2$.
-
If $r_t(\theta) = 0.5$ (new policy half as likely):
- Unclipped objective: $0.5 \times 2 = 1$
- Clipped objective: $\min(1, 0.8 \times 2) = 1$
- No clipping occurs since we’re making the policy worse for a good action
-
If $r_t(\theta) = 1.5$ (new policy 50% more likely):
- Unclipped objective: $1.5 \times 2 = 3$
- Clipped objective: $\min(3, 1.2 \times 2) = 2.4$
- Clipping kicks in! We cap the improvement to prevent overconfidence
The key insight: for positive advantages, clipping prevents us from changing the policy too aggressively in the “good” direction. Once $r_t(\theta) > 1 + \epsilon$, there’s no benefit to increasing it further.
Case 2: Negative Advantage (A_t < 0)
When an action led to worse-than-expected rewards, we want to decrease its probability. Let’s say $A_t = -2$ and $\epsilon = 0.2$.
-
If $r_t(\theta) = 0.5$ (new policy half as likely):
- Unclipped objective: $0.5 \times (-2) = -1$
- Clipped objective: $\min(-1, 0.8 \times (-2)) = -1.6$
- Clipping makes the objective more negative, encouraging further reduction
-
If $r_t(\theta) = 1.5$ (new policy 50% more likely):
- Unclipped objective: $1.5 \times (-2) = -3$
- Clipped objective: $\min(-3, 1.2 \times (-2)) = -3$
- No clipping since we’re increasing probability of a bad action
For negative advantages, clipping prevents us from reducing the probability too aggressively. Once $r_t(\theta) < 1 - \epsilon$, there’s no benefit to decreasing it further.
The Mathematical Beauty of PPO’s Objective
The clipped objective creates a “trust region” implicitly. When the policy changes too much (beyond $1 \pm \epsilon$), the gradient of the clipped objective becomes zero with respect to $\theta$. This elegant mechanism automatically prevents destructive updates without requiring second-order optimization.
To see this mathematically, consider the gradient when $A_t > 0$ and $r_t(\theta) > 1 + \epsilon$:
\[\frac{\partial L^{CLIP}}{\partial \theta} = \frac{\partial}{\partial \theta}[\text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \cdot A_t] = 0\]The gradient vanishes because the clipped value $(1+\epsilon)$ doesn’t depend on $\theta$. This creates a “flat” region in the loss landscape that prevents further movement in that direction.
PPO with Adaptive KL Penalty
Before arriving at the clipped objective, the PPO paper explored a KL penalty approach that directly connects to TRPO:
\[L^{KLPEN}(\theta) = \mathbb{E}_t[r_t(\theta) \cdot A_t - \beta \cdot D_{KL}(\pi_{\theta_{old}}(\cdot|s_t) || \pi_\theta(\cdot|s_t))]\]This is exactly the unconstrained version of TRPO’s problem! The Lagrangian of TRPO’s constrained optimization:
\[\max_{\theta} \mathcal{L}_{\theta_{old}}(\theta) - \lambda(\mathbb{E}[D_{KL}] - \delta)\]becomes PPO’s KL penalty objective when we fix $\beta = \lambda$. The key difference: PPO adapts $\beta$ dynamically during training:
if kl_divergence > 1.5 * target_kl:
beta *= 2 # Increase penalty
elif kl_divergence < target_kl / 1.5:
beta /= 2 # Decrease penalty
However, this adaptive mechanism proved finicky in practice. The clipped objective achieved similar goals with fixed hyperparameters, making it the preferred approach.
Why Multiple Epochs Work: The Importance Sampling Perspective
A subtle but crucial aspect of PPO is performing multiple epochs of updates on the same data. This seems to violate our earlier concern about importance sampling breaking down when policies diverge. The clipping mechanism is precisely what makes this safe.
Consider what happens over multiple epochs:
- Epoch 1: Policy changes slightly, ratios stay near 1
- Epoch 2: For trajectories where policy already changed, clipping prevents further movement
- Epochs 3-10: Most gradients are zero due to clipping, only “unexploited” trajectories contribute
The clipping essentially creates a curriculum where different parts of the data become “active” in different epochs, naturally preventing overfitting to any particular trajectory.
PPO’s clipping prevents the fine-tuned model from diverging too far from the base model’s distribution, maintaining fluency while optimizing for human preferences. This is why responses from RLHF models feel coherent - they’re constrained to stay within a trust region of the original model’s behavior.
The journey from policy gradients through TRPO to PPO represents a beautiful example of how complex theoretical insights can be distilled into simple, practical algorithms. PPO takes TRPO’s guarantee of monotonic improvement and approximates it with a first-order method that captures the essential insights: prevent destructive updates, enable data reuse, and maintain computational simplicity.
MoE
Link to paper: Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer
Link to implementation: [WORK IN PROGRESS]
Quick Summary
This 2017 paper by Shazeer et al. introduces a novel approach to dramatically increase neural network capacity without proportionally increasing computational costs. The core innovation is the Sparsely-Gated Mixture-of-Experts (MoE) layer, which contains thousands of feed-forward neural networks (experts), with a trainable gating network that selectively activates only a small subset of experts for each input example.
Key highlights:
- The authors achieve over 1000x improvements in model capacity while maintaining computational efficiency
- Their approach addresses several challenges of conditional computation, including GPU utilization and load balancing
- When applied to language modeling and machine translation tasks, their MoE models significantly outperform state-of-the-art models with lower computational cost
- Their largest model contains up to 137 billion parameters and demonstrates continued performance improvements with increased capacity
This paper represents a significant advancement in scaling neural networks efficiently, presaging some of the techniques that would later become important in very large language models.
Another explosive paper, in 2017. Talk about being a crazy year right. Well to be perfectly honest MOE was actually introduced in 1991 in the paper Adaptive Mixture of Local Experts. But Noam et al introduced the idea to LSTMs, which really blew up.
Problem
The capacity of a neural network to absorb information is limited by its number of parameters.
Solution
Conditional computation, where parts of the network are active on a per-example basis, has been proposed in theory as a way of dramatically increasing model capacity without a proportional increase in computation.
The following blogs helped me immensely while writing this section
A simple intuition behind MoE can be seen as above, A single dense neural network is like a big student. Who has general knowledge about a lot of things without being particularly great at any one topic. When you ask him a question he takes his time to think and answers you, He also eats a lot because he is big.
But with a MoE layer, a smart router reads the question and directs it to the right expert. That expert gives a focused answer since they only need to worry about their specialty. As we’re only activating one small expert instead of the entire large model, we use much less computation while still having access to lots of specialized knowledge.
The above visualization is good for intuition point of view, but that is not how MoEs actually work in practice. For starter each expert is not an expert in a topic but expert in tokens, some can be punctuation experts, some can be noun experts etc.(More on this later)
This work introduced MoEs to LSTMs, so let us proceed forward in understanding that. Consider reading the following blog Understanding LSTM Networks by Christopher Olah & Recurrent Neural Networks (RNNs), Clearly Explained!!! by Josh Starmer if you need a refresher on the topic.
In an MoE model, the Fully connected neural network(FCNN) (Or the hiddens state in case of RNNs & LSTMs) is replaces with an MoE layer. The MoE layer consists of a gating function which outputs a probability distribution of likely experts. The experts themselves are smaller FCNN. The output of the experts is multiplied with their probability after which it is finally summed over.
The idea seems simple enough, but there are multiple complexities like:
- How do you create a fair gating function?
- How many experts do you choose?
- How many tokens do you send to each expert?
Let’s us go through each question one by one.
Note: We will see many changes that were made on this idea as we progress, but this was the foundational paper on MoEs for large models. So it is crucial that you understand it well.
Sparse vs Dense Networks
Dense Networks: Every parameter is used for every input
- High computational cost that scales linearly with parameters
- Limited by memory and compute constraints
- Parameters must be “generalists” handling all types of inputs
Sparse Networks (MoE): Only a subset of parameters are used per input
- Computational cost scales with active experts, not total experts
- Can have 1000x more parameters with similar compute budget
- Parameters can be highly specialized for specific patterns
conditional computation allows us to scale model capacity without proportional compute scaling. It’s like having a library with thousands of specialized books, but only reading the few relevant ones for each question.
The Gating Network
First let us understand how the output is calculated in a sparse MoE.
We begin with an input matrix X, multiple that by the router weights W. We take the softmax of this output to get the probability distribution $G(x)$. This is the likelihood of which experts are best for the given input.
Depending on how many experts we choose, we take the output of those experts and multiply that with the probability of that output begin chosen (This is done distriute the importance of the output based on which expert is most likely to be chosen). That gives us the output.
When we put it all together, this is how it looks.
The original paper introduced two key innovations:
Softmax Gating (Dense Baseline)
# Simple dense gating - activates ALL experts with different weights
G(x) = Softmax(x · W_g)
y = Σ G(x)_i * Expert_i(x) # All experts contribute
Noisy Top-K Gating (The Sparse Innovation)
# Step 1: Add trainable noise for load balancing
H(x)_i = (x · W_g)_i + StandardNormal() * Softplus((x · W_noise)_i)
# Step 2: Keep only top K experts, set others to -∞
KeepTopK(v, k)_i = {
v_i if v_i is in top k elements
-∞ otherwise
}
# Step 3: Apply softmax (experts with -∞ get probability 0)
G(x) = Softmax(KeepTopK(H(x), k))
Why the noise? The Gaussian noise helps with load balancing. Without it, the same few experts would always dominate, creating a “rich get richer” problem where popular experts get more training and become even more popular.
Why Top-K? By keeping only the top K experts (typically K=2 or K=4), we achieve:
- Sparsity: Most experts are inactive, saving computation
- Specialization: Each expert focuses on specific patterns
- Scalability: We can add thousands of experts without proportional compute increase
Addressing Performance Challenges
The paper identified several critical challenges that needed solving for MoE to work in practice:
The Shrinking Batch Problem
Original batch: 1024 examples
With 256 experts, k=4: Each expert sees only ~16 examples
Small batches = inefficient GPU utilization
Solution:
Mix Data and Model Parallelism
- Combine batches from multiple GPUs before sending to experts
- Each expert gets larger effective batch size:
(batch_size * num_devices * k) / num_experts
- Achieves factor of
d
improvement in expert batch size withd
devices
Network Bandwidth Bottleneck
Modern GPUs have computational power thousands of times greater than network bandwidth. Meaning most time is spent between I/O operations.
Solution:
- Keep experts stationary on devices (don’t move the experts)
- Only send inputs/outputs across network (much smaller)
- Use larger hidden layers to improve computation-to-communication ratio
To understand this better, consider reading Making Deep Learning Go Brrrr From First Principles |
Load Balancing Problem
Without intervention, a few experts dominate while others are rarely used. This creates a vicious cycle: popular experts get more training data, become better, and thus get selected even more often. Meanwhile, neglected experts remain undertrained and essentially become dead weight.
Think of it like a classroom where only the brightest students get called on - they get more practice and become even brighter, while others stagnate.
The Dual Challenge
The paper identifies that we need to balance two distinct but related problems:
- Importance Imbalance: Some experts get high gating weights but few examples
- Load Imbalance: Some experts get many examples but low individual weights
Both scenarios are problematic. An expert with few high-weight examples overfits to specific patterns, while an expert with many low-weight examples receives weak learning signals.
Mathematical Solution: Auxiliary Loss
The authors introduce a load balancing loss that uses the coefficient of variation (CV) to measure and penalize imbalance:
\[CV = \frac{\sigma}{\mu} = \frac{\text{standard deviation}}{\text{mean}}\]The CV is beautiful because it’s scale-invariant - it measures relative variability regardless of the absolute magnitudes. A CV of 0 means perfect balance, while higher values indicate increasing imbalance.
Step 1: Measuring Importance
For each expert $i$, we sum its gating probabilities across the entire batch:
\[\text{Importance}(X)_i = \sum_{x \in X} G(x)_i\]This gives us the “importance scores” - how much each expert contributes regardless of which specific examples it processes.
Step 2: Computing the Importance Loss
\[\mathcal{L}_{\text{importance}} = w_{\text{importance}} \cdot CV(\text{Importance}(X))^2\]Where: \(CV(\text{Importance}(X)) = \frac{\sigma(\text{Importance}(X))}{\mu(\text{Importance}(X))}\)
Why square the CV? This creates a stronger penalty for large imbalances and makes the gradient more well-behaved during optimization.
Step 3: Measuring Load
Load measures how many examples each expert actually processes:
\[\text{Load}(X)_i = \sum_{x \in X} \mathbf{1}[\text{expert } i \text{ is in top-k for } x]\]In practice, this uses a smooth differentiable approximation rather than the hard indicator function.
Step 4: Computing the Load Loss
\[\mathcal{L}_{\text{load}} = w_{\text{load}} \cdot CV(\text{Load}(X))^2\]The Complete Auxiliary Loss
\[\mathcal{L}_{\text{auxiliary}} = w_{\text{importance}} \cdot CV(\text{Importance}(X))^2 + w_{\text{load}} \cdot CV(\text{Load}(X))^2\]Final Training Objective
\[\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{main}} + \mathcal{L}_{\text{auxiliary}}\]Why Both Losses Matter
Consider these scenarios:
- Expert A: Gets selected for 100 examples with average weight 0.01 each
- Expert B: Gets selected for 2 examples with average weight 0.5 each
- Expert C: Gets selected for 50 examples with average weight 0.02 each
All have similar total importance (≈ 1.0), but vastly different training dynamics:
- Expert A gets many weak signals → slow learning
- Expert B gets few strong signals → overfitting risk
- Expert C gets balanced signal → healthy learning
The dual loss ensures both the total contribution (importance) and the number of training examples (load) are balanced across experts.
Practical Impact
With proper load balancing:
- All experts receive sufficient training signal
- No expert dominates the computation
- Model capacity is fully utilized
- Training stability improves dramatically
This auxiliary loss was crucial for making MoE work at scale - without it, the models would collapse to using only a handful of experts, defeating the entire purpose of conditional computation.
Expert Capacity
This wasn’t introduced in this paper, but let’s talk about it too since it’s crucial for modern MoE implementations. Even with perfect load balancing, there’s another challenge: token overflow. In the example above, FFNN 1 receives the majority of tokens. To prevent any single expert from being overwhelmed, we set an Expert Capacity - a maximum number of tokens each expert can process per batch. When an expert reaches capacity, additional tokens that would have been routed to it are either sent to the next-best expert or bypass the MoE layer entirely (called token overflow). This capacity mechanism ensures balanced computational load across experts and prevents memory bottlenecks, though it can sometimes mean tokens don’t get processed by their optimal expert. The trade-off between perfect routing and practical constraints is a key engineering challenge in scaling MoE systems.
Training the MoE Model
Key Challenge: How do experts specialize without explicit supervision?
The specialization emerges through training dynamics:
- Initial randomness: All experts start random and perform similarly
- Noise-induced preferences: The noise in gating creates slight preferences
- Reinforcement loop: Experts that perform well for certain inputs get selected more
- Emergent specialization: Through this process, experts develop distinct capabilities
What do experts actually learn? (From the paper’s analysis)
Image taken from Mistral paper
Unlike the intuitive “biology expert” or “math expert”, real MoE experts learn much more fine-grained patterns:
- Syntactic specialization: Expert 381 specializes in contexts with “researchers”, “innovation”, and “technology”
- Positional patterns: Expert 752 handles phrases where indefinite article “a” introduces important concepts
- Semantic clustering: Expert 2004 focuses on contexts involving speed and rapid change
This emergent specialization is what makes MoE powerful - experts automatically discover useful divisions of labor without being explicitly told what to specialize in.
Revolutionary Results
Language Modeling (1B Word Benchmark):
- 4096-expert MoE: 24% better perplexity than dense baseline
- Same computational cost as much smaller dense models
- Up to 137B parameters (1000x parameter increase) with minimal compute overhead
- Training time: 12 hours vs weeks for equivalent dense models
Machine Translation (WMT’14):
- En→Fr: 40.56 BLEU (vs 39.22 for GNMT)
- En→De: 26.03 BLEU (vs 24.91 for GNMT)
- Achieved new state-of-the-art with lower computational cost
- Faster training than dense models with better quality
Computational Efficiency:
- MoE models achieved 0.74-1.56 TFLOPS/GPU
- Significant fraction of theoretical maximum (4.29 TFLOPS/GPU)
- Only 37-46% of operations were in expert computations
The Breakthrough: This was the first time conditional computation delivered on its theoretical promise at scale. Previous attempts had struggled with the practical challenges that this paper solved.
From LSTMs to Modern Transformers
While this paper applied MoE to LSTMs (the dominant architecture in 2017), the core insights proved even more powerful when later applied to Transformers, about which we will learn more about in the later sections.
The path from this 2017 paper to modern LLMs shows how foundational ideas can have delayed but massive impact. Key lessons that influenced later work:
- Sparsity enables scale: The core insight that you can have orders of magnitude more parameters without proportional compute increase
- Load balancing is crucial: Without proper load balancing, MoE models fail to train effectively
- Engineering matters: Success required solving practical challenges like communication costs and batch sizing
- Specialization emerges: Given proper training dynamics, experts will naturally develop useful specializations
Today’s largest language models increasingly rely on MoE architectures, making this paper’s contributions more relevant than ever. The ability to scale to trillion-parameter models while maintaining reasonable training costs has become essential for pushing the boundaries of AI capabilities.
WORK IN PROGRESS NOTICE
Rest of the sections from 2018-2025 are still being worked on by me, I have a rough draft prepared for each year. But to do justice to the material as well as create visualizations that clearly and explicitly explain the idea, it takes me considerable time. I am also spending time to reimpliment each paper and publish it on github. Consider following me on my socials to stay upto date with what I am doing. Thank you for all the support and reading what I write!!! You are awesome and your love keeps me motivated :)
Comments