diff --git a/pipeline.py b/pipeline.py index 0b6bbc8..1d9894b 100644 --- a/pipeline.py +++ b/pipeline.py @@ -53,4 +53,6 @@ class PreTrainedPipeline(): preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) preds = [pred.strip() for pred in preds] + preds = [{"label": preds[0], "score": 1.0}] + return preds