116 lines
5.0 KiB
Python
116 lines
5.0 KiB
Python
import gradio as gr
|
|
import json
|
|
import random
|
|
import sys
|
|
sys.path.append('./api')
|
|
from main import correct_api
|
|
|
|
# def correct(text='今天上学不要赤道啊'):
|
|
def correct(text):
|
|
error_json = json.loads(correct_api(text))
|
|
print(error_json)
|
|
error_json_res = {}
|
|
error_json_res['error_details'] = []
|
|
for item in error_json["error_details"]:
|
|
if item['have_target']:
|
|
item['target_text'] = item['target_text'][0]
|
|
error_json_res['error_details'].append(item)
|
|
error_json_res["error_count"] = len(error_json_res["error_details"])
|
|
# return {
|
|
# "error_count": 1,
|
|
# "error_details":[{
|
|
# "index_start": 6,
|
|
# "index_end": 8,
|
|
# "source_text": "赤道",
|
|
# "target_text": "迟到"}]
|
|
# }
|
|
return error_json_res
|
|
|
|
def convert_to_ner(text):
|
|
model_output = correct(text)
|
|
model_output['text'] = text
|
|
res_dict = {}
|
|
res_dict['text'] = text
|
|
res_dict['entities'] = []
|
|
for item in model_output['error_details']:
|
|
insert_item = {}
|
|
insert_item['entity'] = item['target_text']
|
|
insert_item['start'] = item['index_start']
|
|
insert_item['end'] = item['index_end']
|
|
res_dict['entities'].append(insert_item)
|
|
return res_dict, model_output, get_error_item(model_output)
|
|
|
|
def get_error_item(error_json):
|
|
if error_json['error_details']:
|
|
item = error_json['error_details'][0]
|
|
res_dict = {}
|
|
res_dict['text'] = f'({1+error_json["error_count"]-len(error_json["error_details"])}/{error_json["error_count"]})建议将{item["source_text"]}修改为{item["target_text"]}'
|
|
res_dict['entities'] = []
|
|
prefix = len(f'({1+error_json["error_count"]-len(error_json["error_details"])}/{error_json["error_count"]})建议将')
|
|
insert_item = {}
|
|
insert_item['entity'] = 'wrong'
|
|
insert_item['start'] = prefix
|
|
insert_item['end'] = prefix + len(item["source_text"])
|
|
res_dict['entities'].append(insert_item)
|
|
prefix += 3
|
|
insert_item = {}
|
|
insert_item['entity'] = 'right'
|
|
insert_item['start'] = prefix + len(item["source_text"])
|
|
insert_item['end'] = prefix + len(item["source_text"]) + len(item["target_text"])
|
|
res_dict['entities'].append(insert_item)
|
|
return res_dict
|
|
else:
|
|
return {'text':'未发现错误:)', 'entities':[{'entity':'right', 'start':0, 'end':7}]}
|
|
|
|
def get_tri_res(error_json):
|
|
res_dict = {}
|
|
res_dict['text'] = error_json['text']
|
|
res_dict['entities'] = []
|
|
for item in error_json['error_details']:
|
|
insert_item = {}
|
|
insert_item['entity'] = item['target_text']
|
|
insert_item['start'] = item['index_start']
|
|
insert_item['end'] = item['index_end']
|
|
res_dict['entities'].append(insert_item)
|
|
return res_dict, error_json, get_error_item(error_json)
|
|
|
|
def accept(error_json):
|
|
raw_text = error_json['text']
|
|
if error_json['error_details']:
|
|
item_card = error_json['error_details'].pop(0)
|
|
error_json['text'] = raw_text[:item_card['index_start']] + item_card['target_text'] + raw_text[item_card['index_end']:]
|
|
if item_card['target_text'] != item_card['source_text']:
|
|
dif = len(item_card['target_text']) - len(item_card['source_text'])
|
|
for item in error_json['error_details']:
|
|
item['index_start'] += dif
|
|
item['index_end'] += dif
|
|
return get_tri_res(error_json)
|
|
|
|
def reject(error_json):
|
|
if error_json['error_details']:
|
|
_ = error_json['error_details'].pop(0)
|
|
return get_tri_res(error_json)
|
|
|
|
if __name__ == '__main__':
|
|
with gr.Blocks() as demo:
|
|
with gr.Row():
|
|
gr.Markdown('# <center> 文本智能校对Demo')
|
|
with gr.Row():
|
|
with gr.Column(scale=1):
|
|
# input_text = gr.Textbox(label='input', lines=25, max_lines=25, placeholder='请输入待校对的文本:').style(show_copy_button=True)
|
|
input_text = gr.Textbox(label='input', placeholder='请输入待校对的文本:').style(show_copy_button=True)
|
|
button_submit = gr.Button(value="check")
|
|
with gr.Column(scale=1):
|
|
diaplay = gr.HighlightedText(label='result', show_label=True)
|
|
hidden_text = gr.JSON(visible=False)
|
|
item_card = gr.HighlightedText(label='error_item',show_label=False).style(color_map={'wrong':'red', 'right':'green'})
|
|
with gr.Row():
|
|
ac_button = gr.Button(value="accept")
|
|
rj_button = gr.Button(value="reject")
|
|
|
|
button_submit.click(convert_to_ner, inputs=input_text, outputs=[diaplay, hidden_text, item_card])
|
|
ac_button.click(accept, inputs=hidden_text, outputs=[diaplay, hidden_text, item_card])
|
|
rj_button.click(reject, inputs=hidden_text, outputs=[diaplay, hidden_text, item_card])
|
|
|
|
demo.launch(server_name='0.0.0.0')
|