Published on

Speculative Decoding: A 47% Leap in Whisper ASR Performance

Authors

Speculative Decoding leads to a 47% improvement in Whisper's performance

Abstract

This blog post delves into the groundbreaking advancement in automatic speech recognition (ASR) using OpenAI's Whisper model enhanced by Speculative Decoding. Speculative Decoding, a technique employing a smaller, assistant model, has been shown to significantly improve the performance and efficiency of the larger Whisper model. Our experiment uses the openai/whisper-large-v2 as the base model and distil-whisper/distil-large-v2 as the assistant model. Conducted on two distinct audio file lengths, the results showcase the remarkable efficiency gains, with Speculative Decoding combined with Flash Attention 2 offering a 47.91% improvement in processing speed for shorter audio files. This post provides comprehensive insight into the methodology, experimental setup, and implications of these findings for the future of ASR technology.

Introduction

In the realm of automatic speech recognition (ASR), OpenAI's Whisper model represents a significant advancement. However, even the most advanced models can face challenges in terms of performance and speed. This is where Speculative Decoding comes into play, offering a novel approach to enhancing the performance of Whisper.

Speculative Decoding is a technique that allows larger models to see potential performance gains by utilizing a smaller model in parallel. The smaller model acts as a draft, or "assistant", model, and predicts the next token in the sequence. The larger model then uses the assistant's predicted token as the next token in th sequence.

Process of Events

  1. The draft model predicts a sequence of tokens that may come next.
  2. The target model estimates the probability of the next token.
  3. The target model "approves" the estimate of the draft token assuming the probability meets the desired output of the target models parameters.
  4. If the probability is too low, the target model will reject the draft token and predict the next token itself.
  5. The final token signifying the End of Sequence is selected from the target model.
  6. The tokens selected from the draft model are concatenated with the final token from the target model to form the final output.

Experiment

The purpose of this experiment is to see if Speculative Decoding can improve the performance of Whisper in various conditions.

The base model used is openai/whisper-large-v2. The draft model used is distil-whisper/distil-large-v2.

The model from disitl-whisper is 49% smaller than the base model and 6 times faster (Gandhi, von Platen, Rush, 2023).

The experiment was run on a single nvidia 4090 GPU with 24GB of VRAM.

Process

The experiment was run on two different audio files, one short and one long.

Each test was ran 5 times and the average completion time were recorded.

gpu_information

The files consisted of a 2m30s audio file and a 1h16m audio file and were chosen because they represent different sides of the usage spectrum.

This helps to determine if Speculative Decoding can improve performance in both short and long audio files.

Results

CategoryAverage Time (s)Percent Increase
No Speculative (Control, Short)15.17-
No Speculative + Flash 2 (Short)14.27-
Speculative Decoding (Short)9.1139.95%
Speculative + Flash 2 (Short)7.9047.91%
No Speculative (Control, Long)258.08-
No Speculative + Flash 2 (Long)231.98-
Speculative Decoding (Long)181.5829.64%
Speculative + Flash 2 (Long)168.0534.89%

combined_visual combined_visual2

Control

This was the code used to achieve the control results.

# Function to load and resample an audio file
def load_and_resample_audio(file_path, target_sample_rate=16000):
    audio, _ = librosa.load(file_path, sr=target_sample_rate)
    return {"array": audio, "sampling_rate": target_sample_rate}

# Set device and data type for PyTorch
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Load the speech recognition model
model_id = "openai/whisper-large-v2"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=True
)
model.to(device)

# Load the processor
processor = AutoProcessor.from_pretrained(model_id)


pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    tokenizer=processor.tokenizer,
    torch_dtype=torch_dtype,
    device=device,
)

# Specify the path to your audio file here
audio_file_path = 'audio/audio-file.mp3'

# Load and resample the audio file
audio_data = load_and_resample_audio(audio_file_path)

# Measure the time taken for speech recognition
start_time = time.time()
result = pipe(audio_data)
end_time = time.time()

# Calculate and print the time taken
time_taken = end_time - start_time
print(f"Time to completion: {time_taken:.2f} seconds")
print(result["text"])

and this was the script used to test the speculative decoding methods.

from transformers import pipeline, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
import librosa
import time

# Function to load and resample an audio file
def load_and_resample_audio(file_path, target_sample_rate=16000):
    audio, _ = librosa.load(file_path, sr=target_sample_rate)
    return {"array": audio, "sampling_rate": target_sample_rate}

# Set device and data type for PyTorch
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Load the assistant model
assistant_model_id = "distil-whisper/distil-large-v2"
assistant_model = AutoModelForCausalLM.from_pretrained(
    assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True,
)
assistant_model.to(device)

# Load the speech recognition model
model_id = "openai/whisper-large-v2"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True,
)
model.to(device)

# Load the processor
processor = AutoProcessor.from_pretrained(model_id)

# Initialize the pipeline
pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    generate_kwargs={"assistant_model": assistant_model},
    torch_dtype=torch_dtype,
    device=device,
)

# Specify the path to your audio file here
audio_file_path = 'audio/audio-file.mp3'

# Load and resample the audio file
audio_data = load_and_resample_audio(audio_file_path)

# Measure the time taken for speech recognition
start_time = time.time()
result = pipe(audio_data)
end_time = time.time()

# Calculate and print the time taken
time_taken = end_time - start_time
print(f"Time to completion: {time_taken:.2f} seconds")
print(result["text"])
  1. Time Comparison for Short File (2m30s)

Short_File

Speculative Decoding

  • For each of the five runs, the completion times decrease from 10.79 seconds to 8.26 seconds, indicating an improvement in efficiency with each run.

Speculative Decoding + Flash Attention 2

  • The completion times range from 8.23 seconds to 8.60 seconds, showing a relatively stable and improved performance compared to Speculative Decoding.

No Speculative Decoding

  • Completion times are higher, ranging from 16.55 seconds to 14.48 seconds, showing a consistent performance but slower than the speculative methods.

As seen in the chart, Speculative Decoding methods, particularly when combined with Flash Attention 2, can demonstrate a significant reduction in average completion time compared to Standard Decoding methods.

  1. Time Comparison for Long File (1h16m)

Long_File

Speculative Decoding

  • The completion times for Speculative Decoding across the five runs fluctuate, ranging from 175.32 seconds to 190.95 seconds. This indicates variability in the efficiency of the method across different runs.

Speculative Decoding + Flash Attention 2

  • This method demonstrates more consistent and efficient performance compared to standard Speculative Decoding. The completion times for the five runs range narrowly from 166.95 seconds to 169.78 seconds, showing an improvement in processing speed.

No Speculative Decoding

  • The completion times for this control group are significantly higher, ranging from 287.55 seconds to 243.98 seconds. This showcases that the absence of speculative decoding methods leads to longer processing times.

Conclusion

The exploration of Speculative Decoding in enhancing the performance of OpenAI's Whisper model yields promising results. Our experiment demonstrates a significant improvement in processing times, especially notable in shorter audio files. The combination of Speculative Decoding with Flash Attention 2 results in nearly a 48% reduction in average completion time compared to the standard decoding method for short files. This enhancement is still substantial, though slightly less pronounced, in longer audio files.

These findings suggest that Speculative Decoding, particularly when combined with efficient attention mechanisms like Flash Attention 2, is not only viable but highly effective in optimizing ASR models. The technique's ability to leverage the predictive capabilities of a smaller model to guide and accelerate the larger model's performance marks a significant step forward in the field of speech recognition. This approach could potentially be applied to other large models, opening new avenues for efficiency improvements in various AI applications.

The future of ASR and AI model efficiency looks brighter with these advancements, promising more responsive and resource-efficient applications. As technology continues to evolve, techniques like Speculative Decoding will play a crucial role in pushing the boundaries of what's possible in the world of artificial intelligence.

Citations

@misc{gandhi2023distilwhisper,
      title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling}, 
      author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
      year={2023},
      eprint={2311.00430},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
@misc{leviathan2023fast,
      title={Fast Inference from Transformers via Speculative Decoding}, 
      author={Yaniv Leviathan and Matan Kalman and Yossi Matias},
      year={2023},
      eprint={2211.17192},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}