Long Context

Ring Attention

How million-token contexts are split across GPUs in a ring — each GPU computes attention on its shard while passing Keys and Values to the next

The Long-Context Memory Wall

Attention has a quadratic memory problem. The attention matrix grows as O(n²) with sequence length n. For a million tokens, the attention matrix alone would need terabytes — far beyond any single GPU.

FlashAttention solved the compute side. But even with FlashAttention, the KV cache for a million-token context is enormous. A single GPU cannot hold a long enough sequence.

Ring Attention solves the memory side by splitting the sequence across multiple GPUs in a ring. Each holds one KV shard; they cooperate so every GPU eventually sees all K and V — but never all at once.

The Ring: Split, Compute, Pass

Imagine 4 GPUs in a ring. Each is assigned a contiguous chunk of the sequence. Each holds the Q, K, V for its own chunk.

The challenge: full attention needs every GPU's Q to see all tokens' K and V. Ring Attention computes attention incrementally as KV blocks circulate.

# Ring Attention: one step around the ring # Each GPU holds: Q_local (its queries), K_local/V_local (its KV shard) for step in range(num_gpus): # 1. Compute partial attention: Q_local against current K_recv, V_recv # (local Q dot incoming KV, accumulate into running softmax) partial = flash_attention(Q_local, K_recv, V_recv) output = online_softmax_merge(output, partial) # 2. Pass current K_recv, V_recv to the next GPU in the ring # (overlap this communication with the NEXT step's compute!) send(K_recv, V_recv) → next_gpu K_recv, V_recv = recv() ← prev_gpu # After num_gpus steps: every GPU has attended to ALL tokens, # but each only ever held one KV shard at a time.

The Magic: Overlapping Communication with Computation

The naive version would be slow: each step waits to receive the next KV block. The key optimization is overlapping: while computing attention with the current KV block, simultaneously send it onward and receive the next.

This double-buffered overlap hides communication completely when compute time exceeds transfer time — which is true for large sequence chunks. The ring spins at the speed of computation.

The Ring in 3D

The visualization below shows 4 GPUs arranged in a ring, each holding a shard. Watch the KV blocks circulate: each GPU computes attention with its local Q and the incoming KV, then passes the KV onward.

4 GPUs in a ring. Colored KV blocks circulate clockwise. Each GPU computes attention with its local Q and the passing KV. Drag to rotate.

Causal Masking: An Asymmetric Twist

For causal models, a token can only attend to previous tokens. In Ring Attention, 'previous' and 'future' depend on which GPU holds which chunk. GPU 1's tokens are 'earlier' than GPU 3's.

So when GPU 3 receives GPU 1's KV, it computes full attention. But when GPU 1 receives GPU 3's KV, it skips (future tokens are masked). A block-diagonal causal mask varies by ring position.

Context Parallelism vs Tensor Parallelism

Ring Attention is context parallelism — the sequence dimension is split across GPUs. This is orthogonal to tensor parallelism (splitting weights) and data parallelism (splitting batch). The three combine.

# Three axes of parallelism (orthogonal, combinable) Tensor Parallelism: split weight matrices → same sequence, different neurons Context Parallelism: split the sequence → different tokens per GPU (Ring Attention) Data Parallelism: split the batch → different examples per GPU # A 1M-token training run might use: # TP=8 (within node), CP=4 (across nodes), DP=32 (batch) # → 1024 GPUs, each holding 1M/(4) = 250K tokens' worth of KV.

Context parallelism matters now because context lengths keep growing (1M, 10M tokens). Ring Attention is the tool that removes that constraint.

Who Uses It: Llama 3, Gemini, and Beyond

Ring Attention and variants power the long-context capabilities of frontier models. Llama 3 uses context parallelism for 128K-token training. Gemini relies on similar sequence-parallel attention for 1M-2M token contexts.

Every time you paste a 500-page document into a modern LLM and it reasons over the whole thing, you are benefiting from the ring.

Key Takeaways

1

Attention memory grows quadratically with sequence length — a million-token context's attention matrix cannot fit on any single GPU.

2

Ring Attention splits the sequence across GPUs in a ring; each holds one KV shard and computes partial attention as KV blocks circulate.

3

Overlapping communication with computation hides transfer latency — the ring spins at the speed of computation.

4

Causal masking adds asymmetry: earlier GPUs skip future GPUs' KV, while later GPUs fully attend to earlier ones.

5

Context parallelism is orthogonal to tensor and data parallelism — together they enable 1M+ token contexts in Llama 3 and Gemini.

Explore related topics:

Dive deeper into attention and scaling: