Fix generate
This commit is contained in:
parent
08bc85104d
commit
fb23542cfe
|
@ -1054,6 +1054,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
# update attention mask
|
# update attention mask
|
||||||
if "attention_mask" in model_kwargs:
|
if "attention_mask" in model_kwargs:
|
||||||
attention_mask = model_kwargs["attention_mask"]
|
attention_mask = model_kwargs["attention_mask"]
|
||||||
|
if attention_mask is not None and attention_mask.dtype == torch.bool:
|
||||||
attention_mask = torch.cat(
|
attention_mask = torch.cat(
|
||||||
[attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
|
[attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
|
||||||
new_attention_mask = attention_mask[:, :, -1:].clone()
|
new_attention_mask = attention_mask[:, :, -1:].clone()
|
||||||
|
@ -1092,8 +1093,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
# only last token for input_ids if past is not None
|
# only last token for input_ids if past is not None
|
||||||
if past is not None or past_key_values is not None:
|
if past is not None or past_key_values is not None:
|
||||||
last_token = input_ids[:, -1].unsqueeze(-1)
|
last_token = input_ids[:, -1].unsqueeze(-1)
|
||||||
if attention_mask is not None:
|
if attention_mask is not None and attention_mask.dtype == torch.bool:
|
||||||
attention_mask = attention_mask[:, :, -1:]
|
attention_mask = attention_mask[:, :, -1:]
|
||||||
|
else:
|
||||||
|
attention_mask = None
|
||||||
if position_ids is not None:
|
if position_ids is not None:
|
||||||
position_ids = position_ids[..., -1:]
|
position_ids = position_ids[..., -1:]
|
||||||
else:
|
else:
|
||||||
|
@ -1115,6 +1118,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||||
"attention_mask": attention_mask
|
"attention_mask": attention_mask
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
if attention_mask is not None and attention_mask.dtype != torch.bool:
|
||||||
|
logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
|
||||||
|
attention_mask = None
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = self.get_masks(
|
attention_mask = self.get_masks(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
|
|
@ -382,8 +382,6 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
mask_token_id = self.sp_tokenizer[self.mask_token]
|
mask_token_id = self.sp_tokenizer[self.mask_token]
|
||||||
gmask_token_id = self.sp_tokenizer[self.gmask_token]
|
gmask_token_id = self.sp_tokenizer[self.gmask_token]
|
||||||
assert self.padding_side == "left"
|
assert self.padding_side == "left"
|
||||||
if return_attention_mask is None:
|
|
||||||
return_attention_mask = "attention_mask" in self.model_input_names
|
|
||||||
|
|
||||||
required_input = encoded_inputs[self.model_input_names[0]]
|
required_input = encoded_inputs[self.model_input_names[0]]
|
||||||
seq_length = len(required_input)
|
seq_length = len(required_input)
|
||||||
|
|
Loading…
Reference in New Issue