Linear Attention Sequence Parallelism
Sequence Parallel (SP) serves as a prevalent strategy to handle long sequences that exceed the memory limit of a single GPU. However, existing SP methods do not take advantage of linear attention features, resulting in sub-optimal parallelism efficiency and usability for linear attention-based langu...
Gespeichert in:
Hauptverfasser: | , , , , , |
---|---|
Format: | Artikel |
Sprache: | eng |
Schlagworte: | |
Online-Zugang: | Volltext bestellen |
Tags: |
Tag hinzufügen
Keine Tags, Fügen Sie den ersten Tag hinzu!
|
Zusammenfassung: | Sequence Parallel (SP) serves as a prevalent strategy to handle long
sequences that exceed the memory limit of a single GPU. However, existing SP
methods do not take advantage of linear attention features, resulting in
sub-optimal parallelism efficiency and usability for linear attention-based
language models. In this paper, we introduce Linear Attention Sequence Parallel
(LASP), an efficient SP method tailored to linear attention-based language
models. Specifically, we design an efficient point-to-point communication
mechanism to leverage the right-product kernel trick of linear attention, which
sharply decreases the communication overhead of SP. We also enhance the
practical efficiency of LASP by performing kernel fusion and intermediate state
caching, making the implementation of LASP hardware-friendly on GPU clusters.
Furthermore, we meticulously ensure the compatibility of sequence-level LASP
with all types of batch-level data parallel methods, which is vital for
distributed training on large clusters with long sequences and large batches.
We conduct extensive experiments on two linear attention-based models with
varying sequence lengths and GPU cluster sizes. LASP scales sequence length up
to 4096K using 128 A100 80G GPUs on 1B models, which is 8 times longer than
existing SP methods while being significantly faster. The code is available at
https://github.com/OpenNLPLab/LASP. |
---|---|
DOI: | 10.48550/arxiv.2404.02882 |