import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration
from gradio.themes.utils import sizes
theme = gr.themes.Default(radius_size=sizes.radius_none).set(
    block_label_text_color = '#4D63FF',
    block_title_text_color = '#4D63FF',
    button_primary_text_color = '#4D63FF',
    button_primary_background_fill='#FFFFFF',
    button_primary_border_color='#4D63FF',
    button_primary_background_fill_hover='#EDEFFF',
)


tokenizer = T5Tokenizer.from_pretrained("t5-large")
model = T5ForConditionalGeneration.from_pretrained("t5-large")

def translation(english, language):
    if language == 'German':
        input_ids = tokenizer("translate English to German: " + english, return_tensors="pt").input_ids
    elif language == 'French':
        input_ids = tokenizer("translate English to French: " + english, return_tensors="pt").input_ids
    else:
        input_ids = tokenizer("translate English to Romanian: " + english, return_tensors="pt").input_ids
    outputs = model.generate(input_ids)

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

demo = gr.Interface(fn=translation, 
                    inputs=['text', gr.inputs.Radio(['German','French','Romanian'], type='value', default='German', label='language')], 
                    outputs='text',
                    theme = theme,
                    title = "翻译"
                    )


if __name__ == "__main__":
    demo.queue(concurrency_count=10)
#    demo.launch(server_name = "0.0.0.0", server_port = 7028)
    demo.launch