Update README.md

This commit is contained in:
Patrick von Platen 2021-01-30 22:26:14 +00:00 committed by huggingface-web
parent 0b362f54b7
commit 8d8823668f
1 changed files with 51 additions and 87 deletions

138
README.md
View File

@ -1,96 +1,60 @@
# Wav2Vec2 Acoustic Model fine-tuned on LibriSpeech
---
language: en
datasets:
- librispeech_asr
tags:
- speech
Original model can be found under https://github.com/pytorch/fairseq/tree/master/examples/wav2vec#wav2vec-20.
license: apache-2.0
---
Paper: https://arxiv.org/abs/2006.11477
# Wav2Vec2-Base-960h
## Usage
[Facebook's Wav2Vec2](https://ai.facebook.com/blog/wav2vec-20-learning-the-structure-of-speech-from-raw-audio/)
Make sure you are working on [this branch](https://github.com/huggingface/transformers/tree/add_wav2vec) (which will be merged to master soon hopefully) of transformers:
The base model pretrained and fine-tuned on 960 hours of Librispeech.
```bash
$ git checkout add_wav2vec
```
[Paper](https://arxiv.org/abs/2006.11477)
In the following, we'll show a simple example of how the model can be used for automatic speech recognition.
Authors: Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli
First, let's load the model
**Abstract**
We show for the first time that learning powerful representations from speech audio alone followed by fine-tuning on transcribed speech can outperform the best semi-supervised methods while being conceptually simpler. wav2vec 2.0 masks the speech input in the latent space and solves a contrastive task defined over a quantization of the latent representations which are jointly learned. Experiments using all labeled data of Librispeech achieve 1.8/3.3 WER on the clean/other test sets. When lowering the amount of labeled data to one hour, wav2vec 2.0 outperforms the previous state of the art on the 100 hour subset while using 100 times less labeled data. Using just ten minutes of labeled data and pre-training on 53k hours of unlabeled data still achieves 4.8/8.2 WER. This demonstrates the feasibility of speech recognition with limited amounts of labeled data.
The original model can be found under https://github.com/pytorch/fairseq/tree/master/examples/wav2vec#wav2vec-20.
# Usage
The model can be used as follows to classify some speech input
```python
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained("patrickvonplaten/wav2vec2-base-960h")
```
Next, let's load a dummy librispeech dataset
```python
from datasets import load_dataset
import soundfile as sf
libri_speech_dummy = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
def map_to_array(batch):
speech_array, _ = sf.read(batch["file"])
batch["speech"] = speech_array
return batch
libri_speech_dummy = libri_speech_dummy.map(map_to_array, remove_columns=["file"])
# check out dataset
print(libri_speech_dummy)
input_speech_16kHz = libri_speech_dummy[2]["speech"]
expected_trans = libri_speech_dummy[2]["text"]
```
Cool, now we can run an inference pass to retrieve the logits:
```python
import torch
logits = model(torch.tensor(input_speech_16kHz)[None, :])
# use highest probability logits
pred_ids = torch.argmax(logits[0], axis=-1)
```
Finally, let's decode the prediction.
Let's create a simple CTC-Decoder:
```python
import numpy as np
from itertools import groupby
class Decoder:
def __init__(self, json_dict):
self.dict = json_dict
self.look_up = np.asarray(list(self.dict.keys()))
def decode(self, ids):
converted_tokens = self.look_up[ids]
fused_tokens = [tok[0] for tok in groupby(converted_tokens)]
output = ' '.join(''.join(''.join(fused_tokens).split("<s>")).split("|"))
return output
```
and instantiate with the corresponding dict.
```python
# hard-coded json dict taken from: https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt
json_dict = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3, "|": 4, "E": 5, "T": 6, "A": 7, "O": 8, "N": 9, "I": 10, "H": 11, "S": 12, "R": 13, "D": 14, "L": 15, "U": 16, "M": 17, "W": 18, "C": 19, "F": 20, "G": 21, "Y": 22, "P": 23, "B": 24, "V": 25, "K": 26, "'": 27, "X": 28, "J": 29, "Q": 30, "Z": 31}
decoder = Decoder(json_dict=json_dict)
```
and decode the result
```python
pred_trans = decoder.decode(pred_ids)
print("Prediction:\n", pred_trans)
print("\n" + 50 * "=" + "\n")
print("Correct result:\n", expected_trans)
```
🎉
from transformers import Wav2Vec2Tokenizer, Wav2Vec2Model
from datasets import load_dataset
import soundfile as sf
import torch
# load model and tokenizer
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
# define function to read in sound file
def map_to_array(batch):
speech, _ = sf.read(batch["file"])
batch["speech"] = speech
return batch
# load dummy dataset and read soundfiles
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
ds = ds.map(map_to_array)
# tokenize
input_values = tokenizer(ds["speech"][:2], return_tensors="pt", padding="longest").input_values # Batch size 1
# retrieve logits
logits = model(input_values).logits
# take argmax and decode
predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.batch_decode(predicted_ids)
```