Update README.md

This commit is contained in:
Patrick von Platen 2021-03-30 05:29:48 +00:00 committed by huggingface-web
parent 698a3d155e
commit aaab3794b6
1 changed files with 40 additions and 40 deletions

View File

@ -23,7 +23,7 @@ model-index:
metrics: metrics:
- name: Test WER - name: Test WER
type: wer type: wer
value: 12.90 value: 12.77
--- ---
# Wav2Vec2-Large-XLSR-53-German # Wav2Vec2-Large-XLSR-53-German
@ -123,30 +123,30 @@ processor = Wav2Vec2Processor.from_pretrained("maxidl/wav2vec2-large-xlsr-german
model = Wav2Vec2ForCTC.from_pretrained("maxidl/wav2vec2-large-xlsr-german") model = Wav2Vec2ForCTC.from_pretrained("maxidl/wav2vec2-large-xlsr-german")
model.to("cuda") model.to("cuda")
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“]' chars_to_ignore_regex = '[\\,\\?\\.\\!\\-\\;\\:\\"\\“]'
resampler = torchaudio.transforms.Resample(48_000, 16_000) resampler = torchaudio.transforms.Resample(48_000, 16_000)
# Preprocessing the datasets. # Preprocessing the datasets.
# We need to read the aduio files as arrays # We need to read the aduio files as arrays
def speech_file_to_array_fn(batch): def speech_file_to_array_fn(batch):
batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower() \tbatch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower()
speech_array, sampling_rate = torchaudio.load(batch["path"]) \tspeech_array, sampling_rate = torchaudio.load(batch["path"])
batch["speech"] = resampler(speech_array).squeeze().numpy() \tbatch["speech"] = resampler(speech_array).squeeze().numpy()
return batch \treturn batch
test_dataset = test_dataset.map(speech_file_to_array_fn) test_dataset = test_dataset.map(speech_file_to_array_fn)
# Preprocessing the datasets. # Preprocessing the datasets.
# We need to read the audio files as arrays # We need to read the audio files as arrays
def evaluate(batch): def evaluate(batch):
inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True) \tinputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad(): \twith torch.no_grad():
logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits \t\tlogits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
pred_ids = torch.argmax(logits, dim=-1) \tpred_ids = torch.argmax(logits, dim=-1)
batch["pred_strings"] = processor.batch_decode(pred_ids) \tbatch["pred_strings"] = processor.batch_decode(pred_ids)
return batch \treturn batch
result = test_dataset.map(evaluate, batched=True, batch_size=8) # batch_size=8 -> requires ~14.5GB GPU memory result = test_dataset.map(evaluate, batched=True, batch_size=8) # batch_size=8 -> requires ~14.5GB GPU memory
@ -176,7 +176,7 @@ print("Total (chunk_size=1000), WER: {:2f}".format(100 * chunked_wer(result["pre
# Total (chunk=1000), WER: 12.768981 # Total (chunk=1000), WER: 12.768981
``` ```
**Test Result**: WER: 12.90 % **Test Result**: WER: 12.77 %
## Training ## Training
@ -187,32 +187,32 @@ The model was trained for 50k steps, taking around 30 hours on a single A100.
The arguments used for training this model are: The arguments used for training this model are:
``` ```
python run_finetuning.py \ python run_finetuning.py \\
--model_name_or_path="facebook/wav2vec2-large-xlsr-53" \ --model_name_or_path="facebook/wav2vec2-large-xlsr-53" \\
--dataset_config_name="de" \ --dataset_config_name="de" \\
--output_dir=./wav2vec2-large-xlsr-german \ --output_dir=./wav2vec2-large-xlsr-german \\
--preprocessing_num_workers="16" \ --preprocessing_num_workers="16" \\
--overwrite_output_dir \ --overwrite_output_dir \\
--num_train_epochs="20" \ --num_train_epochs="20" \\
--per_device_train_batch_size="64" \ --per_device_train_batch_size="64" \\
--per_device_eval_batch_size="32" \ --per_device_eval_batch_size="32" \\
--learning_rate="1e-4" \ --learning_rate="1e-4" \\
--warmup_steps="500" \ --warmup_steps="500" \\
--evaluation_strategy="steps" \ --evaluation_strategy="steps" \\
--save_steps="5000" \ --save_steps="5000" \\
--eval_steps="5000" \ --eval_steps="5000" \\
--logging_steps="1000" \ --logging_steps="1000" \\
--save_total_limit="3" \ --save_total_limit="3" \\
--freeze_feature_extractor \ --freeze_feature_extractor \\
--activation_dropout="0.055" \ --activation_dropout="0.055" \\
--attention_dropout="0.094" \ --attention_dropout="0.094" \\
--feat_proj_dropout="0.04" \ --feat_proj_dropout="0.04" \\
--layerdrop="0.04" \ --layerdrop="0.04" \\
--mask_time_prob="0.08" \ --mask_time_prob="0.08" \\
--gradient_checkpointing="1" \ --gradient_checkpointing="1" \\
--fp16 \ --fp16 \\
--do_train \ --do_train \\
--do_eval \ --do_eval \\
--dataloader_num_workers="16" \ --dataloader_num_workers="16" \\
--group_by_length --group_by_length
``` ```