fix attention flag

This commit is contained in:
ceyda 2021-04-06 23:50:00 +00:00
parent ab8048eec0
commit 3dee02cafd
5 changed files with 11 additions and 45 deletions

View File

@ -23,7 +23,7 @@ model-index:
metrics:
- name: Test WER
type: wer
value: 24.91
value: 22.60
---
# Wav2Vec2-Base-760-Turkish
@ -102,11 +102,13 @@ test_dataset = test_dataset.map(speech_file_to_array_fn)
# Preprocessing the datasets.
# We need to read the aduio files as arrays
#Attention mask is not used because the base-model was not trained with it. reference: https://github.com/huggingface/transformers/blob/403d530eec105c0e229fc2b754afdf77a4439def/src/transformers/models/wav2vec2/tokenization_wav2vec2.py#L305
def evaluate(batch):
inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
logits = model(inputs.input_values.to("cuda")).logits
pred_ids = torch.argmax(logits, dim=-1)
batch["pred_strings"] = processor.batch_decode(pred_ids,skip_special_tokens=True)
@ -117,7 +119,9 @@ result = test_dataset.map(evaluate, batched=True, batch_size=8)
print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))
```
**Test Result**: 24.91 % (in progress)
**Test Results**:
- WER: 22.602390
- CER: 6.054137
## Training

View File

@ -2,7 +2,7 @@
"do_normalize": true,
"feature_size": 1,
"padding_side": "right",
"padding_value": 0.0,
"padding_value": 0,
"return_attention_mask": true,
"sampling_rate": 16000
}

BIN
pytorch_model.bin (Stored with Git LFS)

Binary file not shown.

View File

@ -1 +1 @@
{"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>", "pad_token": "<pad>", "do_lower_case": false, "word_delimiter_token": "|","special_tokens_map_file": "/home/ceyda/workspace/libs/fairseq/hf_finetuned_output/special_tokens_map.json", "tokenizer_file": null}
{"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>", "pad_token": "<pad>", "do_lower_case": false, "word_delimiter_token": "|"}

View File

@ -1,39 +1 @@
{"|": 4,
"p": 5,
"i": 6,
"r": 7,
"n": 8,
"s": 9,
"ö": 10,
"z": 11,
"l": 12,
"e": 13,
"h": 14,
"â": 15,
"y": 16,
"a": 17,
"k": 18,
"ı": 19,
"o": 20,
"m": 21,
"ü": 22,
"g": 23,
"c": 24,
"b": 25,
"ş": 26,
"d": 27,
"u": 28,
"t": 29,
"ç": 30,
"ğ": 31,
"v": 32,
"f": 33,
"j": 34,
"x": 35,
"w": 36,
"q": 37,
"î": 38,
"<s>": 0,
"<pad>": 1,
"</s>": 2,
"<unk>": 3}
{"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3, "|": 4, "p": 5, "i": 6, "r": 7, "n": 8, "s": 9, "ö": 10, "z": 11, "l": 12, "e": 13, "h": 14, "â": 15, "y": 16, "a": 17, "k": 18, "ı": 19, "o": 20, "m": 21, "ü": 22, "g": 23, "c": 24, "b": 25, "ş": 26, "d": 27, "u": 28, "t": 29, "ç": 30, "ğ": 31, "v": 32, "f": 33, "j": 34, "x": 35, "w": 36, "q": 37, "î": 38}