Fix attention_mask and position_ids
This commit is contained in:
parent
e22cddf212
commit
373fd6b9d4
|
@ -340,7 +340,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||
token_ids_0 += [self.sp_tokenizer[self.bos_token]]
|
||||
|
||||
if token_ids_1 is not None:
|
||||
if token_ids_1[-1] != eop_id:
|
||||
if not token_ids_1 or token_ids_1[-1] != eop_id:
|
||||
token_ids_1 += [eop_id]
|
||||
token_ids_0 += token_ids_1
|
||||
|
||||
|
@ -397,26 +397,28 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
||||
|
||||
# Initialize attention mask if not present.
|
||||
if return_attention_mask:
|
||||
if bos_token_id in required_input:
|
||||
context_length = required_input.index(bos_token_id)
|
||||
else:
|
||||
context_length = seq_length
|
||||
attention_mask = np.ones((1, seq_length, seq_length))
|
||||
attention_mask = np.tril(attention_mask)
|
||||
attention_mask[:, :, :context_length] = 1
|
||||
attention_mask = np.bool_(attention_mask < 0.5)
|
||||
encoded_inputs["attention_mask"] = attention_mask
|
||||
if max_length is not None:
|
||||
if "attention_mask" not in encoded_inputs:
|
||||
if bos_token_id in required_input:
|
||||
context_length = required_input.index(bos_token_id)
|
||||
else:
|
||||
context_length = seq_length
|
||||
attention_mask = np.ones((1, seq_length, seq_length))
|
||||
attention_mask = np.tril(attention_mask)
|
||||
attention_mask[:, :, :context_length] = 1
|
||||
attention_mask = np.bool_(attention_mask < 0.5)
|
||||
encoded_inputs["attention_mask"] = attention_mask
|
||||
|
||||
position_ids = np.arange(seq_length, dtype=np.int64)
|
||||
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
|
||||
if mask_token in required_input:
|
||||
mask_position = required_input.index(mask_token)
|
||||
position_ids[context_length:] = mask_position
|
||||
block_position_ids = np.concatenate(
|
||||
[np.zeros(context_length, dtype=np.int64),
|
||||
np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
|
||||
encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
|
||||
if "position_ids" not in encoded_inputs:
|
||||
position_ids = np.arange(seq_length, dtype=np.int64)
|
||||
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
|
||||
if mask_token in required_input:
|
||||
mask_position = required_input.index(mask_token)
|
||||
position_ids[context_length:] = mask_position
|
||||
block_position_ids = np.concatenate(
|
||||
[np.zeros(context_length, dtype=np.int64),
|
||||
np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
|
||||
encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
|
||||
|
||||
if needs_to_be_padded:
|
||||
difference = max_length - len(required_input)
|
||||
|
|
Loading…
Reference in New Issue