Update app.py
This commit is contained in:
parent
9d96329391
commit
e0ec530bf4
11
app.py
11
app.py
|
@ -35,14 +35,14 @@ def predict(text,
|
|||
return
|
||||
|
||||
inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens)
|
||||
if inputs is False:
|
||||
yield chatbot+[[text,"Sorry, the input is too long."]],history,"Generate Fail"
|
||||
return
|
||||
if inputs is None:
|
||||
yield chatbot,history,"Too Long Input"
|
||||
return
|
||||
else:
|
||||
prompt,inputs=inputs
|
||||
begin_length = len(prompt)
|
||||
input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device)
|
||||
torch.cuda.empty_cache()
|
||||
input_ids = inputs["input_ids"].to(device)
|
||||
global total_count
|
||||
total_count += 1
|
||||
print(total_count)
|
||||
|
@ -63,6 +63,7 @@ def predict(text,
|
|||
return
|
||||
except:
|
||||
pass
|
||||
torch.cuda.empty_cache()
|
||||
#print(text)
|
||||
#print(x)
|
||||
#print("="*80)
|
||||
|
@ -150,7 +151,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|||
)
|
||||
max_context_length_tokens = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=4096,
|
||||
maximum=3072,
|
||||
value=2048,
|
||||
step=128,
|
||||
interactive=True,
|
||||
|
|
Loading…
Reference in New Issue