Fix encode method

This commit is contained in:
duzx16 2023-04-02 01:04:40 +08:00
parent e1494f222d
commit e22cddf212
1 changed files with 8 additions and 6 deletions

View File

@ -398,18 +398,20 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
# Initialize attention mask if not present.
if return_attention_mask:
if bos_token_id in required_input:
context_length = required_input.index(bos_token_id)
else:
context_length = seq_length
attention_mask = np.ones((1, seq_length, seq_length))
attention_mask = np.tril(attention_mask)
attention_mask[:, :, :context_length] = 1
attention_mask = np.bool_(attention_mask < 0.5)
encoded_inputs["attention_mask"] = attention_mask
if return_attention_mask:
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
mask_position = required_input.index(mask_token)
context_length = required_input.index(bos_token_id)
position_ids = np.arange(seq_length, dtype=np.int64)
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
if mask_token in required_input:
mask_position = required_input.index(mask_token)
position_ids[context_length:] = mask_position
block_position_ids = np.concatenate(
[np.zeros(context_length, dtype=np.int64),