diff --git a/pipeline.py b/pipeline.py index e2f9ea3..0b6bbc8 100644 --- a/pipeline.py +++ b/pipeline.py @@ -20,7 +20,7 @@ class PreTrainedPipeline(): max_length = 16 num_beams = 4 # self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams} - self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams, return_dict_in_generate=True} + self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "return_dict_in_generate": True} self.model.to("cpu") self.model.eval()