Fix attention_mask and position_ids

This commit is contained in:
duzx16 2023-04-02 01:58:45 +08:00
parent e22cddf212
commit 373fd6b9d4
1 changed files with 22 additions and 20 deletions

View File

@ -340,7 +340,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
token_ids_0 += [self.sp_tokenizer[self.bos_token]] token_ids_0 += [self.sp_tokenizer[self.bos_token]]
if token_ids_1 is not None: if token_ids_1 is not None:
if token_ids_1[-1] != eop_id: if not token_ids_1 or token_ids_1[-1] != eop_id:
token_ids_1 += [eop_id] token_ids_1 += [eop_id]
token_ids_0 += token_ids_1 token_ids_0 += token_ids_1
@ -397,26 +397,28 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
# Initialize attention mask if not present. # Initialize attention mask if not present.
if return_attention_mask: if max_length is not None:
if bos_token_id in required_input: if "attention_mask" not in encoded_inputs:
context_length = required_input.index(bos_token_id) if bos_token_id in required_input:
else: context_length = required_input.index(bos_token_id)
context_length = seq_length else:
attention_mask = np.ones((1, seq_length, seq_length)) context_length = seq_length
attention_mask = np.tril(attention_mask) attention_mask = np.ones((1, seq_length, seq_length))
attention_mask[:, :, :context_length] = 1 attention_mask = np.tril(attention_mask)
attention_mask = np.bool_(attention_mask < 0.5) attention_mask[:, :, :context_length] = 1
encoded_inputs["attention_mask"] = attention_mask attention_mask = np.bool_(attention_mask < 0.5)
encoded_inputs["attention_mask"] = attention_mask
position_ids = np.arange(seq_length, dtype=np.int64) if "position_ids" not in encoded_inputs:
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id position_ids = np.arange(seq_length, dtype=np.int64)
if mask_token in required_input: mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
mask_position = required_input.index(mask_token) if mask_token in required_input:
position_ids[context_length:] = mask_position mask_position = required_input.index(mask_token)
block_position_ids = np.concatenate( position_ids[context_length:] = mask_position
[np.zeros(context_length, dtype=np.int64), block_position_ids = np.concatenate(
np.arange(1, seq_length - context_length + 1, dtype=np.int64)]) [np.zeros(context_length, dtype=np.int64),
encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0) np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
if needs_to_be_padded: if needs_to_be_padded:
difference = max_length - len(required_input) difference = max_length - len(required_input)