Fix generate

This commit is contained in:
duzx16 2023-04-02 14:51:17 +08:00
parent 08bc85104d
commit fb23542cfe
2 changed files with 14 additions and 10 deletions

View File

@ -1054,6 +1054,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
if attention_mask is not None and attention_mask.dtype == torch.bool:
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
new_attention_mask = attention_mask[:, :, -1:].clone()
@ -1092,8 +1093,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
# only last token for input_ids if past is not None
if past is not None or past_key_values is not None:
last_token = input_ids[:, -1].unsqueeze(-1)
if attention_mask is not None:
if attention_mask is not None and attention_mask.dtype == torch.bool:
attention_mask = attention_mask[:, :, -1:]
else:
attention_mask = None
if position_ids is not None:
position_ids = position_ids[..., -1:]
else:
@ -1115,6 +1118,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
"attention_mask": attention_mask
}
else:
if attention_mask is not None and attention_mask.dtype != torch.bool:
logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
attention_mask = None
if attention_mask is None:
attention_mask = self.get_masks(
input_ids,

View File

@ -382,8 +382,6 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
mask_token_id = self.sp_tokenizer[self.mask_token]
gmask_token_id = self.sp_tokenizer[self.gmask_token]
assert self.padding_side == "left"
if return_attention_mask is None:
return_attention_mask = "attention_mask" in self.model_input_names
required_input = encoded_inputs[self.model_input_names[0]]
seq_length = len(required_input)