diff --git a/modeling_chatglm.py b/modeling_chatglm.py index cdcc39b..3d0888e 100644 --- a/modeling_chatglm.py +++ b/modeling_chatglm.py @@ -1054,13 +1054,14 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): # update attention mask if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] - attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3) - new_attention_mask = attention_mask[:, :, -1:].clone() - new_attention_mask[..., -1] = False - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, new_attention_mask], dim=2 - ) + if attention_mask is not None and attention_mask.dtype == torch.bool: + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3) + new_attention_mask = attention_mask[:, :, -1:].clone() + new_attention_mask[..., -1] = False + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, new_attention_mask], dim=2 + ) # update position ids if "position_ids" in model_kwargs: @@ -1092,8 +1093,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): # only last token for input_ids if past is not None if past is not None or past_key_values is not None: last_token = input_ids[:, -1].unsqueeze(-1) - if attention_mask is not None: + if attention_mask is not None and attention_mask.dtype == torch.bool: attention_mask = attention_mask[:, :, -1:] + else: + attention_mask = None if position_ids is not None: position_ids = position_ids[..., -1:] else: @@ -1115,6 +1118,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): "attention_mask": attention_mask } else: + if attention_mask is not None and attention_mask.dtype != torch.bool: + logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool") + attention_mask = None if attention_mask is None: attention_mask = self.get_masks( input_ids, diff --git a/tokenization_chatglm.py b/tokenization_chatglm.py index ba86626..3062c7c 100644 --- a/tokenization_chatglm.py +++ b/tokenization_chatglm.py @@ -382,8 +382,6 @@ class ChatGLMTokenizer(PreTrainedTokenizer): mask_token_id = self.sp_tokenizer[self.mask_token] gmask_token_id = self.sp_tokenizer[self.gmask_token] assert self.padding_side == "left" - if return_attention_mask is None: - return_attention_mask = "attention_mask" in self.model_input_names required_input = encoded_inputs[self.model_input_names[0]] seq_length = len(required_input)