ailab/layoutlmv3-base-mpdocvqa/app.py

34 lines
1.3 KiB
Python

import gradio as gr
import torch
from transformers import LayoutLMv3Processor, LayoutLMv3ForQuestionAnswering
from PIL import Image
processor = LayoutLMv3Processor.from_pretrained("rubentito/layoutlmv3-base-mpdocvqa", apply_ocr=False)
model = LayoutLMv3ForQuestionAnswering.from_pretrained("rubentito/layoutlmv3-base-mpdocvqa")
def vqa(image, question):
inp = Image.fromarray(image.astype('uint8'), 'RGB')
context = ["Example"]
boxes = [0, 0, 1000, 1000] # This is an example bounding box covering the whole image.
document_encoding = processor(inp, question, context, boxes=boxes, return_tensors="pt")
outputs = model(**document_encoding)
start_idx = torch.argmax(outputs.start_logits, axis=1)
end_idx = torch.argmax(outputs.end_logits, axis=1)
answers = self.processor.tokenizer.decode(input_tokens[start_idx: end_idx+1]).strip()
return answers
demo = gr.Interface(fn=vqa,
inputs=['image', 'text'],
outputs='text',
title = "vqa",
examples = [['income.png', 'What are the 2020 net sales?'], ['invoice.png','What is the invoice number?']])
if __name__ == "__main__":
demo.queue(concurrency_count=3).launch(server_name = "0.0.0.0", server_port = 7026)