diff --git a/README.md b/README.md index 7d4b80d..e2ec465 100644 --- a/README.md +++ b/README.md @@ -114,8 +114,7 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor """ Evaluation on the full test set: - takes ~20mins (RTX 3090). -- requires ~170GB RAM to compute the WER. A potential solution to this is computing it in chunks. - See https://discuss.huggingface.co/t/spanish-asr-fine-tuning-wav2vec2/4586/5 on how to implement this. +- requires ~170GB RAM to compute the WER. Below, we use a chunked implementation of WER to avoid large RAM consumption. """ test_dataset = load_dataset("common_voice", "de", split="test") # use "test[:1%]" for 1% sample wer = load_metric("wer") @@ -151,8 +150,30 @@ def evaluate(batch): result = test_dataset.map(evaluate, batched=True, batch_size=8) # batch_size=8 -> requires ~14.5GB GPU memory -print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"]))) +# non-chunked version: +# print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"]))) # WER: 12.615308 + +# Chunked version, see https://discuss.huggingface.co/t/spanish-asr-fine-tuning-wav2vec2/4586/5: +import jiwer + +def chunked_wer(targets, predictions, chunk_size=None): + if chunk_size is None: return jiwer.wer(targets, predictions) + start = 0 + end = chunk_size + H, S, D, I = 0, 0, 0, 0 + while start < len(targets): + chunk_metrics = jiwer.compute_measures(targets[start:end], predictions[start:end]) + H = H + chunk_metrics["hits"] + S = S + chunk_metrics["substitutions"] + D = D + chunk_metrics["deletions"] + I = I + chunk_metrics["insertions"] + start += chunk_size + end += chunk_size + return float(S + D + I) / float(H + S + D) + +print("Total (chunk_size=1000), WER: {:2f}".format(100 * chunked_wer(result["pred_strings"], result["sentence"], chunk_size=1000))) +# Total (chunk=1000), WER: 12.768981 ``` **Test Result**: 12.62 %