@ -44,6 +44,7 @@ def predict(text,
begin_length = len(prompt)
torch.cuda.empty_cache()
input_ids = inputs["input_ids"].to(device)
global total_count
total_count += 1
print(total_count)
with torch.no_grad():