diff --git a/app.py b/app.py index cbdb15f..dea5e08 100644 --- a/app.py +++ b/app.py @@ -41,7 +41,7 @@ def predict(text, else: prompt,inputs=inputs begin_length = len(prompt) - + torch.cuda.empty_cache() input_ids = inputs["input_ids"].to(device) with torch.no_grad():