diff --git a/pipeline.py b/pipeline.py index 91520b8..e2f9ea3 100644 --- a/pipeline.py +++ b/pipeline.py @@ -2,7 +2,8 @@ import os from typing import Dict, List, Any from PIL import Image import jax -from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel +from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel, VisionEncoderDecoderModel +import torch class PreTrainedPipeline(): @@ -11,18 +12,24 @@ class PreTrainedPipeline(): model_dir = path - self.model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir) + # self.model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir) + self.model = VisionEncoderDecoderModel.from_pretrained(model_dir) self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir) self.tokenizer = AutoTokenizer.from_pretrained(model_dir) max_length = 16 num_beams = 4 - self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams} + # self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams} + self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams, return_dict_in_generate=True} - @jax.jit + self.model.to("cpu") + self.model.eval() + + # @jax.jit def _generate(pixel_values): - output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences + with torch.no_grad(): + output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences return output_ids self.generate = _generate @@ -39,7 +46,8 @@ class PreTrainedPipeline(): Return: """ - pixel_values = self.feature_extractor(images=inputs, return_tensors="np").pixel_values + # pixel_values = self.feature_extractor(images=inputs, return_tensors="np").pixel_values + pixel_values = self.feature_extractor(images=inputs, return_tensors="pt").pixel_values output_ids = self.generate(pixel_values) preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)