add evaluation

This commit is contained in:
Jonatas Grosman 2022-02-11 19:20:55 +00:00
parent 7f6f8d586f
commit 765fb8c2cf
13 changed files with 97248 additions and 80 deletions

121
README.md
View File

@ -2,32 +2,62 @@
language: en
datasets:
- common_voice
- mozilla-foundation/common_voice_6_0
metrics:
- wer
- cer
tags:
- en
- audio
- automatic-speech-recognition
- speech
- xlsr-fine-tuning-week
- robust-speech-event
- mozilla-foundation/common_voice_6_0
license: apache-2.0
model-index:
- name: XLSR Wav2Vec2 English by Jonatas Grosman
results:
- task:
name: Speech Recognition
name: Automatic Speech Recognition
type: automatic-speech-recognition
dataset:
name: Common Voice en
name: Common Voice pt
type: common_voice
args: en
metrics:
- name: Test WER
type: wer
value: 18.98
value: 19.06
- name: Test CER
type: cer
value: 8.29
value: 7.69
- name: Test WER (+LM)
type: wer
value: 14.81
- name: Test CER (+LM)
type: cer
value: 6.84
- task:
name: Automatic Speech Recognition
type: automatic-speech-recognition
dataset:
name: Robust Speech Event - Dev Data
type: speech-recognition-community-v2/dev_data
args: en
metrics:
- name: Test WER
type: wer
value: 27.72
- name: Test CER
type: cer
value: 11.65
- name: Test WER (+LM)
type: wer
value: 20.85
- name: Test CER (+LM)
type: cer
value: 11.01
---
# Wav2Vec2-Large-XLSR-53-English
@ -109,83 +139,14 @@ for i, predicted_sentence in enumerate(predicted_sentences):
## Evaluation
The model can be evaluated as follows on the English test data of Common Voice.
1. To evaluate on `mozilla-foundation/common_voice_6_0` with split `test`
```python
import torch
import re
import librosa
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
LANG_ID = "en"
MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-english"
DEVICE = "cuda"
CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", "", ":", '""', "%", '"', "<22>", "ʿ", "·", "჻", "~", "՞",
"؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "", "", "《", "》", "(", ")", "[", "]",
"{", "}", "=", "`", "_", "+", "<", ">", "…", "", "°", "´", "ʾ", "", "", "©", "®", "—", "→", "。",
"、", "﹂", "﹁", "‧", "", "", "", "", "", "", "", "", "", "【", "】", "‥", "〽",
"『", "』", "〝", "〟", "⟨", "⟩", "〜", "", "", "", "♪", "؛", "/", "\\", "º", "", "^", "ʻ", "ˆ"]
test_dataset = load_dataset("common_voice", LANG_ID, split="test")
wer = load_metric("wer.py") # https://github.com/jonatasgrosman/wav2vec2-sprint/blob/main/wer.py
cer = load_metric("cer.py") # https://github.com/jonatasgrosman/wav2vec2-sprint/blob/main/cer.py
chars_to_ignore_regex = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]"
processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
model.to(DEVICE)
# Preprocessing the datasets.
# We need to read the audio files as arrays
def speech_file_to_array_fn(batch):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
speech_array, sampling_rate = librosa.load(batch["path"], sr=16_000)
batch["speech"] = speech_array
batch["sentence"] = re.sub(chars_to_ignore_regex, "", batch["sentence"]).upper()
return batch
test_dataset = test_dataset.map(speech_file_to_array_fn)
# Preprocessing the datasets.
# We need to read the audio files as arrays
def evaluate(batch):
inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs.input_values.to(DEVICE), attention_mask=inputs.attention_mask.to(DEVICE)).logits
pred_ids = torch.argmax(logits, dim=-1)
batch["pred_strings"] = processor.batch_decode(pred_ids)
return batch
result = test_dataset.map(evaluate, batched=True, batch_size=8)
predictions = [x.upper() for x in result["pred_strings"]]
references = [x.upper() for x in result["sentence"]]
print(f"WER: {wer.compute(predictions=predictions, references=references, chunk_size=1000) * 100}")
print(f"CER: {cer.compute(predictions=predictions, references=references, chunk_size=1000) * 100}")
```bash
python eval.py --model_id jonatasgrosman/wav2vec2-large-xlsr-53-english --dataset mozilla-foundation/common_voice_6_0 --config en --split test
```
**Test Result**:
2. To evaluate on `speech-recognition-community-v2/dev_data`
In the table below I report the Word Error Rate (WER) and the Character Error Rate (CER) of the model. I ran the evaluation script described above on other models as well (on 2021-06-17). Note that the table below may show different results from those already reported, this may have been caused due to some specificity of the other evaluation scripts used.
| Model | WER | CER |
| ------------- | ------------- | ------------- |
| jonatasgrosman/wav2vec2-large-xlsr-53-english | **18.98%** | **8.29%** |
| jonatasgrosman/wav2vec2-large-english | 21.53% | 9.66% |
| facebook/wav2vec2-large-960h-lv60-self | 22.03% | 10.39% |
| facebook/wav2vec2-large-960h-lv60 | 23.97% | 11.14% |
| boris/xlsr-en-punctuation | 29.10% | 10.75% |
| facebook/wav2vec2-large-960h | 32.79% | 16.03% |
| facebook/wav2vec2-base-960h | 39.86% | 19.89% |
| facebook/wav2vec2-base-100h | 51.06% | 25.06% |
| elgeish/wav2vec2-large-lv60-timit-asr | 59.96% | 34.28% |
| facebook/wav2vec2-base-10k-voxpopuli-ft-en | 66.41% | 36.76% |
| elgeish/wav2vec2-base-timit-asr | 68.78% | 36.81% |
```bash
python eval.py --model_id jonatasgrosman/wav2vec2-large-xlsr-53-english --dataset speech-recognition-community-v2/dev_data --config en --split validation --chunk_length_s 5.0 --stride_length_s 1.0
```

164
eval.py Normal file
View File

@ -0,0 +1,164 @@
#!/usr/bin/env python3
from datasets import load_dataset, load_metric, Audio, Dataset
from transformers import pipeline, AutoFeatureExtractor, AutoTokenizer, AutoConfig, AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
import re
import torch
import argparse
from typing import Dict
def log_results(result: Dataset, args: Dict[str, str]):
""" DO NOT CHANGE. This function computes and logs the result metrics. """
log_outputs = args.log_outputs
dataset_id = "_".join(args.dataset.split("/") + [args.config, args.split])
# load metric
wer = load_metric("wer")
cer = load_metric("cer")
# compute metrics
wer_result = wer.compute(references=result["target"], predictions=result["prediction"])
cer_result = cer.compute(references=result["target"], predictions=result["prediction"])
# print & log results
result_str = (
f"WER: {wer_result}\n"
f"CER: {cer_result}"
)
print(result_str)
with open(f"{dataset_id}_eval_results.txt", "w") as f:
f.write(result_str)
# log all results in text file. Possibly interesting for analysis
if log_outputs is not None:
pred_file = f"log_{dataset_id}_predictions.txt"
target_file = f"log_{dataset_id}_targets.txt"
with open(pred_file, "w") as p, open(target_file, "w") as t:
# mapping function to write output
def write_to_file(batch, i):
p.write(f"{i}" + "\n")
p.write(batch["prediction"] + "\n")
t.write(f"{i}" + "\n")
t.write(batch["target"] + "\n")
result.map(write_to_file, with_indices=True)
def normalize_text(text: str, invalid_chars_regex: str, to_lower: bool) -> str:
""" DO ADAPT FOR YOUR USE CASE. this function normalizes the target text. """
text = text.lower() if to_lower else text.upper()
text = re.sub(invalid_chars_regex, " ", text)
text = re.sub("\s+", " ", text).strip()
return text
def main(args):
# load dataset
dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
# for testing: only process the first two examples as a test
# dataset = dataset.select(range(10))
# load processor
if args.greedy:
processor = Wav2Vec2Processor.from_pretrained(args.model_id)
decoder = None
else:
processor = Wav2Vec2ProcessorWithLM.from_pretrained(args.model_id)
decoder = processor.decoder
feature_extractor = processor.feature_extractor
tokenizer = processor.tokenizer
# resample audio
dataset = dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
# load eval pipeline
if args.device is None:
args.device = 0 if torch.cuda.is_available() else -1
config = AutoConfig.from_pretrained(args.model_id)
model = AutoModelForCTC.from_pretrained(args.model_id)
#asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
asr = pipeline("automatic-speech-recognition", config=config, model=model, tokenizer=tokenizer,
feature_extractor=feature_extractor, decoder=decoder, device=args.device)
# build normalizer config
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
tokens = [x for x in tokenizer.convert_ids_to_tokens(range(0, tokenizer.vocab_size))]
special_tokens = [
tokenizer.pad_token, tokenizer.word_delimiter_token,
tokenizer.unk_token, tokenizer.bos_token,
tokenizer.eos_token,
]
non_special_tokens = [x for x in tokens if x not in special_tokens]
invalid_chars_regex = f"[^\s{re.escape(''.join(set(non_special_tokens)))}]"
normalize_to_lower = False
for token in non_special_tokens:
if token.isalpha() and token.islower():
normalize_to_lower = True
break
# map function to decode audio
def map_to_pred(batch, args=args, asr=asr, invalid_chars_regex=invalid_chars_regex, normalize_to_lower=normalize_to_lower):
prediction = asr(batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s)
batch["prediction"] = prediction["text"]
batch["target"] = normalize_text(batch["sentence"], invalid_chars_regex, normalize_to_lower)
return batch
# run inference on all examples
result = dataset.map(map_to_pred, remove_columns=dataset.column_names)
# filtering out empty targets
result = result.filter(lambda example: example["target"] != "")
# compute and log_results
# do not change function below
log_results(result, args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_id", type=str, required=True, help="Model identifier. Should be loadable with 🤗 Transformers"
)
parser.add_argument(
"--dataset", type=str, required=True, help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets"
)
parser.add_argument(
"--config", type=str, required=True, help="Config of the dataset. *E.g.* `'en'` for Common Voice"
)
parser.add_argument(
"--split", type=str, required=True, help="Split of the dataset. *E.g.* `'test'`"
)
parser.add_argument(
"--chunk_length_s", type=float, default=None, help="Chunk length in seconds. Defaults to None. For long audio files a good value would be 5.0 seconds."
)
parser.add_argument(
"--stride_length_s", type=float, default=None, help="Stride of the audio chunks. Defaults to None. For long audio files a good value would be 1.0 seconds."
)
parser.add_argument(
"--log_outputs", action='store_true', help="If defined, write outputs to log file for analysis."
)
parser.add_argument(
"--greedy", action='store_true', help="If defined, the LM will be ignored during inference."
)
parser.add_argument(
"--device",
type=int,
default=None,
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
)
args = parser.parse_args()
main(args)

15
full_eval.sh Normal file
View File

@ -0,0 +1,15 @@
# CV - TEST
python eval.py --model_id jonatasgrosman/wav2vec2-large-xlsr-53-english --dataset mozilla-foundation/common_voice_6_0 --config en --split test --log_outputs --greedy
mv log_mozilla-foundation_common_voice_6_0_en_test_predictions.txt log_mozilla-foundation_common_voice_6_0_en_test_predictions_greedy.txt
mv mozilla-foundation_common_voice_6_0_en_test_eval_results.txt mozilla-foundation_common_voice_6_0_en_test_eval_results_greedy.txt
python eval.py --model_id jonatasgrosman/wav2vec2-large-xlsr-53-english --dataset mozilla-foundation/common_voice_6_0 --config en --split test --log_outputs
# HF EVENT - DEV
python eval.py --model_id jonatasgrosman/wav2vec2-large-xlsr-53-english --dataset speech-recognition-community-v2/dev_data --config en --split validation --chunk_length_s 5.0 --stride_length_s 1.0 --log_outputs --greedy
mv log_speech-recognition-community-v2_dev_data_en_validation_predictions.txt log_speech-recognition-community-v2_dev_data_en_validation_predictions_greedy.txt
mv speech-recognition-community-v2_dev_data_en_validation_eval_results.txt speech-recognition-community-v2_dev_data_en_validation_eval_results_greedy.txt
python eval.py --model_id jonatasgrosman/wav2vec2-large-xlsr-53-english --dataset speech-recognition-community-v2/dev_data --config en --split validation --chunk_length_s 5.0 --stride_length_s 1.0 --log_outputs

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,2 @@
WER: 0.1481828839390387
CER: 0.06848087313203592

View File

@ -0,0 +1,2 @@
WER: 0.19067492882264278
CER: 0.07694957927516068

View File

@ -0,0 +1,2 @@
WER: 0.2085057090848916
CER: 0.11011805154105943

View File

@ -0,0 +1,2 @@
WER: 0.27722157868608305
CER: 0.11652265190008215