add evaluation
This commit is contained in:
parent
7f6f8d586f
commit
765fb8c2cf
121
README.md
121
README.md
|
@ -2,32 +2,62 @@
|
|||
language: en
|
||||
datasets:
|
||||
- common_voice
|
||||
- mozilla-foundation/common_voice_6_0
|
||||
metrics:
|
||||
- wer
|
||||
- cer
|
||||
tags:
|
||||
- en
|
||||
- audio
|
||||
- automatic-speech-recognition
|
||||
- speech
|
||||
- xlsr-fine-tuning-week
|
||||
- robust-speech-event
|
||||
- mozilla-foundation/common_voice_6_0
|
||||
license: apache-2.0
|
||||
model-index:
|
||||
- name: XLSR Wav2Vec2 English by Jonatas Grosman
|
||||
results:
|
||||
- task:
|
||||
name: Speech Recognition
|
||||
name: Automatic Speech Recognition
|
||||
type: automatic-speech-recognition
|
||||
dataset:
|
||||
name: Common Voice en
|
||||
name: Common Voice pt
|
||||
type: common_voice
|
||||
args: en
|
||||
metrics:
|
||||
- name: Test WER
|
||||
type: wer
|
||||
value: 18.98
|
||||
value: 19.06
|
||||
- name: Test CER
|
||||
type: cer
|
||||
value: 8.29
|
||||
value: 7.69
|
||||
- name: Test WER (+LM)
|
||||
type: wer
|
||||
value: 14.81
|
||||
- name: Test CER (+LM)
|
||||
type: cer
|
||||
value: 6.84
|
||||
- task:
|
||||
name: Automatic Speech Recognition
|
||||
type: automatic-speech-recognition
|
||||
dataset:
|
||||
name: Robust Speech Event - Dev Data
|
||||
type: speech-recognition-community-v2/dev_data
|
||||
args: en
|
||||
metrics:
|
||||
- name: Test WER
|
||||
type: wer
|
||||
value: 27.72
|
||||
- name: Test CER
|
||||
type: cer
|
||||
value: 11.65
|
||||
- name: Test WER (+LM)
|
||||
type: wer
|
||||
value: 20.85
|
||||
- name: Test CER (+LM)
|
||||
type: cer
|
||||
value: 11.01
|
||||
---
|
||||
|
||||
# Wav2Vec2-Large-XLSR-53-English
|
||||
|
@ -109,83 +139,14 @@ for i, predicted_sentence in enumerate(predicted_sentences):
|
|||
|
||||
## Evaluation
|
||||
|
||||
The model can be evaluated as follows on the English test data of Common Voice.
|
||||
1. To evaluate on `mozilla-foundation/common_voice_6_0` with split `test`
|
||||
|
||||
```python
|
||||
import torch
|
||||
import re
|
||||
import librosa
|
||||
from datasets import load_dataset, load_metric
|
||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||
|
||||
LANG_ID = "en"
|
||||
MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-english"
|
||||
DEVICE = "cuda"
|
||||
|
||||
CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "<22>", "ʿ", "·", "჻", "~", "՞",
|
||||
"؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", "[", "]",
|
||||
"{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。",
|
||||
"、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽",
|
||||
"『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"]
|
||||
|
||||
test_dataset = load_dataset("common_voice", LANG_ID, split="test")
|
||||
|
||||
wer = load_metric("wer.py") # https://github.com/jonatasgrosman/wav2vec2-sprint/blob/main/wer.py
|
||||
cer = load_metric("cer.py") # https://github.com/jonatasgrosman/wav2vec2-sprint/blob/main/cer.py
|
||||
|
||||
chars_to_ignore_regex = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]"
|
||||
|
||||
processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
|
||||
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
|
||||
model.to(DEVICE)
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to read the audio files as arrays
|
||||
def speech_file_to_array_fn(batch):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
speech_array, sampling_rate = librosa.load(batch["path"], sr=16_000)
|
||||
batch["speech"] = speech_array
|
||||
batch["sentence"] = re.sub(chars_to_ignore_regex, "", batch["sentence"]).upper()
|
||||
return batch
|
||||
|
||||
test_dataset = test_dataset.map(speech_file_to_array_fn)
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to read the audio files as arrays
|
||||
def evaluate(batch):
|
||||
inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(inputs.input_values.to(DEVICE), attention_mask=inputs.attention_mask.to(DEVICE)).logits
|
||||
|
||||
pred_ids = torch.argmax(logits, dim=-1)
|
||||
batch["pred_strings"] = processor.batch_decode(pred_ids)
|
||||
return batch
|
||||
|
||||
result = test_dataset.map(evaluate, batched=True, batch_size=8)
|
||||
|
||||
predictions = [x.upper() for x in result["pred_strings"]]
|
||||
references = [x.upper() for x in result["sentence"]]
|
||||
|
||||
print(f"WER: {wer.compute(predictions=predictions, references=references, chunk_size=1000) * 100}")
|
||||
print(f"CER: {cer.compute(predictions=predictions, references=references, chunk_size=1000) * 100}")
|
||||
```bash
|
||||
python eval.py --model_id jonatasgrosman/wav2vec2-large-xlsr-53-english --dataset mozilla-foundation/common_voice_6_0 --config en --split test
|
||||
```
|
||||
|
||||
**Test Result**:
|
||||
2. To evaluate on `speech-recognition-community-v2/dev_data`
|
||||
|
||||
In the table below I report the Word Error Rate (WER) and the Character Error Rate (CER) of the model. I ran the evaluation script described above on other models as well (on 2021-06-17). Note that the table below may show different results from those already reported, this may have been caused due to some specificity of the other evaluation scripts used.
|
||||
|
||||
| Model | WER | CER |
|
||||
| ------------- | ------------- | ------------- |
|
||||
| jonatasgrosman/wav2vec2-large-xlsr-53-english | **18.98%** | **8.29%** |
|
||||
| jonatasgrosman/wav2vec2-large-english | 21.53% | 9.66% |
|
||||
| facebook/wav2vec2-large-960h-lv60-self | 22.03% | 10.39% |
|
||||
| facebook/wav2vec2-large-960h-lv60 | 23.97% | 11.14% |
|
||||
| boris/xlsr-en-punctuation | 29.10% | 10.75% |
|
||||
| facebook/wav2vec2-large-960h | 32.79% | 16.03% |
|
||||
| facebook/wav2vec2-base-960h | 39.86% | 19.89% |
|
||||
| facebook/wav2vec2-base-100h | 51.06% | 25.06% |
|
||||
| elgeish/wav2vec2-large-lv60-timit-asr | 59.96% | 34.28% |
|
||||
| facebook/wav2vec2-base-10k-voxpopuli-ft-en | 66.41% | 36.76% |
|
||||
| elgeish/wav2vec2-base-timit-asr | 68.78% | 36.81% |
|
||||
```bash
|
||||
python eval.py --model_id jonatasgrosman/wav2vec2-large-xlsr-53-english --dataset speech-recognition-community-v2/dev_data --config en --split validation --chunk_length_s 5.0 --stride_length_s 1.0
|
||||
```
|
||||
|
|
|
@ -0,0 +1,164 @@
|
|||
#!/usr/bin/env python3
|
||||
from datasets import load_dataset, load_metric, Audio, Dataset
|
||||
from transformers import pipeline, AutoFeatureExtractor, AutoTokenizer, AutoConfig, AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
|
||||
import re
|
||||
import torch
|
||||
import argparse
|
||||
from typing import Dict
|
||||
|
||||
def log_results(result: Dataset, args: Dict[str, str]):
|
||||
""" DO NOT CHANGE. This function computes and logs the result metrics. """
|
||||
|
||||
log_outputs = args.log_outputs
|
||||
dataset_id = "_".join(args.dataset.split("/") + [args.config, args.split])
|
||||
|
||||
# load metric
|
||||
wer = load_metric("wer")
|
||||
cer = load_metric("cer")
|
||||
|
||||
# compute metrics
|
||||
wer_result = wer.compute(references=result["target"], predictions=result["prediction"])
|
||||
cer_result = cer.compute(references=result["target"], predictions=result["prediction"])
|
||||
|
||||
# print & log results
|
||||
result_str = (
|
||||
f"WER: {wer_result}\n"
|
||||
f"CER: {cer_result}"
|
||||
)
|
||||
print(result_str)
|
||||
|
||||
with open(f"{dataset_id}_eval_results.txt", "w") as f:
|
||||
f.write(result_str)
|
||||
|
||||
# log all results in text file. Possibly interesting for analysis
|
||||
if log_outputs is not None:
|
||||
pred_file = f"log_{dataset_id}_predictions.txt"
|
||||
target_file = f"log_{dataset_id}_targets.txt"
|
||||
|
||||
with open(pred_file, "w") as p, open(target_file, "w") as t:
|
||||
|
||||
# mapping function to write output
|
||||
def write_to_file(batch, i):
|
||||
p.write(f"{i}" + "\n")
|
||||
p.write(batch["prediction"] + "\n")
|
||||
t.write(f"{i}" + "\n")
|
||||
t.write(batch["target"] + "\n")
|
||||
|
||||
result.map(write_to_file, with_indices=True)
|
||||
|
||||
|
||||
def normalize_text(text: str, invalid_chars_regex: str, to_lower: bool) -> str:
|
||||
""" DO ADAPT FOR YOUR USE CASE. this function normalizes the target text. """
|
||||
|
||||
text = text.lower() if to_lower else text.upper()
|
||||
|
||||
text = re.sub(invalid_chars_regex, " ", text)
|
||||
|
||||
text = re.sub("\s+", " ", text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def main(args):
|
||||
# load dataset
|
||||
dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
|
||||
|
||||
# for testing: only process the first two examples as a test
|
||||
# dataset = dataset.select(range(10))
|
||||
|
||||
# load processor
|
||||
if args.greedy:
|
||||
processor = Wav2Vec2Processor.from_pretrained(args.model_id)
|
||||
decoder = None
|
||||
else:
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained(args.model_id)
|
||||
decoder = processor.decoder
|
||||
|
||||
feature_extractor = processor.feature_extractor
|
||||
tokenizer = processor.tokenizer
|
||||
|
||||
# resample audio
|
||||
dataset = dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
|
||||
|
||||
# load eval pipeline
|
||||
if args.device is None:
|
||||
args.device = 0 if torch.cuda.is_available() else -1
|
||||
|
||||
config = AutoConfig.from_pretrained(args.model_id)
|
||||
model = AutoModelForCTC.from_pretrained(args.model_id)
|
||||
|
||||
#asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
|
||||
asr = pipeline("automatic-speech-recognition", config=config, model=model, tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor, decoder=decoder, device=args.device)
|
||||
|
||||
# build normalizer config
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
|
||||
tokens = [x for x in tokenizer.convert_ids_to_tokens(range(0, tokenizer.vocab_size))]
|
||||
special_tokens = [
|
||||
tokenizer.pad_token, tokenizer.word_delimiter_token,
|
||||
tokenizer.unk_token, tokenizer.bos_token,
|
||||
tokenizer.eos_token,
|
||||
]
|
||||
non_special_tokens = [x for x in tokens if x not in special_tokens]
|
||||
invalid_chars_regex = f"[^\s{re.escape(''.join(set(non_special_tokens)))}]"
|
||||
normalize_to_lower = False
|
||||
for token in non_special_tokens:
|
||||
if token.isalpha() and token.islower():
|
||||
normalize_to_lower = True
|
||||
break
|
||||
|
||||
# map function to decode audio
|
||||
def map_to_pred(batch, args=args, asr=asr, invalid_chars_regex=invalid_chars_regex, normalize_to_lower=normalize_to_lower):
|
||||
prediction = asr(batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s)
|
||||
|
||||
batch["prediction"] = prediction["text"]
|
||||
batch["target"] = normalize_text(batch["sentence"], invalid_chars_regex, normalize_to_lower)
|
||||
return batch
|
||||
|
||||
# run inference on all examples
|
||||
result = dataset.map(map_to_pred, remove_columns=dataset.column_names)
|
||||
|
||||
# filtering out empty targets
|
||||
result = result.filter(lambda example: example["target"] != "")
|
||||
|
||||
# compute and log_results
|
||||
# do not change function below
|
||||
log_results(result, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--model_id", type=str, required=True, help="Model identifier. Should be loadable with 🤗 Transformers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset", type=str, required=True, help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config", type=str, required=True, help="Config of the dataset. *E.g.* `'en'` for Common Voice"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split", type=str, required=True, help="Split of the dataset. *E.g.* `'test'`"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk_length_s", type=float, default=None, help="Chunk length in seconds. Defaults to None. For long audio files a good value would be 5.0 seconds."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stride_length_s", type=float, default=None, help="Stride of the audio chunks. Defaults to None. For long audio files a good value would be 1.0 seconds."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log_outputs", action='store_true', help="If defined, write outputs to log file for analysis."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--greedy", action='store_true', help="If defined, the LM will be ignored during inference."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
|
@ -0,0 +1,15 @@
|
|||
# CV - TEST
|
||||
|
||||
python eval.py --model_id jonatasgrosman/wav2vec2-large-xlsr-53-english --dataset mozilla-foundation/common_voice_6_0 --config en --split test --log_outputs --greedy
|
||||
mv log_mozilla-foundation_common_voice_6_0_en_test_predictions.txt log_mozilla-foundation_common_voice_6_0_en_test_predictions_greedy.txt
|
||||
mv mozilla-foundation_common_voice_6_0_en_test_eval_results.txt mozilla-foundation_common_voice_6_0_en_test_eval_results_greedy.txt
|
||||
|
||||
python eval.py --model_id jonatasgrosman/wav2vec2-large-xlsr-53-english --dataset mozilla-foundation/common_voice_6_0 --config en --split test --log_outputs
|
||||
|
||||
# HF EVENT - DEV
|
||||
|
||||
python eval.py --model_id jonatasgrosman/wav2vec2-large-xlsr-53-english --dataset speech-recognition-community-v2/dev_data --config en --split validation --chunk_length_s 5.0 --stride_length_s 1.0 --log_outputs --greedy
|
||||
mv log_speech-recognition-community-v2_dev_data_en_validation_predictions.txt log_speech-recognition-community-v2_dev_data_en_validation_predictions_greedy.txt
|
||||
mv speech-recognition-community-v2_dev_data_en_validation_eval_results.txt speech-recognition-community-v2_dev_data_en_validation_eval_results_greedy.txt
|
||||
|
||||
python eval.py --model_id jonatasgrosman/wav2vec2-large-xlsr-53-english --dataset speech-recognition-community-v2/dev_data --config en --split validation --chunk_length_s 5.0 --stride_length_s 1.0 --log_outputs
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,2 @@
|
|||
WER: 0.1481828839390387
|
||||
CER: 0.06848087313203592
|
|
@ -0,0 +1,2 @@
|
|||
WER: 0.19067492882264278
|
||||
CER: 0.07694957927516068
|
|
@ -0,0 +1,2 @@
|
|||
WER: 0.2085057090848916
|
||||
CER: 0.11011805154105943
|
|
@ -0,0 +1,2 @@
|
|||
WER: 0.27722157868608305
|
||||
CER: 0.11652265190008215
|
Loading…
Reference in New Issue