Fix generate
This commit is contained in:
parent
08bc85104d
commit
fb23542cfe
|
@ -1054,13 +1054,14 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
# update attention mask
|
||||
if "attention_mask" in model_kwargs:
|
||||
attention_mask = model_kwargs["attention_mask"]
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
|
||||
new_attention_mask = attention_mask[:, :, -1:].clone()
|
||||
new_attention_mask[..., -1] = False
|
||||
model_kwargs["attention_mask"] = torch.cat(
|
||||
[attention_mask, new_attention_mask], dim=2
|
||||
)
|
||||
if attention_mask is not None and attention_mask.dtype == torch.bool:
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
|
||||
new_attention_mask = attention_mask[:, :, -1:].clone()
|
||||
new_attention_mask[..., -1] = False
|
||||
model_kwargs["attention_mask"] = torch.cat(
|
||||
[attention_mask, new_attention_mask], dim=2
|
||||
)
|
||||
|
||||
# update position ids
|
||||
if "position_ids" in model_kwargs:
|
||||
|
@ -1092,8 +1093,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
# only last token for input_ids if past is not None
|
||||
if past is not None or past_key_values is not None:
|
||||
last_token = input_ids[:, -1].unsqueeze(-1)
|
||||
if attention_mask is not None:
|
||||
if attention_mask is not None and attention_mask.dtype == torch.bool:
|
||||
attention_mask = attention_mask[:, :, -1:]
|
||||
else:
|
||||
attention_mask = None
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids[..., -1:]
|
||||
else:
|
||||
|
@ -1115,6 +1118,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|||
"attention_mask": attention_mask
|
||||
}
|
||||
else:
|
||||
if attention_mask is not None and attention_mask.dtype != torch.bool:
|
||||
logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
|
||||
attention_mask = None
|
||||
if attention_mask is None:
|
||||
attention_mask = self.get_masks(
|
||||
input_ids,
|
||||
|
|
|
@ -382,8 +382,6 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|||
mask_token_id = self.sp_tokenizer[self.mask_token]
|
||||
gmask_token_id = self.sp_tokenizer[self.gmask_token]
|
||||
assert self.padding_side == "left"
|
||||
if return_attention_mask is None:
|
||||
return_attention_mask = "attention_mask" in self.model_input_names
|
||||
|
||||
required_input = encoded_inputs[self.model_input_names[0]]
|
||||
seq_length = len(required_input)
|
||||
|
|
Loading…
Reference in New Issue