diff --git a/README.md b/README.md index e1358af..28c0731 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ license: apache-2.0 # Hubert-Base for Keyword Spotting [S3PRL speech toolkit](https://github.com/s3prl/s3prl) + [Facebook's Hubert](https://ai.facebook.com/blog/hubert-self-supervised-representation-learning-for-speech-recognition-generation-and-compression) The base model is pretrained on 16kHz sampled speech audio. When using the model make sure that your speech input is also sampled at 16Khz. @@ -26,4 +27,41 @@ Self-supervised learning (SSL) has proven vital for advancing research in natura The original model can be found under https://github.com/s3prl/s3prl/tree/master/s3prl/downstream/speech_commands. -# Usage +The base model is [hubert-base-ls960](https://huggingface.co/facebook/hubert-base-ls960) + +# Usage examples + +You can use the model via the Audio Classification pipeline: +```python +import numpy as np +from datasets import load_dataset +from transformers import pipeline, PreTrainedTokenizer + +superb_ks = load_dataset("anton-l/superb_dummy", "ks", split="test") +model = "superb/hubert-base-superb-ks" +tokenizer = PreTrainedTokenizer() # a dummy tokenizer, since the classifier doesn't need a real one +classifier = pipeline("audio-classification", model=model, feature_extractor=model, tokenizer=tokenizer) + +audio = np.array(superb_ks[0]["speech"]) +labels = classifier(audio, top_k=5) +``` + +Or use the model directly: +```python +import torch +import numpy as np +from datasets import load_dataset +from transformers import HubertForSequenceClassification, Wav2Vec2FeatureExtractor + +superb_ks = load_dataset("anton-l/superb_dummy", "ks", split="test") +model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ks") +feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-ks") + +audio = np.array(superb_ks[0]["speech"]) +# compute attention masks and normalize the waveform if needed +inputs = feature_extractor(audio, sampling_rate=16_000, return_tensors="pt") + +logits = model(**inputs).logits +predicted_ids = torch.argmax(logits, dim=-1) +labels = [model.config.id2label[_id] for _id in predicted_ids.tolist()] +``` \ No newline at end of file