diff --git a/README.md b/README.md index b5c2b0f..93382cb 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,9 @@ +--- +tags: +- image-classification +library_name: generic +--- + ## Example The model is by no means a state-of-the-art model, but nevertheless @@ -37,4 +43,4 @@ print(preds) # should produce # ['a cat laying on top of a couch next to another cat'] -``` \ No newline at end of file +``` diff --git a/pipeline.py b/pipeline.py new file mode 100644 index 0000000..91520b8 --- /dev/null +++ b/pipeline.py @@ -0,0 +1,48 @@ +import os +from typing import Dict, List, Any +from PIL import Image +import jax +from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel + + +class PreTrainedPipeline(): + + def __init__(self, path=""): + + model_dir = path + + self.model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir) + self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir) + self.tokenizer = AutoTokenizer.from_pretrained(model_dir) + + max_length = 16 + num_beams = 4 + self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams} + + @jax.jit + def _generate(pixel_values): + + output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences + return output_ids + + self.generate = _generate + + # compile the model + image_path = os.path.join(path, 'val_000000039769.jpg') + image = Image.open(image_path) + self(image) + image.close() + + def __call__(self, inputs: "Image.Image") -> List[str]: + """ + Args: + Return: + """ + + pixel_values = self.feature_extractor(images=inputs, return_tensors="np").pixel_values + + output_ids = self.generate(pixel_values) + preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) + preds = [pred.strip() for pred in preds] + + return preds