71 lines
2.3 KiB
Python
71 lines
2.3 KiB
Python
from threading import Thread
|
|
from typing import Iterator
|
|
|
|
import torch
|
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
|
|
|
model_id = 'meta-llama/Llama-2-13b-chat-hf'
|
|
|
|
if torch.cuda.is_available():
|
|
config = AutoConfig.from_pretrained(model_id, use_auth_token ='hf_*************************')
|
|
config.pretraining_tp = 1
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
config=config,
|
|
torch_dtype=torch.float16,
|
|
load_in_4bit=True,
|
|
device_map='auto',
|
|
use_auth_token ='hf_*************************'
|
|
)
|
|
else:
|
|
model = None
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token ='hf_*************************')
|
|
|
|
|
|
def get_prompt(message: str, chat_history: list[tuple[str, str]],
|
|
system_prompt: str) -> str:
|
|
texts = [f'[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
|
|
for user_input, response in chat_history:
|
|
texts.append(f'{user_input.strip()} [/INST] {response.strip()} </s><s> [INST] ')
|
|
texts.append(f'{message.strip()} [/INST]')
|
|
return ''.join(texts)
|
|
|
|
|
|
def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
|
|
prompt = get_prompt(message, chat_history, system_prompt)
|
|
input_ids = tokenizer([prompt], return_tensors='np')['input_ids']
|
|
return input_ids.shape[-1]
|
|
|
|
|
|
def run(message: str,
|
|
chat_history: list[tuple[str, str]],
|
|
system_prompt: str,
|
|
max_new_tokens: int = 1024,
|
|
temperature: float = 0.8,
|
|
top_p: float = 0.95,
|
|
top_k: int = 50) -> Iterator[str]:
|
|
prompt = get_prompt(message, chat_history, system_prompt)
|
|
inputs = tokenizer([prompt], return_tensors='pt').to('cuda')
|
|
|
|
streamer = TextIteratorStreamer(tokenizer,
|
|
timeout=10.,
|
|
skip_prompt=True,
|
|
skip_special_tokens=True)
|
|
generate_kwargs = dict(
|
|
inputs,
|
|
streamer=streamer,
|
|
max_new_tokens=max_new_tokens,
|
|
do_sample=True,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
temperature=temperature,
|
|
num_beams=1,
|
|
)
|
|
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
|
t.start()
|
|
|
|
outputs = []
|
|
for text in streamer:
|
|
outputs.append(text)
|
|
yield ''.join(outputs)
|