use float()

This commit is contained in:
Wen Sun 2023-03-30 02:28:59 +09:00
parent f7a66d7523
commit b0c66a1229
No known key found for this signature in database
GPG Key ID: C1231D5B4615398A
1 changed files with 1 additions and 1 deletions

View File

@ -6,7 +6,7 @@ class EndpointHandler:
def __init__(self, path=""): def __init__(self, path=""):
# load model and processor from path # load model and processor from path
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self.model = AutoModel.from_pretrained(path, trust_remote_code=True).half().cuda() self.model = AutoModel.from_pretrained(path, trust_remote_code=True).float()
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
""" """