Update pipeline.py

This commit is contained in:
Yih-Dar SHIEH 2021-10-25 07:55:35 +00:00 committed by huggingface-web
parent 7bbd3592df
commit 25a7779507
1 changed files with 14 additions and 6 deletions

View File

@ -2,7 +2,8 @@ import os
from typing import Dict, List, Any from typing import Dict, List, Any
from PIL import Image from PIL import Image
import jax import jax
from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel, VisionEncoderDecoderModel
import torch
class PreTrainedPipeline(): class PreTrainedPipeline():
@ -11,18 +12,24 @@ class PreTrainedPipeline():
model_dir = path model_dir = path
self.model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir) # self.model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir)
self.model = VisionEncoderDecoderModel.from_pretrained(model_dir)
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir) self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir)
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
max_length = 16 max_length = 16
num_beams = 4 num_beams = 4
self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams} # self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams, return_dict_in_generate=True}
@jax.jit self.model.to("cpu")
self.model.eval()
# @jax.jit
def _generate(pixel_values): def _generate(pixel_values):
output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences with torch.no_grad():
output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences
return output_ids return output_ids
self.generate = _generate self.generate = _generate
@ -39,7 +46,8 @@ class PreTrainedPipeline():
Return: Return:
""" """
pixel_values = self.feature_extractor(images=inputs, return_tensors="np").pixel_values # pixel_values = self.feature_extractor(images=inputs, return_tensors="np").pixel_values
pixel_values = self.feature_extractor(images=inputs, return_tensors="pt").pixel_values
output_ids = self.generate(pixel_values) output_ids = self.generate(pixel_values)
preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)