fix attention flag
This commit is contained in:
parent
ab8048eec0
commit
3dee02cafd
10
README.md
10
README.md
|
@ -23,7 +23,7 @@ model-index:
|
|||
metrics:
|
||||
- name: Test WER
|
||||
type: wer
|
||||
value: 24.91
|
||||
value: 22.60
|
||||
---
|
||||
|
||||
# Wav2Vec2-Base-760-Turkish
|
||||
|
@ -102,11 +102,13 @@ test_dataset = test_dataset.map(speech_file_to_array_fn)
|
|||
|
||||
# Preprocessing the datasets.
|
||||
# We need to read the aduio files as arrays
|
||||
|
||||
#Attention mask is not used because the base-model was not trained with it. reference: https://github.com/huggingface/transformers/blob/403d530eec105c0e229fc2b754afdf77a4439def/src/transformers/models/wav2vec2/tokenization_wav2vec2.py#L305
|
||||
def evaluate(batch):
|
||||
inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
|
||||
logits = model(inputs.input_values.to("cuda")).logits
|
||||
|
||||
pred_ids = torch.argmax(logits, dim=-1)
|
||||
batch["pred_strings"] = processor.batch_decode(pred_ids,skip_special_tokens=True)
|
||||
|
@ -117,7 +119,9 @@ result = test_dataset.map(evaluate, batched=True, batch_size=8)
|
|||
print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))
|
||||
```
|
||||
|
||||
**Test Result**: 24.91 % (in progress)
|
||||
**Test Results**:
|
||||
- WER: 22.602390
|
||||
- CER: 6.054137
|
||||
|
||||
|
||||
## Training
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
"do_normalize": true,
|
||||
"feature_size": 1,
|
||||
"padding_side": "right",
|
||||
"padding_value": 0.0,
|
||||
"padding_value": 0,
|
||||
"return_attention_mask": true,
|
||||
"sampling_rate": 16000
|
||||
}
|
||||
|
|
BIN
pytorch_model.bin (Stored with Git LFS)
BIN
pytorch_model.bin (Stored with Git LFS)
Binary file not shown.
|
@ -1 +1 @@
|
|||
{"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>", "pad_token": "<pad>", "do_lower_case": false, "word_delimiter_token": "|","special_tokens_map_file": "/home/ceyda/workspace/libs/fairseq/hf_finetuned_output/special_tokens_map.json", "tokenizer_file": null}
|
||||
{"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>", "pad_token": "<pad>", "do_lower_case": false, "word_delimiter_token": "|"}
|
40
vocab.json
40
vocab.json
|
@ -1,39 +1 @@
|
|||
{"|": 4,
|
||||
"p": 5,
|
||||
"i": 6,
|
||||
"r": 7,
|
||||
"n": 8,
|
||||
"s": 9,
|
||||
"ö": 10,
|
||||
"z": 11,
|
||||
"l": 12,
|
||||
"e": 13,
|
||||
"h": 14,
|
||||
"â": 15,
|
||||
"y": 16,
|
||||
"a": 17,
|
||||
"k": 18,
|
||||
"ı": 19,
|
||||
"o": 20,
|
||||
"m": 21,
|
||||
"ü": 22,
|
||||
"g": 23,
|
||||
"c": 24,
|
||||
"b": 25,
|
||||
"ş": 26,
|
||||
"d": 27,
|
||||
"u": 28,
|
||||
"t": 29,
|
||||
"ç": 30,
|
||||
"ğ": 31,
|
||||
"v": 32,
|
||||
"f": 33,
|
||||
"j": 34,
|
||||
"x": 35,
|
||||
"w": 36,
|
||||
"q": 37,
|
||||
"î": 38,
|
||||
"<s>": 0,
|
||||
"<pad>": 1,
|
||||
"</s>": 2,
|
||||
"<unk>": 3}
|
||||
{"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3, "|": 4, "p": 5, "i": 6, "r": 7, "n": 8, "s": 9, "ö": 10, "z": 11, "l": 12, "e": 13, "h": 14, "â": 15, "y": 16, "a": 17, "k": 18, "ı": 19, "o": 20, "m": 21, "ü": 22, "g": 23, "c": 24, "b": 25, "ş": 26, "d": 27, "u": 28, "t": 29, "ç": 30, "ğ": 31, "v": 32, "f": 33, "j": 34, "x": 35, "w": 36, "q": 37, "î": 38}
|
Loading…
Reference in New Issue