from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import gradio as gr
from gradio.themes.utils import sizes

theme = gr.themes.Default(radius_size=sizes.radius_none).set(
    block_label_text_color = '#4D63FF',
    block_title_text_color = '#4D63FF',
    button_primary_text_color = '#4D63FF',
    button_primary_background_fill='#FFFFFF',
    button_primary_border_color='#4D63FF',
    button_primary_background_fill_hover='#EDEFFF',
)


processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-printed')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-printed')

def ocr(image):
    pixel_values = processor(images=image, return_tensors="pt").pixel_values
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    return generated_text


demo = gr.Interface(fn=ocr,
                    inputs='image',
                    outputs='text',
                    title = "ocr",
                    theme = theme,
                    examples = ['printed.jpg'])


if __name__ == "__main__":
   # demo.queue(concurrency_count=3)
    #demo.launch(server_name = "0.0.0.0", server_port = 7011)
    demo.launch()