TL;DR:
- 1.4–1.7× pretraining wall-clock speedup against dense SDPA at 32K–128K context — no inference overhead, no architectural changes.
- Symmetric pyramid pooling compresses queries, keys, and values together — unlike every prior sparse method that only pools K/V — yielding an S²d attention call instead of NSd.
- Two-stage training with a recoverability guarantee: Stage 1 trains with Lighthouse selection, Stage 2 recovers under dense SDPA — final loss beats the dense-from-scratch baseline.
- Selection lives entirely outside the attention kernel, reusing stock FlashAttention on a contiguous gathered sub-sequence — no custom sparse kernels, no entangled selection logic.
- Context-parallelism ready: pyramid stages run shard-locally with zero cross-rank communication; gathered sub-sequences participate in standard ring attention.
—
The Θ(N²) Wall Is Still There
FlashAttention solved memory. It did not solve compute. Scaled dot-product attention scales Θ(N²) — double the context, quadruple the FLOPs. At 512K context on one B200 GPU, dense SDPA forward+backward burns enormous compute. The only real way around it: attend to fewer tokens.
Frontier models target million-token windows. Training one on dense attention needs 32 B200 GPUs for attention alone. Teams either train short or pay a tax that doubles with every ~1.4× context increase.
Lighthouse Attention delivers a training-only hierarchy that selects tokens for attention, runs stock FlashAttention on the subset, and scatters the result. At inference, the hierarchy is gone — a standard dense transformer remains. Result: 1.40–1.69× end-to-end pretraining speedup at long context, with matching or lower loss.
—
Why Prior Sparse Attention Falls Short
Existing sparse methods — NSA, HISA, DSA, MoBA — share two design conventions that are wrong for pretraining.
Asymmetric pooling. Prior work pools only keys and values; queries stay at full resolution. The hierarchy becomes a compressed KV memory. The attention call stays O(NSd) — still linear in N. Lighthouse pools Q symmetrically, turning attention into O(S²d) with S ≪ N. At 512K tokens, forward pass: 21× faster.
Kernel entanglement. Sparse methods embed selection inside custom attention kernels — can’t reuse optimized FlashAttention. Every method ships its own kernels, all slower. Lighthouse gathers selected tokens into a contiguous sequence, then calls standard FlashAttention. No custom sparse kernels. Gather/scatter are simple enough for `torch.compile`.
The deeper difference: a training-time sparsifier must produce weights that work as dense attention at inference. An inference sparsifier only needs to match its backbone. Lighthouse treats recoverability as its central correctness condition.
—
The Four-Stage Pipeline
A Lighthouse layer wraps SDPA without modification. Four stages: two custom kernels (pool, scatter), two compiled PyTorch ops (score, gather).
Stage 1: Pyramid Construction. Average pooling builds an L-level pyramid from Q, K, V with factor p. Level ℓ has N/p^ℓ tokens summarizing p^ℓ base positions. Same pooling across all projections ensures coherent (Q^(ℓ), K^(ℓ), V^(ℓ)) triples. Cost: Θ(N). The coarsest level is always retained — cheap and guarantees every position a contributor.
Stage 2: Scoring and Stratified Top-K. Per-head ℓ₂-norm scores (‖Q^(ℓ)_i‖₂, ‖K^(ℓ)_i‖₂). A dilated softmax scorer costs ~9% more. Coarser levels inherit scores via max-pooling. A fused chunked-bitonic top-K selects k entries across levels.
The chunked-bitonic design produces stratified top-K, not global top-K. The score stream is chunked; each chunk maintains an in-register top-m buffer. If the k globally highest entries cluster in one chunk, some are displaced by lower entries from other chunks — balanced coverage, no attention collapse. This is deliberate.
Top-K is discrete and non-differentiable — no straight-through estimator, no Gumbel softmax. Gradients flow through gathered Q, K, V into W_Q, W_K, W_V. The projections learn to produce values useful when selected, not scores good at selecting.
Stage 3: Gather and FlashAttention. Selected entries gather into a contiguous sub-sequence of length S = N/p^(L-1) + (L-1)·p·k. At N=1M, L=4, p=4, k=4096: S ≈ 65,000. The sub-sequence is guaranteed dense — gaps would strand tokens without gradient paths during backprop. Standard FlashAttention runs on it.
Stage 4: Scatter-Back. Each output scatters to its p^ℓ base positions via integer-atomic scatter with a causality-preserving shift. Per-position fan-in bounded by L regardless of k.
—
The Recoverability Experiment
The acid test for any training-time sparsification method: do the resulting weights still produce a competent dense-attention model?
A 530M Llama-3-style decoder (d_model=1024, 30 layers, 8 heads) was trained on C4 at 98K context. Layers 0, 1, 28, and 29 retained dense SDPA — only the interior 26 layers used Lighthouse. The inner attention call within those 26 layers used the same cuDNN-backed SDPA kernel as the dense baseline.
Stage 1 trains with Lighthouse enabled for the majority of the step budget. Stage 2 resumes the checkpoint under dense SDPA — same optimizer state, same dataloader — for a short tail. If Stage 1 had hollowed out the model’s dense-attention capability, Stage 2 recovery would fail.
It does not fail. At 16,000 total steps (~50.3B tokens), three split points (10K+6K, 11K+5K, 12K+4K) were evaluated against a dense-from-scratch baseline. At each resume, loss spikes transiently by 1.12–1.57 nats, then recovers within ~1,000–1,500 SDPA steps and crosses below the dense baseline. By step 16,000, Lighthouse runs reach final losses of 0.6980–0.7102 vs. the dense baseline’s 0.7237 — while using 22.5–27.0 wall-clock hours instead of 37.9.
A simplified Needle-in-a-Haystack evaluation (4K–96K context, retrieval scored as one-token argmax, random chance 10%) confirms the pattern: Lighthouse with k=2048 dilated scorer reaches 0.76 retrieval rate vs. the dense baseline’s 0.72. Larger k is the dominant axis for retrieval; the norm scorer hurts retrieval more than training loss at matched k. Context parallelism scales cleanly to 1M tokens across 32 B200 GPUs with ~10% ring-rotation overhead and no kernel changes.
—
Engineering Takeaways
Lighthouse Attention is not a universal accelerator. At short contexts, pyramid overhead dominates and it provides no benefit. At 32K+ tokens, it is a drop-in pretraining optimization: no architectural changes, no inference penalty, no custom sparse kernels to maintain.
The method’s value rests on three architectural decisions. Symmetric pooling of queries alongside keys and values — this changes the attention call from O(NSd) to O(S²d). Selection outside the attention kernel — this enables reuse of battle-tested FlashAttention kernels and avoids the maintenance burden of custom sparse kernels. A clean separation between non-differentiable selection and differentiable projection weights — this enables the recoverability guarantee.
For teams training long-context models, the practical path is clear. If sequences exceed 32K tokens, Lighthouse provides meaningful speedup with no inference penalty. The two-stage recipe is mandatory — skipping Stage 2 recovery leaves the model unable to perform dense attention. The optimal configuration (L=3, p=4, k=1536, projection-norm scorer) is well-characterized by the ablation grid. Context-parallelism integration requires no sparse-aware collectives.
The one limitation: Lighthouse is training-only. Autoregressive decoding presents one query at a time, violating the all-queries-co-occur assumption that symmetric pooling depends on. For teams whose bottleneck is pretraining throughput at long context — which describes most frontier-model efforts — Lighthouse is a proven, recoverable speedup.
🔗 Related Articles
- 5 Agent Projects to Build with Gemini 3.5 Flash
- Google Antigravity 2.0 Shifts Dev to Agent-First at I/O 2026
- Google I/O 2026 AI Roundup: Every Feature You Actually Need to Know
Discover more from Susiloharjo
Subscribe to get the latest posts sent to your email.