diff --git a/pipeline.py b/pipeline.py deleted file mode 100644 index 2b7fef8..0000000 --- a/pipeline.py +++ /dev/null @@ -1,48 +0,0 @@ -import os -from typing import Dict, List, Any -from PIL import Image -import jax -from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel - - -class PreTrainedPipeline(): - - def __init__(self, path=""): - - model_dir = path - - self.model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir) - self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir) - self.tokenizer = AutoTokenizer.from_pretrained(model_dir) - - max_length = 16 - num_beams = 4 - self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams} - - @jax.jit - def _generate(pixel_values): - - output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences - return output_ids - - self.generate = _generate - - # compile the model - image_path = os.path.join(path, 'val_000000039769.jpg') - image = Image.open(image_path) - self(image) - image.close() - - def __call__(self, inputs: "Image.Image") -> List[str]: - """ - Args: - Return: - """ - - pixel_values = self.feature_extractor(images=inputs, return_tensors="np").pixel_values - - output_ids = self.generate(pixel_values) - preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) - preds = [pred.strip() for pred in preds] - - return preds \ No newline at end of file