Update app.py

This commit is contained in:
Baize 2023-04-04 15:55:15 +00:00 committed by huggingface-web
parent 9d96329391
commit e0ec530bf4
1 changed files with 6 additions and 5 deletions

9
app.py
View File

@ -35,14 +35,14 @@ def predict(text,
return return
inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens) inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens)
if inputs is False: if inputs is None:
yield chatbot+[[text,"Sorry, the input is too long."]],history,"Generate Fail" yield chatbot,history,"Too Long Input"
return return
else: else:
prompt,inputs=inputs prompt,inputs=inputs
begin_length = len(prompt) begin_length = len(prompt)
input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device)
torch.cuda.empty_cache() torch.cuda.empty_cache()
input_ids = inputs["input_ids"].to(device)
global total_count global total_count
total_count += 1 total_count += 1
print(total_count) print(total_count)
@ -63,6 +63,7 @@ def predict(text,
return return
except: except:
pass pass
torch.cuda.empty_cache()
#print(text) #print(text)
#print(x) #print(x)
#print("="*80) #print("="*80)
@ -150,7 +151,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
) )
max_context_length_tokens = gr.Slider( max_context_length_tokens = gr.Slider(
minimum=0, minimum=0,
maximum=4096, maximum=3072,
value=2048, value=2048,
step=128, step=128,
interactive=True, interactive=True,