Create handler.py
This commit is contained in:
parent
dac03c3ac8
commit
0f1c288530
|
@ -0,0 +1,23 @@
|
||||||
|
from typing import Dict, List, Any
|
||||||
|
from transformers import AutoTokenizer, AutoModel
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class EndpointHandler:
|
||||||
|
def __init__(self, path=""):
|
||||||
|
# load model and processor from path
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
|
||||||
|
self.model = AutoModel.from_pretrained(path, trust_remote_code=True).half().cuda()
|
||||||
|
|
||||||
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
data (:dict:):
|
||||||
|
The payload with the text prompt and generation parameters.
|
||||||
|
"""
|
||||||
|
# process input
|
||||||
|
inputs = data.pop("inputs", data)
|
||||||
|
history = data.pop("history", None)
|
||||||
|
|
||||||
|
response, new_history = self.model.chat(self.tokenizer, inputs, history)
|
||||||
|
|
||||||
|
return [{"generated_text": response}]
|
Loading…
Reference in New Issue