Compare commits

...

10 Commits

Author SHA1 Message Date
Patrick von Platen 893c0ff874 upload flax model 2021-07-06 12:32:21 +00:00
Patrick von Platen b4b2ae8a8c allow flax 2021-07-06 12:31:31 +00:00
Patrick von Platen aaab3794b6 Update README.md 2021-03-30 05:29:48 +00:00
maxidl 698a3d155e update WER to 12.90 after evaluating with cleared cache 2021-03-29 15:51:48 +02:00
maxidl 45a4f3a3c6 add chunked wer to eval script 2021-03-29 14:49:05 +02:00
maxidl 4246f93dec update model card 2021-03-28 23:49:37 +02:00
maxidl 57658944a9 update model card 2021-03-28 23:48:59 +02:00
Maximilian Idahl 1f39a7116d Create README.md 2021-03-28 21:46:15 +00:00
maxidl 32cda75624 update model to latest checkpoint 2021-03-28 23:16:07 +02:00
maxidl 0956b07463 Add model files 2021-03-28 21:43:59 +02:00
9 changed files with 312 additions and 0 deletions

1
.gitattributes vendored
View File

@ -14,3 +14,4 @@
*.pb filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text

218
README.md Normal file
View File

@ -0,0 +1,218 @@
---
language: de
datasets:
- common_voice
metrics:
- wer
tags:
- audio
- automatic-speech-recognition
- speech
- xlsr-fine-tuning-week
license: apache-2.0
model-index:
- name: {XLSR Wav2Vec2 Large 53 CV-de}
results:
- task:
name: Speech Recognition
type: automatic-speech-recognition
dataset:
name: Common Voice de
type: common_voice
args: de
metrics:
- name: Test WER
type: wer
value: 12.77
---
# Wav2Vec2-Large-XLSR-53-German
Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on German using the [Common Voice](https://huggingface.co/datasets/common_voice) dataset.
When using this model, make sure that your speech input is sampled at 16kHz.
## Usage
The model can be used directly (without a language model) as follows:
```python
import torch
import torchaudio
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
test_dataset = load_dataset("common_voice", "de", split="test[:8]") # use a batch of 8 for demo purposes
processor = Wav2Vec2Processor.from_pretrained("maxidl/wav2vec2-large-xlsr-german")
model = Wav2Vec2ForCTC.from_pretrained("maxidl/wav2vec2-large-xlsr-german")
resampler = torchaudio.transforms.Resample(48_000, 16_000)
"""
Preprocessing the dataset by:
- loading audio files
- resampling to 16kHz
- converting to array
- prepare input tensor using the processor
"""
def speech_file_to_array_fn(batch):
speech_array, sampling_rate = torchaudio.load(batch["path"])
batch["speech"] = resampler(speech_array).squeeze().numpy()
return batch
test_dataset = test_dataset.map(speech_file_to_array_fn)
inputs = processor(test_dataset["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
# run forward
with torch.no_grad():
logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)
print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", test_dataset["sentence"])
"""
Example Result:
Prediction: [
'zieh durch bittet draußen die schuhe aus',
'es kommt zugvorgebauten fo',
'ihre vorterstrecken erschienen it modemagazinen wie der voge karpes basar mariclair',
'fürliepert eine auch für manachen ungewöhnlich lange drittelliste',
'er wurde zu ehren des reichskanzlers otto von bismarck errichtet',
'was solls ich bin bereit',
'das internet besteht aus vielen computern die miteinander verbunden sind',
'der uranus ist der siebinteplanet in unserem sonnensystem s'
]
Reference: [
'Zieht euch bitte draußen die Schuhe aus.',
'Es kommt zum Showdown in Gstaad.',
'Ihre Fotostrecken erschienen in Modemagazinen wie der Vogue, Harpers Bazaar und Marie Claire.',
'Felipe hat eine auch für Monarchen ungewöhnlich lange Titelliste.',
'Er wurde zu Ehren des Reichskanzlers Otto von Bismarck errichtet.',
'Was solls, ich bin bereit.',
'Das Internet besteht aus vielen Computern, die miteinander verbunden sind.',
'Der Uranus ist der siebente Planet in unserem Sonnensystem.'
]
"""
```
## Evaluation
The model can be evaluated as follows on the German test data of Common Voice:
```python
import re
import torch
import torchaudio
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
"""
Evaluation on the full test set:
- takes ~20mins (RTX 3090).
- requires ~170GB RAM to compute the WER. Below, we use a chunked implementation of WER to avoid large RAM consumption.
"""
test_dataset = load_dataset("common_voice", "de", split="test") # use "test[:1%]" for 1% sample
wer = load_metric("wer")
processor = Wav2Vec2Processor.from_pretrained("maxidl/wav2vec2-large-xlsr-german")
model = Wav2Vec2ForCTC.from_pretrained("maxidl/wav2vec2-large-xlsr-german")
model.to("cuda")
chars_to_ignore_regex = '[\\,\\?\\.\\!\\-\\;\\:\\"\\“]'
resampler = torchaudio.transforms.Resample(48_000, 16_000)
# Preprocessing the datasets.
# We need to read the aduio files as arrays
def speech_file_to_array_fn(batch):
\tbatch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower()
\tspeech_array, sampling_rate = torchaudio.load(batch["path"])
\tbatch["speech"] = resampler(speech_array).squeeze().numpy()
\treturn batch
test_dataset = test_dataset.map(speech_file_to_array_fn)
# Preprocessing the datasets.
# We need to read the audio files as arrays
def evaluate(batch):
\tinputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
\twith torch.no_grad():
\t\tlogits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
\tpred_ids = torch.argmax(logits, dim=-1)
\tbatch["pred_strings"] = processor.batch_decode(pred_ids)
\treturn batch
result = test_dataset.map(evaluate, batched=True, batch_size=8) # batch_size=8 -> requires ~14.5GB GPU memory
# non-chunked version:
# print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))
# WER: 12.900291
# Chunked version, see https://discuss.huggingface.co/t/spanish-asr-fine-tuning-wav2vec2/4586/5:
import jiwer
def chunked_wer(targets, predictions, chunk_size=None):
if chunk_size is None: return jiwer.wer(targets, predictions)
start = 0
end = chunk_size
H, S, D, I = 0, 0, 0, 0
while start < len(targets):
chunk_metrics = jiwer.compute_measures(targets[start:end], predictions[start:end])
H = H + chunk_metrics["hits"]
S = S + chunk_metrics["substitutions"]
D = D + chunk_metrics["deletions"]
I = I + chunk_metrics["insertions"]
start += chunk_size
end += chunk_size
return float(S + D + I) / float(H + S + D)
print("Total (chunk_size=1000), WER: {:2f}".format(100 * chunked_wer(result["pred_strings"], result["sentence"], chunk_size=1000)))
# Total (chunk=1000), WER: 12.768981
```
**Test Result**: WER: 12.77 %
## Training
The Common Voice German `train` and `validation` were used for training.
The script used for training can be found [here](https://github.com/maxidl/wav2vec2).
The model was trained for 50k steps, taking around 30 hours on a single A100.
The arguments used for training this model are:
```
python run_finetuning.py \\
--model_name_or_path="facebook/wav2vec2-large-xlsr-53" \\
--dataset_config_name="de" \\
--output_dir=./wav2vec2-large-xlsr-german \\
--preprocessing_num_workers="16" \\
--overwrite_output_dir \\
--num_train_epochs="20" \\
--per_device_train_batch_size="64" \\
--per_device_eval_batch_size="32" \\
--learning_rate="1e-4" \\
--warmup_steps="500" \\
--evaluation_strategy="steps" \\
--save_steps="5000" \\
--eval_steps="5000" \\
--logging_steps="1000" \\
--save_total_limit="3" \\
--freeze_feature_extractor \\
--activation_dropout="0.055" \\
--attention_dropout="0.094" \\
--feat_proj_dropout="0.04" \\
--layerdrop="0.04" \\
--mask_time_prob="0.08" \\
--gradient_checkpointing="1" \\
--fp16 \\
--do_train \\
--do_eval \\
--dataloader_num_workers="16" \\
--group_by_length
```

76
config.json Normal file
View File

@ -0,0 +1,76 @@
{
"_name_or_path": "facebook/wav2vec2-large-xlsr-53",
"activation_dropout": 0.055,
"apply_spec_augment": true,
"architectures": [
"Wav2Vec2ForCTC"
],
"attention_dropout": 0.094,
"bos_token_id": 1,
"conv_bias": true,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"ctc_loss_reduction": "mean",
"ctc_zero_infinity": false,
"do_stable_layer_norm": true,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "layer",
"feat_proj_dropout": 0.04,
"final_dropout": 0.0,
"gradient_checkpointing": true,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"layerdrop": 0.04,
"mask_channel_length": 10,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_feature_length": 10,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_min_space": 1,
"mask_time_other": 0.0,
"mask_time_prob": 0.08,
"mask_time_selection": "static",
"model_type": "wav2vec2",
"num_attention_heads": 16,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 24,
"pad_token_id": 171,
"transformers_version": "4.4.2",
"vocab_size": 174
}

BIN
flax_model.msgpack (Stored with Git LFS) Normal file

Binary file not shown.

8
preprocessor_config.json Normal file
View File

@ -0,0 +1,8 @@
{
"do_normalize": true,
"feature_size": 1,
"padding_side": "right",
"padding_value": 0.0,
"return_attention_mask": true,
"sampling_rate": 16000
}

BIN
pytorch_model.bin (Stored with Git LFS) Normal file

Binary file not shown.

1
special_tokens_map.json Normal file
View File

@ -0,0 +1 @@
{"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "pad_token": "[PAD]"}

1
tokenizer_config.json Normal file
View File

@ -0,0 +1 @@
{"unk_token": "[UNK]", "bos_token": "<s>", "eos_token": "</s>", "pad_token": "[PAD]", "do_lower_case": false, "word_delimiter_token": "|"}

1
vocab.json Normal file
View File

@ -0,0 +1 @@
{"к": 0, "„": 1, "⟩": 2, "ą": 3, "‟": 4, "о": 5, "臣": 6, "å": 7, "_": 8, "ğ": 9, "ï": 10, "x": 11, "ź": 12, "ç": 13, "ö": 14, "´": 15, "ë": 16, "ä": 17, "t": 18, "f": 19, "ň": 20, "ť": 21, "ņ": 22, "ü": 23, "s": 24, "е": 25, "ả": 26, "b": 27, "в": 28, "ū": 29, "q": 30, "ț": 31, "̆": 32, "≡": 33, "̇": 34, "ś": 35, "孙": 36, "ṟ": 38, "i": 39, "û": 40, "ō": 41, "`": 42, "z": 43, "": 44, "ན": 45, "ù": 46, "r": 47, "и": 48, "″": 49, "ọ": 50, "e": 51, "w": 52, "ạ": 53, "a": 54, "比": 55, "ʻ": 56, "°": 57, "d": 58, "ī": 59, "g": 60, "k": 61, "カ": 62, "ч": 63, "=": 64, "ż": 65, "ħ": 66, "o": 67, "ñ": 68, "ô": 69, "ď": 70, "ø": 71, "ž": 72, "y": 73, "ý": 74, "ș": 75, "ḫ": 76, "ó": 77, "ē": 78, "ő": 79, "無": 80, "à": 81, "ş": 82, "→": 83, "ộ": 84, "с": 85, "ắ": 86, "ê": 87, "”": 88, "ă": 89, "幺": 90, "é": 91, "h": 92, "": 93, "ř": 94, "]": 95, "尣": 96, "乡": 97, "á": 98, "æ": 99, "ǐ": 100, "ı": 101, "[": 102, "î": 103, "u": 104, "": 105, "č": 106, "⟨": 107, "p": 108, "ę": 109, "l": 110, "ā": 111, "ằ": 112, "c": 113, "š": 114, "“": 115, "临": 116, "ѹ": 117, "ć": 118, "ṭ": 119, "ŏ": 120, "m": 121, "ě": 122, "¡": 123, "ů": 124, "": 125, "…": 126, "道": 127, "་": 128, "м": 129, "v": 130, "ń": 131, "р": 132, "ф": 133, "": 134, "": 135, "»": 136, "ð": 137, "«": 138, "支": 139, "þ": 140, "ế": 141, "ễ": 142, "ʿ": 143, "ġ": 144, "í": 145, "ß": 146, "": 147, "ú": 148, "ứ": 149, "ə": 150, "n": 151, "а": 152, "ш": 153, "đ": 154, "": 155, "ã": 156, "â": 157, "œ": 158, "": 159, "µ": 160, "—": 161, "ṣ": 162, "õ": 163, "ò": 164, "辶": 165, "j": 166, "ì": 167, "ė": 168, "ł": 169, "|": 37, "[UNK]": 170, "[PAD]": 171}