Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Below shows the generation speed gain by using FastSeq.
| Model | W/O FastSeq (in samples/s) | W/ FastSeq (in samples/s) | Speedup |
|------------------|:--------------------------:|:-------------------------:|:-----:|
| [ProphetNet](examples/prophetnet/README.md) | 2.8 | 10.7 | 3.8x |
| [Bart (`fs`)](examples/bart/README.md) | 2.4 | 19.7 | 8.2x |
| [Bart (`fs`)](examples/bart/README.md) | 2.4 | 25.3 | 10.5x |
| [Bart (`hf`)](examples/bart/README.md#speedup-bart-huggingface-transformers-version-by-using-fastseq) | 2.5 | 12.4 | 5.0x |
| [DistilBart (`hf`)](examples/distilbart/README.md) | 3.4 | 18.5 | 5.4x |
| [T5 (`hf`)](examples/t5/README.md) | 8.7 | 31.3 | 3.6x |
Expand Down
2 changes: 2 additions & 0 deletions benchmarks/benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ bss="$1"; shift
shell=benchmark_fs.sh
if [ "$framework" = "fairseq+fastseq" ]; then
:
elif [ "$framework" = "fairseq+fastseq+el" ]; then
:
elif [ "$framework" = "fairseq" ]; then
:
elif [ "$framework" = "transformers+fastseq" ]; then
Expand Down
28 changes: 25 additions & 3 deletions benchmarks/benchmark_fs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ if [[ $framework == fairseq ]]; then
else
util=fairseq-generate
fi
elif [[ "$framework" == "fairseq+fastseq" ]]; then
elif [[ "$framework" == fairseq+fastseq* ]]; then
ver1=`pip show fairseq | awk '{if($1=="Version:")print $2}'`
ver2=`pip show fastseq | awk '{if($1=="Version:")print $2}'`
framework_versioned="fairseq_v$ver1+fastseq_v$ver2"
Expand All @@ -51,7 +51,7 @@ elif [[ "$framework" == "fairseq+fastseq" ]]; then
util=fastseq-generate-for-fairseq
fi
fi

echo $framework
mark1=" with beam="
mark2="| Evaluated "
for i in `seq $LOOP`; do
Expand Down Expand Up @@ -94,6 +94,28 @@ for bs in "${bs_list[@]}"; do
--remove-bpe \
--gen-subset $split $* \
> $STDOUT_FILE 2> $STDERR_FILE
elif [[ $framework == "fairseq+fastseq+el" ]]; then
echo "USING EL"
$util \
$data_dir \
--path $model_path \
--fp16 \
--task translation \
--batch-size $bs \
--gen-subset $split \
--truncate-source \
--bpe gpt2 \
--beam 4 \
--num-workers 4 \
--min-len 55 \
--max-len-b 140 \
--no-repeat-ngram-size 3 \
--lenpen 2.0 \
--use-el-attn \
`#--print-alignment` \
`#--print-step # KeyError: steps` \
--skip-invalid-size-inputs-valid-test $* \
> $STDOUT_FILE 2> $STDERR_FILE
else
$util \
$data_dir \
Expand All @@ -110,7 +132,7 @@ for bs in "${bs_list[@]}"; do
--max-len-b 140 \
--no-repeat-ngram-size 3 \
--lenpen 2.0 \
`#--print-alignment` \
`#--print-alignment` \
`#--print-step # KeyError: steps` \
--skip-invalid-size-inputs-valid-test $* \
> $STDOUT_FILE 2> $STDERR_FILE
Expand Down
9 changes: 9 additions & 0 deletions benchmarks/models/fs_bart.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ source utils.sh
valid \
32/64/128/256 \
--max-tokens 131072
./benchmark.sh \
fairseq+fastseq+el \
bart.large.cnn \
cnn_dm/len-1024.bin \
valid \
320

# Accuracy
grep "bart.large.cnn cnn_dm/len-1024.bin valid " perf \
Expand All @@ -43,3 +49,6 @@ grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm/len-1024.bin valid 12
grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm/len-1024.bin valid 256 " perf \
| awk '{s+=$13}END{print s/NR}' \
| ./range.sh 19 100
grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm/len-1024.bin valid 320 " perf \
| awk '{s+=$13}END{print s/NR}' \
| ./range.sh 25 100
9 changes: 5 additions & 4 deletions examples/bart/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ BART is sequence-to-sequence model trained with denoising as pretraining objecti

- CNN daily mail validation data, NVIDIA-V100-16GB

| BatchSize | 32 | 64 | 128 | 256* |
|:----------------:|:-------------:|:---------------:|:--------------:|:--------------:|
| fairseq-0.9.0 | 2.4 samples/s | OOM | OOM | OOM |
| above + fastseq | 8.1 samples/s | 13.3 samples/s | 18.4 samples/s | 19.7 samples/s |
| BatchSize | 32 | 64 | 128 | 256* | 320 |
|:----------------:|:-------------:|:---------------:|:--------------:|:--------------:|:--------------:|
| fairseq-0.9.0 | 2.4 samples/s | OOM | OOM | OOM | OOM |
| above + fastseq | 8.1 samples/s | 13.3 samples/s | 18.4 samples/s | 19.7 samples/s | OOM |
| above + el_attn | --- samples/s | ---- samples/s | ---- samples/s | --- samples/s | 25.3 samples/s |
\* with `--max-tokens 131072` to avoid attn_weights' total number of elements exceed INT.MAX, which is a limitation for softmax op.

### Model
Expand Down
4 changes: 4 additions & 0 deletions fastseq/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@
# supported versions of fairseq
MIN_FAIRSEQ_VERSION = '0.9.0'
MAX_FAIRSEQ_VERSION = '0.9.0'

#Set following variable to use Efficient-Lossless Attention
USE_EL_ATTN = os.getenv('USE_EL_ATTN', '0')

7 changes: 7 additions & 0 deletions fastseq/optimizer/fairseq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
from fastseq.config import FASTSEQ_VERSION, MAX_FAIRSEQ_VERSION, MIN_FAIRSEQ_VERSION
from fastseq.logging import get_logger
from fastseq.utils.api_decorator import OPTIMIZED_CLASSES
from fastseq import config

#Efficient-Lossless Attention
use_el_attn = config.USE_EL_ATTN == '1'

logger = get_logger(__name__, logging.INFO)

Expand Down Expand Up @@ -46,6 +50,9 @@ def apply_fairseq_optimization():
return

import fastseq.optimizer.fairseq.beam_search_optimizer # pylint: disable=import-outside-toplevel
if use_el_attn:
import fastseq.optimizer.fairseq.el_attention_optimizer # pylint: disable=import-outside-toplevel

import fastseq.optimizer.fairseq.generate # pylint: disable=import-outside-toplevel
_update_fairseq_model_registration()
logger.info(f"fairseq(v{fairseq.__version__}) has been optimized by "
Expand Down
22 changes: 17 additions & 5 deletions fastseq/optimizer/fairseq/beam_search_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional

import torch
import logging
import torch.nn.functional as F
from torch import Tensor

Expand All @@ -17,8 +18,19 @@
from fairseq.sequence_generator import SequenceGenerator
from fastseq.ops.ngram_repeat_block import NGramRepeatBlock
from fastseq.utils.api_decorator import replace
from fastseq import config
from fastseq.logging import get_logger

@replace(BeamSearch)
logger = get_logger(__name__, logging.INFO)


#Efficient-Lossless Attention
use_el_attn = config.USE_EL_ATTN == '1'
if use_el_attn:
logger.info(f"Using Efficient-Lossless Attention optimization")


@replace(BeamSearch, True)
class BeamSearchV2(BeamSearch):

def step(self, step, lprobs, scores):
Expand Down Expand Up @@ -47,7 +59,7 @@ def step(self, step, lprobs, scores):
self.indices_buf.fmod_(vocab_size)
return self.scores_buf, self.indices_buf, self.beams_buf

@replace(TransformerEncoder)
@replace(TransformerEncoder, not use_el_attn)
class TransformerEncoderV2(TransformerEncoder):
"""
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
Expand All @@ -62,7 +74,7 @@ def _reorder_encoder_out(self, encoder_out, new_order):
return encoder_out


@replace(TransformerModel)
@replace(TransformerModel, not use_el_attn)
class TransformerModelV2(TransformerModel):
""" Represent the BART model."""

Expand All @@ -74,7 +86,7 @@ def make_generation_fast_(self, **kwargs):
self.encoder.reorder_encoder_out = self.encoder._reorder_encoder_out


@replace(MultiheadAttention)
@replace(MultiheadAttention, not use_el_attn)
class MultiheadAttentionV2(MultiheadAttention):
"""Multi-headed attention.

Expand Down Expand Up @@ -426,7 +438,7 @@ def make_generation_fast_(self, beamable_mm_beam_size=None, **kwargs):
self.set_beam_size(beamable_mm_beam_size)


@replace(SequenceGenerator)
@replace(SequenceGenerator, not use_el_attn)
class SequenceGeneratorV2(SequenceGenerator):
"""
Sequence Generator is optimized by reducing the cached memory usage
Expand Down
Loading