Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[whisper] static kv cache #31166

Merged
merged 76 commits into from
Jul 2, 2024

Conversation

sanchit-gandhi
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi commented May 31, 2024

What does this PR do?

Supersedes #28931 and extends it by adding static k/v cache support for Whisper. Also improves the performance of the eager attention implementation by removing un-necessary reshapes (inspired by LlamaAttention).

Similar to #28931, we use a separate cache for the self-attention and cross-attention layers. We define a lightweight EncoderDecoderCache wrapper that holds these two cache classes and implements common base methods (e.g. to_legacy_cache()) by calling the corresponding methods for each cache class.

However, there is one hurdle in enabling compatibility with torch.compile. Namely, we have to determine whether we're in the first decoding step, or second step onwards:

  1. In the first decoding step, we compute the cross-attention k/v states and update the cache accordingly
  2. In the second step onwards, we re-use the k/v states directly from the cache. There’s no further update to the cross-attention cache, since the k/v states are derived entirely from the encoder hidden-states (which stay fixed)

=> the difficulty is in detecting whether we’re in the first decoding step (1), or second step onwards (2). With eager mode, we can condition on past_key_values.get_seq_length() to determine the decoding step. However, for torch.compile this introduces a graph break. Consequently, we add a boolean flag is_updated to the StaticCache class, which informs us whether the cache has been updated or not. The alternative would be to employ the same logic we do in the Flax code, where we re-compute the cross-attention k/v states each time. Benchmarks show this approach is 1.4x slower than adding the CPU flag.

Using the .generate API with Whisper medium, we get approximately 5x speed-up when generating 64 tokens using sdpa attention. Note here that we compile the forward pass only:

bsz dynamic tok/s compiled tok/s Speed-up
1 55.6 270.7 4.9
2 111.4 541.3 4.9
4 222.3 1078.8 4.9
8 446.3 2167.4 4.9
Extended results:

Whisper large-v3

bsz dynamic tok/s compiled tok/s Speed-up
1 41.1 190.4 4.6
2 82.1 381.2 4.6
4 162.9 761.2 4.7
8 331.3 1522.5 4.6

Distil-Whisper distil-large-v3

bsz dynamic tok/s compiled tok/s Speed-up
1 278.7 449.1 1.6
2 560.5 900.3 1.6
4 1113.2 1798.7 1.6
8 2225.0 3592.8 1.6

As expected, the speed-ups for Distil-Whisper are less pronounced:

  1. With only 2 decoder layers, the decoder forward pass is already >6x faster than Whisper, and we have a very small decoder graph that can be compiled
  2. The overhead from the logits post-processing now occupies a greater proportion of the generation time. Compiling the logits processors is a good next step for speeding-up generation further.
Code example:
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, AutoProcessor
import torch
import logging
import time

torch._logging.set_logs(graph_breaks=True, recompiles=True)

torch_device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", attn_implementation="sdpa")
model.to(torch_device, dtype=torch_dtype)

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
inputs = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").to(torch_device)
input_features = inputs.input_features.to(torch_dtype)

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
model.generation_config.cache_implementation = "static"

# compile
for i in range(2):
    model.generate(input_features)

# inference
pred_ids = model.generate(input_features)

In refactoring the eager attention implementation for the cache abstraction, I managed to remove a lot of wasteful .view operations, generally aligning it with LLaMA and giving a performance boost even without compile (TODO: quantify speed-up).

The only regression comes when using FA2 and compile, where we have to introduce a bunch of new .transpose operations for compatibility with the shape of our k/v cache (TODO: quantify regression). This is also a known problem in LLaMA.

There are a few tidy-up points left TODO. Once we're happy with the design, I'll complete the PR with the final checklist items:

  • Fix failing fast tests
  • Tidy docstrings for new arguments (past_key_values, cache_position)
  • Update model doc with FA2 usage
  • Run all Whisper slow tests
  • Run all ASR pipeline slow tests
  • Check gradients propagate correctly when training with output_attentions=True

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice overall.cc @zhenglongjiepheonix I reviewed this one instead of #30949 because it had less changes, sorry that work got duplicated here!

src/transformers/models/whisper/modeling_whisper.py Outdated Show resolved Hide resolved
src/transformers/models/whisper/modeling_whisper.py Outdated Show resolved Hide resolved
src/transformers/models/whisper/modeling_whisper.py Outdated Show resolved Hide resolved
@zhenglongjiepheonix
Copy link
Contributor

zhenglongjiepheonix commented Jun 3, 2024

You can reference my PR #30949 for tests failing part, it passes all the tests that the current main branch passes and will save you a lot of time debugging @sanchit-gandhi

Copy link
Contributor Author

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @gante @ArthurZucker ready for review! (just one TODO left with the custom 4d test)

key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)

# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: copied from LLaMA 2

# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few nits but super nice!

src/transformers/cache_utils.py Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Show resolved Hide resolved
Comment on lines +305 to +308
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_value.is_updated[self.layer_idx] = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again here, we can just check the cache_position. Though this is only possible if we enforce passing cache position when generating, let's add a deprecation saying that generating without cache positions is deprecated

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the cache_position is not a reliable way of deterring which decoding step we're in. Checking cache_position would only check whether we're in the self-attention or cross-attention. We need this is_updated logic to detect whether we're doing the first decoding step (in which case we compute the cross-attention k/v proj, or second decoding step onwards (in which case we completely re-use the cross-attention cache).

Checking cache_position is the same as the "re-computing" results in this comment: #31166 (comment)

We see that we need the is_updated flag to determine which decoding step we're in for fastest perf.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, conditioning on cache_position introduces a graph break, since we condition on the value of cache_position

        if is_cross_attention and past_key_value and cache_position[0] > 0:

=> let's stick with the proposed design if good with you?

src/transformers/models/whisper/modeling_whisper.py Outdated Show resolved Hide resolved
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy with the PR 🔥🔥 Let's goooo

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would maybe just run the slow tests?

Comment on lines +1347 to +1351
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💌

@sanchit-gandhi
Copy link
Contributor Author

sanchit-gandhi commented Jul 2, 2024

Thanks both for the reviews! Confirming that the slow tests pass on the DGX A100.

Going to merge this one to enable static kv cache for:

  1. Short-form generation
  2. Long-form generation without fallback (i.e. sequential generation without temperature fallback)

We'll need a follow-up PR to enable:

  1. Long-form generation with fallback: remember that we dynamically reduce the batch size when we do temperature fallback. We'll need to change this to fixed batch sizes for compile
  2. Long-form chunked generation with pipeline: again, the batch size is set dynamically in the pipeline class, depending on the length of the inputs

@sanchit-gandhi sanchit-gandhi merged commit a970195 into huggingface:main Jul 2, 2024
23 checks passed
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_shape and _reshape are not the same op, is it fine to replace?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do the transpose later to get it into the original format: https://github.com/huggingface/transformers/pull/31166/files#r1652420859

But this is a good point - we don't need to _shape then .transpose the q-states, we can directly get them into the correct format

@SaeedNajafi
Copy link

SaeedNajafi commented Jul 2, 2024

Hi, I am getting some cache errors while doing generation with llama3 and fsdp.
I am using flash_attention_2, and the use_cache=True in the generate function.
Latest transformer from the repo including your recent PR.

[rank1]: Traceback (most recent call last):
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/src/squadv2_finetuning.py", line 129, in <module>
[rank1]:     app.run(main)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/absl/app.py", line 308, in run
[rank1]:     _run_main(main, args)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
[rank1]:     sys.exit(main(argv))
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/src/squadv2_finetuning.py", line 91, in main
[rank1]:     results = train(
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/src/utils/train_utils.py", line 108, in train
[rank1]:     eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity, eval_scores = evaluation(
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/src/utils/train_utils.py", line 405, in evaluation
[rank1]:     for ret_row, ret_loss in model.predict(batch):
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/src/llm.py", line 245, in predict
[rank1]:     answers, log_ps = self.generation_pass(batch)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/src/llm.py", line 216, in generation_pass
[rank1]:     predictions_output = self.model.generate(
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/peft/peft_model.py", line 1491, in generate
[rank1]:     outputs = self.base_model.generate(*args, **kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/transformers/src/transformers/generation/utils.py", line 1945, in generate
[rank1]:     result = self._sample(
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/transformers/src/transformers/generation/utils.py", line 2693, in _sample
[rank1]:     outputs = self(
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/transformers/src/transformers/models/llama/modeling_llama.py", line 1174, in forward
[rank1]:     outputs = self.model(
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/transformers/src/transformers/models/llama/modeling_llama.py", line 978, in forward
[rank1]:     layer_outputs = decoder_layer(
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 857, in forward
[rank1]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 168, in forward
[rank1]:     return self.checkpoint_fn(  # type: ignore[misc]
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
[rank1]:     return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
[rank1]:     return fn(*args, **kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
[rank1]:     return fn(*args, **kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 494, in checkpoint
[rank1]:     ret = function(*args, **kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/transformers/src/transformers/models/llama/modeling_llama.py", line 718, in forward
[rank1]:     hidden_states, self_attn_weights, present_key_value = self.self_attn(
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/transformers/src/transformers/models/llama/modeling_llama.py", line 431, in forward
[rank1]:     key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
[rank1]:   File "/fs01/home/snajafi/codes/llm-research/transformers/src/transformers/cache_utils.py", line 366, in update
[rank1]:     return self.key_cache[layer_idx], self.value_cache[layer_idx]
[rank1]: IndexError: list index out of range

@sanchit-gandhi
Copy link
Contributor Author

Hey @SaeedNajafi - do you have a minimal reproducer you could use to open a new issue on the repo? Thanks!

@ArthurZucker
Copy link
Collaborator

The pipeline needs more work, specifically for longer audios + the merging solution.
Your controbution is welcome, especially for 1) if you have a wroking snippet feel free to add it to the doc

@Jiltseb
Copy link

Jiltseb commented Jul 22, 2024

The pipeline needs more work, specifically for longer audios + the merging solution. Your controbution is welcome, especially for 1) if you have a wroking snippet feel free to add it to the doc

Thanks. I deleted the comment once I saw the PR already in progress #31772 for this exact thing. I think it's better to wait for the merge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants