Update README.md

This commit is contained in:
Patrick von Platen 2021-08-27 15:37:00 +00:00 committed by huggingface-web
parent 87f7f02dc3
commit 6f0b7949d1
1 changed files with 9 additions and 9 deletions

View File

@ -31,13 +31,13 @@ The original model can be found under https://github.com/pytorch/fairseq/tree/ma
To transcribe audio files the model can be used as a standalone acoustic model as follows: To transcribe audio files the model can be used as a standalone acoustic model as follows:
```python ```python
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from datasets import load_dataset from datasets import load_dataset
import soundfile as sf import soundfile as sf
import torch import torch
# load model and tokenizer # load model and processor
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self") processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self") model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
# define function to read in sound file # define function to read in sound file
@ -51,14 +51,14 @@ To transcribe audio files the model can be used as a standalone acoustic model a
ds = ds.map(map_to_array) ds = ds.map(map_to_array)
# tokenize # tokenize
input_values = tokenizer(ds["speech"][:2], return_tensors="pt", padding="longest").input_values # Batch size 1 input_values = processor(ds["speech"][:2], return_tensors="pt", padding="longest").input_values # Batch size 1
# retrieve logits # retrieve logits
logits = model(input_values).logits logits = model(input_values).logits
# take argmax and decode # take argmax and decode
predicted_ids = torch.argmax(logits, dim=-1) predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.batch_decode(predicted_ids) transcription = processor.batch_decode(predicted_ids)
``` ```
## Evaluation ## Evaluation
@ -67,7 +67,7 @@ To transcribe audio files the model can be used as a standalone acoustic model a
```python ```python
from datasets import load_dataset from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import soundfile as sf import soundfile as sf
import torch import torch
from jiwer import wer from jiwer import wer
@ -76,7 +76,7 @@ from jiwer import wer
librispeech_eval = load_dataset("librispeech_asr", "clean", split="test") librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to("cuda") model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to("cuda")
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self") processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
def map_to_array(batch): def map_to_array(batch):
speech, _ = sf.read(batch["file"]) speech, _ = sf.read(batch["file"])
@ -86,7 +86,7 @@ def map_to_array(batch):
librispeech_eval = librispeech_eval.map(map_to_array) librispeech_eval = librispeech_eval.map(map_to_array)
def map_to_pred(batch): def map_to_pred(batch):
inputs = tokenizer(batch["speech"], return_tensors="pt", padding="longest") inputs = processor(batch["speech"], return_tensors="pt", padding="longest")
input_values = inputs.input_values.to("cuda") input_values = inputs.input_values.to("cuda")
attention_mask = inputs.attention_mask.to("cuda") attention_mask = inputs.attention_mask.to("cuda")
@ -94,7 +94,7 @@ def map_to_pred(batch):
logits = model(input_values, attention_mask=attention_mask).logits logits = model(input_values, attention_mask=attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1) predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.batch_decode(predicted_ids) transcription = processor.batch_decode(predicted_ids)
batch["transcription"] = transcription batch["transcription"] = transcription
return batch return batch