From eb2a3e30c359003cd595b8aa5a883c68d40d905b Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi Date: Mon, 5 Dec 2022 12:28:32 +0000 Subject: [PATCH] Force <|startoftranscript|> Updates the `forced_decoder_ids` to force the `<|startoftranscript|>` token at position 1. This is to match the official Whisper implementation, which always predicts `<|startoftranscript|>` at position 1: ```python #!pip install git+https://github.com/openai/whisper.git import whisper from datasets import load_dataset import torch device = "cuda" if torch.cuda.is_available() else "cpu" model = whisper.load_model("tiny.en").to(device) tokenizer = whisper.tokenizer.get_tokenizer(False, task="transcribe", language="en") tokenizer = tokenizer.tokenizer librispeech = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") def to_pad_to_mel(array): """Static function which: 1. Pads/trims a list of audio arrays to a max length of 30s 2. Computes log-mel filter coefficients from padded/trimmed audio sequences Inputs: array: list of audio arrays Returns: input_ids: torch.tensor of log-mel filter bank coefficients """ padded_input = whisper.pad_or_trim(np.asarray(array, dtype=np.float32)) input_ids = whisper.log_mel_spectrogram(padded_input) return input_ids audio_array = librispeech[0]["audio"]["array"] log_mel = to_pad_to_mel(audio_array).unsqueeze(0) tokens = model.generate(log_mel.to(device))[0] transcript = tokenizer.decode(tokens, skip_special_tokens=False) print(transcript) ``` **Print Output:** ``` <|startoftranscript|><|notimestamps|> Mr. Quilter is the apostle of the middle classes, and we are glad to ``` --- config.json | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/config.json b/config.json index 18c737b..4a0f799 100644 --- a/config.json +++ b/config.json @@ -26,6 +26,10 @@ "forced_decoder_ids": [ [ 1, + 50257 + ] + [ + 2, 50362 ] ],