Fix encode method
This commit is contained in:
parent
e1494f222d
commit
e22cddf212
|
@ -398,19 +398,21 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
|
|
||||||
# Initialize attention mask if not present.
|
# Initialize attention mask if not present.
|
||||||
if return_attention_mask:
|
if return_attention_mask:
|
||||||
context_length = required_input.index(bos_token_id)
|
if bos_token_id in required_input:
|
||||||
|
context_length = required_input.index(bos_token_id)
|
||||||
|
else:
|
||||||
|
context_length = seq_length
|
||||||
attention_mask = np.ones((1, seq_length, seq_length))
|
attention_mask = np.ones((1, seq_length, seq_length))
|
||||||
attention_mask = np.tril(attention_mask)
|
attention_mask = np.tril(attention_mask)
|
||||||
attention_mask[:, :, :context_length] = 1
|
attention_mask[:, :, :context_length] = 1
|
||||||
attention_mask = np.bool_(attention_mask < 0.5)
|
attention_mask = np.bool_(attention_mask < 0.5)
|
||||||
encoded_inputs["attention_mask"] = attention_mask
|
encoded_inputs["attention_mask"] = attention_mask
|
||||||
|
|
||||||
if return_attention_mask:
|
|
||||||
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
|
|
||||||
mask_position = required_input.index(mask_token)
|
|
||||||
context_length = required_input.index(bos_token_id)
|
|
||||||
position_ids = np.arange(seq_length, dtype=np.int64)
|
position_ids = np.arange(seq_length, dtype=np.int64)
|
||||||
position_ids[context_length:] = mask_position
|
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
|
||||||
|
if mask_token in required_input:
|
||||||
|
mask_position = required_input.index(mask_token)
|
||||||
|
position_ids[context_length:] = mask_position
|
||||||
block_position_ids = np.concatenate(
|
block_position_ids = np.concatenate(
|
||||||
[np.zeros(context_length, dtype=np.int64),
|
[np.zeros(context_length, dtype=np.int64),
|
||||||
np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
|
np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
|
||||||
|
|
Loading…
Reference in New Issue