Alleviating Forgetfulness of Linear Attention by Hybrid Sparse Attention and Contextualized Learnable Token Eviction
Welcome to the reference implementation for the paper Alleviating Forgetfulness of Linear Attention by Hybrid Sparse Attention and Contextualized Learnable Token Eviction! Here provides the code for the hybrid sparse & linear attention, Learnable Token Eviction layer, SWA+KV sparse attention prefilling & decoding kernels, and NSA decoding kernels, based upon the Flash Linear Attention.
Linear-attention models that compress the entire input sequence into a fixed-size recurrent state offer an efficient alternative to Transformers, but their finite memory induces forgetfulness that harms retrieval-intensive tasks. To mitigate the issue, we explore a series of hybrid models that restore direct access to past tokens. We interleave token mixers with intermediate time and space complexity between linear and full attention, including sparse attention with token eviction, and the query-aware native sparse attention. Particularly, we propose a novel learnable token eviction approach. Combined with sliding-window attention, an end-to-end trainable lightweight CNN aggregates information from both past and future adjacent tokens to adaptively retain a limited set of critical KV-pairs per head, maintaining linear attention's constant time and space complexity. Efficient Triton kernels for the sparse attention mechanisms are provided. Empirical evaluations on retrieval-intensive benchmarks support the effectiveness of our approaches.
The code is developed upon Flash Linear Attention and Flame, so please refer to their repositories for general instructions. As for installation, please run
pip install -r requirements.txt
pip install git+https://github.com/fla-org/flame@f5c618c6 --no-deps
Note the required versions of torch and transformers to avoid any imcompatibility.
We are undergoing our internal administrative process and our pretrained models will be soon released. To train an 1.4B laLTE model by yourself on 2 GPUs, run the following command:
NNODE=1 NGPU=2 LOG_RANK=0 bash train.sh \
--job.config_file configs/fla.toml \
--job.dump_folder exp/gated-deltanet-lte-1.4B-30B --model.config configs/gated_delta_net_lte_1.4B.json \
--model.tokenizer_path fla-hub/transformer-1.3B-100B \
--optimizer.name AdamW --optimizer.lr 3e-4 \
--lr_scheduler.warmup_steps 512 --lr_scheduler.lr_min 0.1 --lr_scheduler.decay_type cosine \
--training.data_parallel_shard_degree 2 --training.data_parallel_replicate_degree 1 \
--training.batch_size 8 --training.seq_len 4096 --training.gradient_accumulation_steps 16 \
--training.steps 30720 --training.max_norm 1.0 --training.skip_nan_inf \
--training.dataset HuggingFaceFW/fineweb-edu --training.dataset_name sample-100BT --training.dataset_split train \
--training.num_workers 32 --training.prefetch_factor 2 --training.seed 42 \
--checkpoint.interval 320 --checkpoint.load_step -1 --metrics.log_freq 4 --profiling.profile_freq 5120
Other model configs are available in configs/ directory. Please refer to Flame documentation or code for more details about the command arguments.
As for evaluations, we run lm-eval like
python -m evals.harness --output_path results/gated-deltanet-lte-0.4B-10B/ruler --tasks niah_single_1,niah_single_2,niah_single_3 \
--model_args pretrained=MODEL_PATH,dtype=float32,max_length=16384,trust_remote_code=True,add_bos_token=True \
--metadata '{"max_seq_lengths": [1024, 2048, 4096]}' --batch_size 32 --show_config --trust_remote_code
You may need to run pip install lm-eval[ruler] beforehand.
If you find the paper or code useful, please consider citing:
@misc{he_alleviating_2025,
title = {Alleviating {Forgetfulness} of {Linear} {Attention} by {Hybrid} {Sparse} {Attention} and {Contextualized} {Learnable} {Token} {Eviction}},
url = {http://arxiv.org/abs/2510.20787},
doi = {10.48550/arXiv.2510.20787},
publisher = {arXiv},
author = {He, Mutian and Garner, Philip N.},
month = oct,
year = {2025},
note = {arXiv:2510.20787 [cs]},
}

