From e0ec530bf40ce0555bd3320f04c75de93609d614 Mon Sep 17 00:00:00 2001 From: Baize Date: Tue, 4 Apr 2023 15:55:15 +0000 Subject: [PATCH] Update app.py --- app.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/app.py b/app.py index 7bcf3cd..4dadcf1 100644 --- a/app.py +++ b/app.py @@ -35,14 +35,14 @@ def predict(text, return inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens) - if inputs is False: - yield chatbot+[[text,"Sorry, the input is too long."]],history,"Generate Fail" - return + if inputs is None: + yield chatbot,history,"Too Long Input" + return else: prompt,inputs=inputs begin_length = len(prompt) + input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device) torch.cuda.empty_cache() - input_ids = inputs["input_ids"].to(device) global total_count total_count += 1 print(total_count) @@ -63,6 +63,7 @@ def predict(text, return except: pass + torch.cuda.empty_cache() #print(text) #print(x) #print("="*80) @@ -150,7 +151,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo: ) max_context_length_tokens = gr.Slider( minimum=0, - maximum=4096, + maximum=3072, value=2048, step=128, interactive=True,