From 117d46a1159e97ee945be227e051a38ec23e2457 Mon Sep 17 00:00:00 2001 From: Yih-Dar SHIEH Date: Mon, 25 Oct 2021 08:13:14 +0000 Subject: [PATCH] Update pipeline.py --- pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipeline.py b/pipeline.py index e2f9ea3..0b6bbc8 100644 --- a/pipeline.py +++ b/pipeline.py @@ -20,7 +20,7 @@ class PreTrainedPipeline(): max_length = 16 num_beams = 4 # self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams} - self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams, return_dict_in_generate=True} + self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "return_dict_in_generate": True} self.model.to("cpu") self.model.eval()