Update README.md

This commit is contained in:
Anton Lozhkov 2021-09-01 10:38:15 +00:00 committed by huggingface-web
parent af304f64a2
commit abfb9d49f5
1 changed files with 39 additions and 1 deletions

View File

@ -12,6 +12,7 @@ license: apache-2.0
# Hubert-Base for Keyword Spotting
[S3PRL speech toolkit](https://github.com/s3prl/s3prl)
[Facebook's Hubert](https://ai.facebook.com/blog/hubert-self-supervised-representation-learning-for-speech-recognition-generation-and-compression)
The base model is pretrained on 16kHz sampled speech audio. When using the model make sure that your speech input is also sampled at 16Khz.
@ -26,4 +27,41 @@ Self-supervised learning (SSL) has proven vital for advancing research in natura
The original model can be found under https://github.com/s3prl/s3prl/tree/master/s3prl/downstream/speech_commands.
# Usage
The base model is [hubert-base-ls960](https://huggingface.co/facebook/hubert-base-ls960)
# Usage examples
You can use the model via the Audio Classification pipeline:
```python
import numpy as np
from datasets import load_dataset
from transformers import pipeline, PreTrainedTokenizer
superb_ks = load_dataset("anton-l/superb_dummy", "ks", split="test")
model = "superb/hubert-base-superb-ks"
tokenizer = PreTrainedTokenizer() # a dummy tokenizer, since the classifier doesn't need a real one
classifier = pipeline("audio-classification", model=model, feature_extractor=model, tokenizer=tokenizer)
audio = np.array(superb_ks[0]["speech"])
labels = classifier(audio, top_k=5)
```
Or use the model directly:
```python
import torch
import numpy as np
from datasets import load_dataset
from transformers import HubertForSequenceClassification, Wav2Vec2FeatureExtractor
superb_ks = load_dataset("anton-l/superb_dummy", "ks", split="test")
model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ks")
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-ks")
audio = np.array(superb_ks[0]["speech"])
# compute attention masks and normalize the waveform if needed
inputs = feature_extractor(audio, sampling_rate=16_000, return_tensors="pt")
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
labels = [model.config.id2label[_id] for _id in predicted_ids.tolist()]
```