Update pipeline.py

This commit is contained in:
Yih-Dar SHIEH 2021-10-25 07:55:35 +00:00 committed by huggingface-web
parent 7bbd3592df
commit 25a7779507
1 changed files with 14 additions and 6 deletions

View File

@ -2,7 +2,8 @@ import os
from typing import Dict, List, Any
from PIL import Image
import jax
from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel
from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel, VisionEncoderDecoderModel
import torch
class PreTrainedPipeline():
@ -11,18 +12,24 @@ class PreTrainedPipeline():
model_dir = path
self.model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir)
# self.model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir)
self.model = VisionEncoderDecoderModel.from_pretrained(model_dir)
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir)
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
max_length = 16
num_beams = 4
self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
# self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams, return_dict_in_generate=True}
@jax.jit
self.model.to("cpu")
self.model.eval()
# @jax.jit
def _generate(pixel_values):
output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences
with torch.no_grad():
output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences
return output_ids
self.generate = _generate
@ -39,7 +46,8 @@ class PreTrainedPipeline():
Return:
"""
pixel_values = self.feature_extractor(images=inputs, return_tensors="np").pixel_values
# pixel_values = self.feature_extractor(images=inputs, return_tensors="np").pixel_values
pixel_values = self.feature_extractor(images=inputs, return_tensors="pt").pixel_values
output_ids = self.generate(pixel_values)
preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)