23 lines
820 B
Python
23 lines
820 B
Python
from typing import Dict, List, Any
|
|
from transformers import AutoTokenizer, AutoModel
|
|
import torch
|
|
|
|
class EndpointHandler:
|
|
def __init__(self, path=""):
|
|
# load model and processor from path
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
|
|
self.model = AutoModel.from_pretrained(path, trust_remote_code=True).half().cuda()
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
|
|
"""
|
|
Args:
|
|
data (:dict:):
|
|
The payload with the text prompt and generation parameters.
|
|
"""
|
|
# process input
|
|
inputs = data.pop("inputs", data)
|
|
history = data.pop("history", None)
|
|
|
|
response, new_history = self.model.chat(self.tokenizer, inputs, history)
|
|
|
|
return [{"generated_text": response}] |