Create handler.py
This commit is contained in:
parent
dac03c3ac8
commit
0f1c288530
|
@ -0,0 +1,23 @@
|
|||
from typing import Dict, List, Any
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import torch
|
||||
|
||||
class EndpointHandler:
|
||||
def __init__(self, path=""):
|
||||
# load model and processor from path
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
|
||||
self.model = AutoModel.from_pretrained(path, trust_remote_code=True).half().cuda()
|
||||
|
||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""
|
||||
Args:
|
||||
data (:dict:):
|
||||
The payload with the text prompt and generation parameters.
|
||||
"""
|
||||
# process input
|
||||
inputs = data.pop("inputs", data)
|
||||
history = data.pop("history", None)
|
||||
|
||||
response, new_history = self.model.chat(self.tokenizer, inputs, history)
|
||||
|
||||
return [{"generated_text": response}]
|
Loading…
Reference in New Issue