Update app.py
This commit is contained in:
parent
9d96329391
commit
e0ec530bf4
11
app.py
11
app.py
|
@ -35,14 +35,14 @@ def predict(text,
|
||||||
return
|
return
|
||||||
|
|
||||||
inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens)
|
inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens)
|
||||||
if inputs is False:
|
if inputs is None:
|
||||||
yield chatbot+[[text,"Sorry, the input is too long."]],history,"Generate Fail"
|
yield chatbot,history,"Too Long Input"
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
prompt,inputs=inputs
|
prompt,inputs=inputs
|
||||||
begin_length = len(prompt)
|
begin_length = len(prompt)
|
||||||
|
input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
input_ids = inputs["input_ids"].to(device)
|
|
||||||
global total_count
|
global total_count
|
||||||
total_count += 1
|
total_count += 1
|
||||||
print(total_count)
|
print(total_count)
|
||||||
|
@ -63,6 +63,7 @@ def predict(text,
|
||||||
return
|
return
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
torch.cuda.empty_cache()
|
||||||
#print(text)
|
#print(text)
|
||||||
#print(x)
|
#print(x)
|
||||||
#print("="*80)
|
#print("="*80)
|
||||||
|
@ -150,7 +151,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
||||||
)
|
)
|
||||||
max_context_length_tokens = gr.Slider(
|
max_context_length_tokens = gr.Slider(
|
||||||
minimum=0,
|
minimum=0,
|
||||||
maximum=4096,
|
maximum=3072,
|
||||||
value=2048,
|
value=2048,
|
||||||
step=128,
|
step=128,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
|
|
Loading…
Reference in New Issue