Fix bug
This commit is contained in:
parent
eb55ff050e
commit
53f019758b
|
@ -176,6 +176,8 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||
mask_token='[MASK]',
|
||||
gmask_token='[gMASK]',
|
||||
padding_side="left",
|
||||
pad_token="<pad>",
|
||||
unk_token="<unk>",
|
||||
num_image_tokens=20000,
|
||||
**kwargs
|
||||
) -> None:
|
||||
|
@ -188,6 +190,8 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||
end_token=end_token,
|
||||
mask_token=mask_token,
|
||||
gmask_token=gmask_token,
|
||||
pad_token=pad_token,
|
||||
unk_token=unk_token,
|
||||
num_image_tokens=num_image_tokens,
|
||||
**kwargs
|
||||
)
|
||||
|
@ -402,6 +406,10 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||
encoded_inputs["attention_mask"] = attention_mask
|
||||
|
||||
if "position_ids" not in encoded_inputs:
|
||||
if bos_token_id in required_input:
|
||||
context_length = required_input.index(bos_token_id)
|
||||
else:
|
||||
context_length = seq_length
|
||||
position_ids = np.arange(seq_length, dtype=np.int64)
|
||||
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
|
||||
if mask_token in required_input:
|
||||
|
|
Loading…
Reference in New Issue