Unboxing LLMs > loading...

July 3, 2023

Activation Checkpointing: Trading Computation for Memory in Deep Learning

Introduction

Modern deep learning models have ballooned to monstrous scales – shifting from millions to billions, soon trillions, of parameters. Training these behemoths slams you headfirst into a brutal reality: GPU memory consumption. The training process isn’t just about holding the model weights. It demands hoarding intermediate activation values for the sacred ritual of backpropagation. This memory hunger becomes a hard wall, preventing researchers from scaling up further or even using batch sizes that aren’t embarrassingly small.
Standard Training (High Memory)

Enter Activation Checkpointing (or Gradient Checkpointing, if you prefer). It’s not magic, but a brutally effective trick to sidestep the memory gods. Instead of dutifully storing every damned activation from the forward pass, you strategically save only a handful at designated “checkpoint” locations. The rest? Discarded. Vanished. Only to be regenerated, through sheer computational effort, when the backward pass demands them.

This offers a stark, uncomfortable, but often necessary tradeoff: you pay a tax in computation (recalculating those activations) to buy yourself precious memory headroom. For the colossal models like transformers that define the current landscape, this technique is often the only thing making training feasible on the hardware we actually possess.

With Checkpointing (Lower Memory)

How Activation Checkpointing Works

The Memory Problem in Backpropagation

To grasp checkpointing, you first have to internalize why training these deep networks is such a memory hog:

  1. Forward Pass: Data flows in, layers churn, activations pop out at each step. Simple enough.
  2. Storage Mandate: Here’s the rub. Backpropagation needs those intermediate activations to compute gradients. No activations, no gradients, no learning. So, you must store them.
  3. The Memory Bottleneck: For deep networks, stacking up activations layer after layer consumes gigabytes. Throw in large batches, and your GPU memory evaporates faster than venture capital in a bear market.

For a network depth L and batch size N, activation memory scales like O(N × L). This scaling is the enemy, the direct cause of countless “CUDA out of memory” errors that haunt practitioners.

Core Mechanics of Checkpointing

Activation checkpointing tackles this with a piece of computational judo:

  1. Strategic Segmentation
    You carve the network into segments. Think of checkpoints as gates between these segments. Only the activations at these gates are deemed worthy of storage.
  2. Selective Storage During Forward Pass
  • Activations at checkpoints? Keep ’em.
  • Activations between checkpoints? Calculate them, use them for the next step, then immediately discard them. Free up that memory.
  1. Recomputation During Backward Pass
  • When backpropagation hits the boundary of a segment, it needs the intermediate activations that you just threw away.
  • The solution? Recompute the forward pass for just that segment, starting from the stored activation at the segment’s beginning.
  • This regenerates the needed activations on the fly. Now you can calculate gradients and continue the backward pass.

sequenceDiagram diagram

This maneuver drastically cuts peak memory usage. The price? You’re doing extra forward computation during the backward pass – paying the computational toll for your memory savings.

The Memory-Computation Tradeoff

Let’s put some rough numbers on this Faustian bargain:

  • Without Checkpointing:
    • Memory (Activations): O(N × L) – The killer.
    • Computation: 1 forward pass + 1 backward pass – The baseline.
  • With Checkpointing (using √L checkpoints, a common strategy):
    • Memory: Slashed to roughly O(N × √L) – The salvation.
    • Computation: 1 forward pass + (approx. 1 extra forward pass during backprop) + 1 backward pass – The cost.
    • Computational Overhead: Often cited as up to ~33% increase (roughly one extra forward pass per full forward+backward cycle), though the theoretical worst case is higher if recomputation dominates.

Memory Usage Comparison

This exchange rate – trading potentially 33%+ compute time for potentially massive memory reduction – is often incredibly attractive when the alternative is simply not being able to train the model at all.

Optimal Checkpoint Placement

Where you stick these checkpoints matters. Slamming them down randomly isn’t smart. Common sense dictates a few strategies:

  1. Uniform Checkpointing
    The simplest approach: divide the network into roughly equal segments. Placing √L checkpoints evenly spaced gets you that O(N × √L) memory complexity. Simple, predictable, often good enough.
  2. Nested Checkpointing
    Get fancy. Apply checkpointing recursively within segments. Buys you even more memory, but the recomputation cost starts climbing faster. Use when desperate.
  3. Selective Checkpointing
    Be surgical. Identify the layers hogging the most activation memory (often the larger or earlier layers) and checkpoint around them. Requires more insight into your model’s guts but can be more efficient.

Selective Checkpointing (Checkpointing memory hogs)

Research papers delve into finding the absolute optimal placement, but it’s often model-specific voodoo. For transformers, a pragmatic sweet spot is often checkpointing each transformer block (or every few blocks). It strikes a decent balance between memory savings, computational overhead, and implementation simplicity.

Thankfully, the major frameworks have internalized this pain and offer relatively painless ways to implement checkpointing, hiding much of the messy recomputation logic.

PyTorch

PyTorch’s torch.utils.checkpoint is the weapon of choice. You wrap the part of your model you want to checkpoint.

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint

# Assume SelfAttention, FeedForward are defined elsewhere
class CheckpointedTransformerBlock(nn.Module):
    def __init__(self, dim, heads):
        super().__init__()
        # Assume standard transformer block components are initialized
        self.attention = nn.MultiheadAttention(dim, heads, batch_first=True) # Example component
        self.ffn = nn.Sequential(nn.Linear(dim, dim * 4), nn.ReLU(), nn.Linear(dim * 4, dim)) # Example component
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        
    # This function defines the computation to be checkpointed
    def _forward(self, x):
        # Simplified block logic for illustration
        attn_output, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_output)
        ffn_output = self.ffn(x)
        x = self.norm2(x + ffn_output)
        return x
        
    # The public forward method applies checkpointing
    def forward(self, x):
        # checkpoint will run _forward(x) during the forward pass
        # but discard intermediate activations inside _forward.
        # It will re-run _forward(x) during the backward pass as needed.
        # `use_reentrant=False` is often recommended for newer PyTorch versions
        # for better compatibility with hooks and other features, but requires
        # inputs/outputs that don't need gradient tracking themselves (usually true for x).
        return checkpoint.checkpoint(self._forward, x, use_reentrant=False) 

# Usage in a larger model
class MemoryEfficientTransformer(nn.Module):
    def __init__(self, dim, depth, heads):
        super().__init__()
        self.layers = nn.ModuleList([
            CheckpointedTransformerBlock(dim, heads)
            for _ in range(depth)
        ])
        # other components like embeddings, final layer...
        
    def forward(self, x):
        # Pass input through the stack of checkpointed blocks
        for layer in self.layers:
            x = layer(x)
        # ... pass through final layers ...
        return x

The key is that the framework handles the “when” and “how” of recomputation during backprop. You just declare what segment gets this treatment.

TensorFlow/Keras

TensorFlow uses the tf.recompute_grad decorator, achieving a similar effect: wrap the function defining the computation you want to checkpoint.

import tensorflow as tf

# Assume layer_logic contains the actual operations of the layer
def layer_logic(inputs):
    # ... dense layers, convolutions, attention, etc. ...
    outputs = tf.keras.layers.Dense(128, activation='relu')(inputs) # Example
    return outputs

# Apply the decorator to the function you want to checkpoint
@tf.recompute_grad
def checkpointed_operation(inputs):
    return layer_logic(inputs)

# Usage within a Keras model or TensorFlow graph
def build_model(input_shape):
    inputs = tf.keras.Input(shape=input_shape)
    # ... some initial layers ...
    x = tf.keras.layers.Dense(256, activation='relu')(inputs) # Example
    # Apply the checkpointed operation
    x = checkpointed_operation(x) 
    # ... subsequent layers ...
    outputs = tf.keras.layers.Dense(10, activation='softmax')(x) # Example
    
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

# model = build_model((784,))
# model.summary() # Will show the structure, but checkpointing is a runtime behavior

HuggingFace Transformers

The HuggingFace ecosystem, built for massive models, makes this trivial. You often just flip a switch in the model configuration.

from transformers import AutoConfig, AutoModelForCausalLM # More general example

# Load config, enabling gradient checkpointing
# Use trust_remote_code=True only if you trust the model source
config = AutoConfig.from_pretrained('gpt2', gradient_checkpointing=True, use_cache=False, trust_remote_code=True) 
# use_cache=False is often required when using gradient checkpointing

# Instantiate the model with the modified config
model = AutoModelForCausalLM.from_pretrained('gpt2', config=config, trust_remote_code=True)

# Now, when you train this model, checkpointing will be active
# training_args = TrainingArguments(..., gradient_checkpointing=True) # Can also be set here sometimes

This level of abstraction is fantastic, but it’s crucial to understand the underlying trade-off you’re enabling with that single boolean flag.

Real-World Impact: Case Studies

This isn’t theoretical. Checkpointing is the bedrock beneath many large-scale successes.

Training GPT-3 Scale Models

Let’s be blunt: the 175B parameter GPT-3 would likely have been stillborn without activation checkpointing and other scaling techniques. Standard backpropagation would demand well over a terabyte of activation memory per replica – ludicrous. Checkpointing, combined with model/data parallelism, wrestled this down into the realm of the merely extremely difficult, fitting within the distributed hardware clusters available.

Vision Transformers (ViT)

Processing high-resolution images with deep Vision Transformers creates enormous activation maps. Checkpointing allows training deeper, more powerful ViTs, or fine-tuning them on larger images, even when individual GPU memory is a constraint. It unlocks architectures that would otherwise choke standard hardware.

Benchmarks and Performance Metrics

Numbers paint the picture. These are illustrative – reality depends heavily on architecture, hardware, and implementation details.

Model SizeBatch SizeMemory (No Checkpointing)Memory (With Checkpointing)Compute Overhead (Approx)
~125M params3216GB8GB~30%
~1.5B params16OOM (on 24GB GPU)~20GB~35%
~6B params8OOM (on 40GB GPU)~32GB~40%

OOM = Out of Memory. Memory figures are peak activation memory estimates.

The trend is clear: checkpointing buys you significant memory headroom, especially crucial for larger models, at the cost of noticeable but often acceptable computational slowdown.

Advanced Techniques and Optimizations

The basic idea can be refined further:

Combining with Mixed Precision Training

Checkpointing and mixed-precision (FP16/BF16) are natural allies. Together they deliver a potent one-two punch:

  • Checkpointing slashes the number of activations stored.
  • Mixed precision halves the size of those activations (and weights/gradients).
  • Mixed precision often speeds up computation via tensor cores. The synergy allows dramatically larger models and/or batch sizes.

Selective Activation Recomputation

Moving beyond simple uniform checkpointing, smarter implementations analyze the computation graph:

  • Recompute cheap layers, checkpoint expensive ones.
  • Factor in the memory footprint of each layer’s activation.
  • Identify critical paths to minimize recomputation overhead. This requires deeper analysis but can squeeze out more efficiency.

Integration with Model Parallelism

On massive multi-GPU setups:

  • Tensor Parallelism splits operations – checkpointing applies within each device’s part of the operation.
  • Pipeline Parallelism splits layers across GPUs – checkpointing applies within each stage/GPU. Checkpointing remains a vital tool even when distributing the model itself.

Alternatives and Complementary Approaches

Checkpointing isn’t the only weapon in the memory-saving arsenal. It’s often deployed alongside:

  1. Activation Offloading: Swapping activations to (slow) CPU RAM instead of recomputing. Useful if CPU bandwidth is high and compute is scarce.
  2. Gradient Accumulation: Simulating larger batches by processing smaller micro-batches sequentially and accumulating gradients before updating weights. Reduces memory per step but doesn’t shrink activation size within a micro-batch.
  3. Reversible Layers: Specially designed layers (like RevNets) where activations can be perfectly reconstructed during the backward pass without recomputation. Elegant, but restricts architectural choices.
  4. CPU Offloading (Parameters/Optimizer): Shifting model weights or optimizer states to CPU memory (e.g., DeepSpeed ZeRO Stage 3). Complementary to activation checkpointing.

When to Use Activation Checkpointing

You need to consider activation checkpointing when:

  • You’re building truly deep networks (hundreds/thousands of layers, or dozens of large transformer blocks).
  • Your GPU screams “CUDA out of memory” despite reasonable batch sizes.
  • You desperately need a larger batch size for training stability or convergence, but memory is the bottleneck.
  • You have compute cycles to burn but are drowning in memory limitations.
  • Transformers are your game – it’s almost standard practice there.

You can likely skip it when:

  • Raw training speed is paramount, and you have memory to spare.
  • Your model is already compute-bound (eg., exotic custom operations).
  • The model fits comfortably within GPU memory with your desired batch size. Why pay the compute tax if you don’t have to?

Conclusion

I believe that activation checkpointing is a fundamental algorithmic pillar supporting the current era of large-scale deep learning. By offering a controllable way to trade computation for memory, it ripped down walls that previously limited model size and complexity. It transformed impossibly large models into merely resource-intensive ones.

As models continue their relentless march towards larger scales, checkpointing techniques will undoubtedly become even more sophisticated, optimizing the balance between memory footprint and computational cost. But the core principle – selectively forgetting and recomputing – will remain essential knowledge for anyone pushing the boundaries of deep learning on real-world hardware.

For researchers and engineers grappling with GPU memory limits, mastering activation checkpointing isn’t optional. It’s a critical survival skill, often the deciding factor between hitting a hard hardware ceiling and successfully training the next generation of powerful models.

References and Further Reading

  • Chen, Tianqi, et al. “Training deep nets with sublinear memory cost.” arXiv preprint arXiv:1604.06174 (2016). (The seminal paper, often cited as “Gradient Checkpointing”)
  • Gruslys, Audrunas, et al. “Memory-efficient backpropagation through time.” Advances in Neural Information Processing Systems 29 (2016). (Related ideas in recurrent settings)
  • Jain, Paras, et al. “Checkmate: Breaking the memory wall with optimal tensor rematerialization.” Proceedings of Machine Learning and Systems 2 (2020). (More advanced optimal checkpointing strategies)
  • PyTorch Documentation: torch.utils.checkpoint
  • TensorFlow Documentation: tf.recompute_grad
Posted in AI / ML, LLM Intermediate