34 lines
1.3 KiB
Python
34 lines
1.3 KiB
Python
|
import gradio as gr
|
||
|
import torch
|
||
|
from transformers import LayoutLMv3Processor, LayoutLMv3ForQuestionAnswering
|
||
|
from PIL import Image
|
||
|
|
||
|
|
||
|
processor = LayoutLMv3Processor.from_pretrained("rubentito/layoutlmv3-base-mpdocvqa", apply_ocr=False)
|
||
|
model = LayoutLMv3ForQuestionAnswering.from_pretrained("rubentito/layoutlmv3-base-mpdocvqa")
|
||
|
|
||
|
|
||
|
def vqa(image, question):
|
||
|
inp = Image.fromarray(image.astype('uint8'), 'RGB')
|
||
|
|
||
|
context = ["Example"]
|
||
|
boxes = [0, 0, 1000, 1000] # This is an example bounding box covering the whole image.
|
||
|
document_encoding = processor(inp, question, context, boxes=boxes, return_tensors="pt")
|
||
|
outputs = model(**document_encoding)
|
||
|
|
||
|
start_idx = torch.argmax(outputs.start_logits, axis=1)
|
||
|
end_idx = torch.argmax(outputs.end_logits, axis=1)
|
||
|
answers = self.processor.tokenizer.decode(input_tokens[start_idx: end_idx+1]).strip()
|
||
|
|
||
|
return answers
|
||
|
|
||
|
demo = gr.Interface(fn=vqa,
|
||
|
inputs=['image', 'text'],
|
||
|
outputs='text',
|
||
|
title = "vqa",
|
||
|
examples = [['income.png', 'What are the 2020 net sales?'], ['invoice.png','What is the invoice number?']])
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
demo.queue(concurrency_count=3).launch(server_name = "0.0.0.0", server_port = 7026)
|