Lighthouse Attention: The Training-Time Hierarchy That Makes Quadratic Attention Practical Again

TL;DR:

  • 1.4–1.7× pretraining wall-clock speedup against dense SDPA at 32K–128K context — no inference overhead, no architectural changes.
  • Symmetric pyramid pooling compresses queries, keys, and values together — unlike every prior sparse method that only pools K/V — yielding an S²d attention call instead of NSd.
  • Two-stage training with a recoverability guarantee: Stage 1 trains with Lighthouse selection, Stage 2 recovers under dense SDPA — final loss beats the dense-from-scratch baseline.
  • Selection lives entirely outside the attention kernel, reusing stock FlashAttention on a contiguous gathered sub-sequence — no custom sparse kernels, no entangled selection logic.
  • Context-parallelism ready: pyramid stages run shard-locally with zero cross-rank communication; gathered sub-sequences participate in standard ring attention.

The Θ(N²) Wall Is Still There

FlashAttention solved memory. It did not solve compute. Scaled dot-product attention scales Θ(N²) — double the context, quadruple the FLOPs. At 512K context on one B200 GPU, dense SDPA forward+backward burns enormous compute. The only real way around it: attend to fewer tokens.

Frontier models target million-token windows. Training one on dense attention needs 32 B200 GPUs for attention alone. Teams either train short or pay a tax that doubles with every ~1.4× context increase.

Lighthouse Attention delivers a training-only hierarchy that selects tokens for attention, runs stock FlashAttention on the subset, and scatters the result. At inference, the hierarchy is gone — a standard dense transformer remains. Result: 1.40–1.69× end-to-end pretraining speedup at long context, with matching or lower loss.

Why Prior Sparse Attention Falls Short

Existing sparse methods — NSA, HISA, DSA, MoBA — share two design conventions that are wrong for pretraining.

Asymmetric pooling. Prior work pools only keys and values; queries stay at full resolution. The hierarchy becomes a compressed KV memory. The attention call stays O(NSd) — still linear in N. Lighthouse pools Q symmetrically, turning attention into O(S²d) with S ≪ N. At 512K tokens, forward pass: 21× faster.

Kernel entanglement. Sparse methods embed selection inside custom attention kernels — can’t reuse optimized FlashAttention. Every method ships its own kernels, all slower. Lighthouse gathers selected tokens into a contiguous sequence, then calls standard FlashAttention. No custom sparse kernels. Gather/scatter are simple enough for `torch.compile`.

The deeper difference: a training-time sparsifier must produce weights that work as dense attention at inference. An inference sparsifier only needs to match its backbone. Lighthouse treats recoverability as its central correctness condition.

The Four-Stage Pipeline

A Lighthouse layer wraps SDPA without modification. Four stages: two custom kernels (pool, scatter), two compiled PyTorch ops (score, gather).

Stage 1: Pyramid Construction. Average pooling builds an L-level pyramid from Q, K, V with factor p. Level ℓ has N/p^ℓ tokens summarizing p^ℓ base positions. Same pooling across all projections ensures coherent (Q^(ℓ), K^(ℓ), V^(ℓ)) triples. Cost: Θ(N). The coarsest level is always retained — cheap and guarantees every position a contributor.

Stage 2: Scoring and Stratified Top-K. Per-head ℓ₂-norm scores (‖Q^(ℓ)_i‖₂, ‖K^(ℓ)_i‖₂). A dilated softmax scorer costs ~9% more. Coarser levels inherit scores via max-pooling. A fused chunked-bitonic top-K selects k entries across levels.

The chunked-bitonic design produces stratified top-K, not global top-K. The score stream is chunked; each chunk maintains an in-register top-m buffer. If the k globally highest entries cluster in one chunk, some are displaced by lower entries from other chunks — balanced coverage, no attention collapse. This is deliberate.

Top-K is discrete and non-differentiable — no straight-through estimator, no Gumbel softmax. Gradients flow through gathered Q, K, V into W_Q, W_K, W_V. The projections learn to produce values useful when selected, not scores good at selecting.

Stage 3: Gather and FlashAttention. Selected entries gather into a contiguous sub-sequence of length S = N/p^(L-1) + (L-1)·p·k. At N=1M, L=4, p=4, k=4096: S ≈ 65,000. The sub-sequence is guaranteed dense — gaps would strand tokens without gradient paths during backprop. Standard FlashAttention runs on it.

Stage 4: Scatter-Back. Each output scatters to its p^ℓ base positions via integer-atomic scatter with a causality-preserving shift. Per-position fan-in bounded by L regardless of k.

The Recoverability Experiment

The acid test for any training-time sparsification method: do the resulting weights still produce a competent dense-attention model?

A 530M Llama-3-style decoder (d_model=1024, 30 layers, 8 heads) was trained on C4 at 98K context. Layers 0, 1, 28, and 29 retained dense SDPA — only the interior 26 layers used Lighthouse. The inner attention call within those 26 layers used the same cuDNN-backed SDPA kernel as the dense baseline.

Stage 1 trains with Lighthouse enabled for the majority of the step budget. Stage 2 resumes the checkpoint under dense SDPA — same optimizer state, same dataloader — for a short tail. If Stage 1 had hollowed out the model’s dense-attention capability, Stage 2 recovery would fail.

It does not fail. At 16,000 total steps (~50.3B tokens), three split points (10K+6K, 11K+5K, 12K+4K) were evaluated against a dense-from-scratch baseline. At each resume, loss spikes transiently by 1.12–1.57 nats, then recovers within ~1,000–1,500 SDPA steps and crosses below the dense baseline. By step 16,000, Lighthouse runs reach final losses of 0.6980–0.7102 vs. the dense baseline’s 0.7237 — while using 22.5–27.0 wall-clock hours instead of 37.9.

A simplified Needle-in-a-Haystack evaluation (4K–96K context, retrieval scored as one-token argmax, random chance 10%) confirms the pattern: Lighthouse with k=2048 dilated scorer reaches 0.76 retrieval rate vs. the dense baseline’s 0.72. Larger k is the dominant axis for retrieval; the norm scorer hurts retrieval more than training loss at matched k. Context parallelism scales cleanly to 1M tokens across 32 B200 GPUs with ~10% ring-rotation overhead and no kernel changes.

Engineering Takeaways

Lighthouse Attention is not a universal accelerator. At short contexts, pyramid overhead dominates and it provides no benefit. At 32K+ tokens, it is a drop-in pretraining optimization: no architectural changes, no inference penalty, no custom sparse kernels to maintain.

The method’s value rests on three architectural decisions. Symmetric pooling of queries alongside keys and values — this changes the attention call from O(NSd) to O(S²d). Selection outside the attention kernel — this enables reuse of battle-tested FlashAttention kernels and avoids the maintenance burden of custom sparse kernels. A clean separation between non-differentiable selection and differentiable projection weights — this enables the recoverability guarantee.

For teams training long-context models, the practical path is clear. If sequences exceed 32K tokens, Lighthouse provides meaningful speedup with no inference penalty. The two-stage recipe is mandatory — skipping Stage 2 recovery leaves the model unable to perform dense attention. The optimal configuration (L=3, p=4, k=1536, projection-norm scorer) is well-characterized by the ablation grid. Context-parallelism integration requires no sparse-aware collectives.

The one limitation: Lighthouse is training-only. Autoregressive decoding presents one query at a time, violating the all-queries-co-occur assumption that symmetric pooling depends on. For teams whose bottleneck is pretraining throughput at long context — which describes most frontier-model efforts — Lighthouse is a proven, recoverable speedup.


🔗 Related Articles


Discover more from Susiloharjo

Subscribe to get the latest posts sent to your email.

Discover more from Susiloharjo

Subscribe now to keep reading and get access to the full archive.

Continue reading