use float()
This commit is contained in:
parent
f7a66d7523
commit
b0c66a1229
|
@ -6,7 +6,7 @@ class EndpointHandler:
|
||||||
def __init__(self, path=""):
|
def __init__(self, path=""):
|
||||||
# load model and processor from path
|
# load model and processor from path
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
|
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
|
||||||
self.model = AutoModel.from_pretrained(path, trust_remote_code=True).half().cuda()
|
self.model = AutoModel.from_pretrained(path, trust_remote_code=True).float()
|
||||||
|
|
||||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue