Force <|startoftranscript|>
Updates the `forced_decoder_ids` to force the `<|startoftranscript|>` token at position 1. This is to match the official Whisper implementation, which always predicts `<|startoftranscript|>` at position 1: ```python #!pip install git+https://github.com/openai/whisper.git import whisper from datasets import load_dataset import torch device = "cuda" if torch.cuda.is_available() else "cpu" model = whisper.load_model("tiny.en").to(device) tokenizer = whisper.tokenizer.get_tokenizer(False, task="transcribe", language="en") tokenizer = tokenizer.tokenizer librispeech = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") def to_pad_to_mel(array): """Static function which: 1. Pads/trims a list of audio arrays to a max length of 30s 2. Computes log-mel filter coefficients from padded/trimmed audio sequences Inputs: array: list of audio arrays Returns: input_ids: torch.tensor of log-mel filter bank coefficients """ padded_input = whisper.pad_or_trim(np.asarray(array, dtype=np.float32)) input_ids = whisper.log_mel_spectrogram(padded_input) return input_ids audio_array = librispeech[0]["audio"]["array"] log_mel = to_pad_to_mel(audio_array).unsqueeze(0) tokens = model.generate(log_mel.to(device))[0] transcript = tokenizer.decode(tokens, skip_special_tokens=False) print(transcript) ``` **Print Output:** ``` <|startoftranscript|><|notimestamps|> Mr. Quilter is the apostle of the middle classes, and we are glad to ```
This commit is contained in:
parent
f3b5e97e6e
commit
eb2a3e30c3
|
@ -26,6 +26,10 @@
|
||||||
"forced_decoder_ids": [
|
"forced_decoder_ids": [
|
||||||
[
|
[
|
||||||
1,
|
1,
|
||||||
|
50257
|
||||||
|
]
|
||||||
|
[
|
||||||
|
2,
|
||||||
50362
|
50362
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
|
|
Loading…
Reference in New Issue