207 lines
8.2 KiB
Python
207 lines
8.2 KiB
Python
# https://huggingface.co/deepkyu/ml-talking-face
|
|
import os
|
|
import subprocess
|
|
|
|
#REST_IP = os.environ['REST_IP']
|
|
#SERVICE_PORT = int(os.environ['SERVICE_PORT'])
|
|
#ETRANSLATION_APIKEY_URL = os.environ['TRANSLATION_APIKEY_URL']
|
|
#$GOOGLE_APPLICATION_CREDENTIALS = os.environ['GOOGLE_APPLICATION_CREDENTIALS']
|
|
#subprocess.call(f"wget --no-check-certificate -O {GOOGLE_APPLICATION_CREDENTIALS} {TRANSLATION_APIKEY_URL}", shell=True)
|
|
|
|
#TOXICITY_THRESHOLD = float(os.getenv('TOXICITY_THRESHOLD', 0.7))
|
|
|
|
import gradio as gr
|
|
#from toxicity_estimator import PerspectiveAPI
|
|
from translator import Translator
|
|
from client_rest import RestAPIApplication
|
|
from pathlib import Path
|
|
import argparse
|
|
import threading
|
|
import yaml
|
|
|
|
TITLE = Path("docs/title.txt").read_text()
|
|
DESCRIPTION = Path("docs/description.md").read_text()
|
|
|
|
|
|
class GradioApplication:
|
|
def __init__(self, rest_ip, rest_port, max_seed):
|
|
self.lang_list = {
|
|
'ko': 'ko_KR',
|
|
'en': 'en_US',
|
|
'ja': 'ja_JP',
|
|
'zh': 'zh_CN',
|
|
'zh-CN': 'zh_CN'
|
|
}
|
|
self.background_list = [None,
|
|
"background_image/cvpr.png",
|
|
"background_image/black.png",
|
|
"background_image/river.mp4",
|
|
"background_image/sky.mp4"]
|
|
|
|
#self.perspective_api = PerspectiveAPI()
|
|
#self.translator = Translator()
|
|
self.rest_application = RestAPIApplication(rest_ip, rest_port)
|
|
self.output_dir = Path("output_file")
|
|
|
|
inputs = prepare_input()
|
|
outputs = prepare_output()
|
|
|
|
self.iface = gr.Interface(fn=self.infer,
|
|
title=TITLE,
|
|
description=DESCRIPTION,
|
|
inputs=inputs,
|
|
outputs=outputs,
|
|
allow_flagging='never',
|
|
article=Path("docs/article.md").read_text())
|
|
|
|
self.max_seed = max_seed
|
|
self._file_seed = 0
|
|
self.lock = threading.Lock()
|
|
|
|
|
|
def _get_file_seed(self):
|
|
return f"{self._file_seed % self.max_seed:02d}"
|
|
|
|
def _reset_file_seed(self):
|
|
self._file_seed = 0
|
|
|
|
def _counter_file_seed(self):
|
|
with self.lock:
|
|
self._file_seed += 1
|
|
|
|
def get_lang_code(self, lang):
|
|
return self.lang_list[lang]
|
|
|
|
def get_background_data(self, background_index):
|
|
# get background filename and its extension
|
|
data_path = self.background_list[background_index]
|
|
|
|
if data_path is not None:
|
|
with open(data_path, 'rb') as rf:
|
|
background_data = rf.read()
|
|
is_video_background = str(data_path).endswith(".mp4")
|
|
else:
|
|
background_data = None
|
|
is_video_background = False
|
|
|
|
return background_data, is_video_background
|
|
|
|
@staticmethod
|
|
def return_format(toxicity_prob, target_text, lang_dest, video_filename, detail=""):
|
|
return {'Toxicity': toxicity_prob}, f"Language: {lang_dest}\nText: {target_text}\n-\nDetails: {detail}", str(video_filename)
|
|
|
|
def infer(self, text, lang, duration_rate, action, background_index):
|
|
self._counter_file_seed()
|
|
print(f"File Seed: {self._file_seed}")
|
|
toxicity_prob = 0.0
|
|
target_text = ""
|
|
lang_dest = ""
|
|
video_filename = "vacant.mp4"
|
|
|
|
# Toxicity estimation
|
|
# try:
|
|
# toxicity_prob = self.perspective_api.get_score(text)
|
|
# except Exception as e: # when Perspective API doesn't work
|
|
# pass
|
|
#
|
|
# if toxicity_prob > TOXICITY_THRESHOLD:
|
|
# detail = "Sorry, it seems that the input text is too toxic."
|
|
# return self.return_format(toxicity_prob, target_text, lang_dest, video_filename, detail=f"Error: {detail}")
|
|
|
|
target_text = text
|
|
lang_rpc_code = "en-GB"
|
|
|
|
# Google Translate API
|
|
# try:
|
|
# target_text, lang_dest = self.translator.get_translation(text, lang)
|
|
# except Exception as e:
|
|
# target_text = ""
|
|
# lang_dest = ""
|
|
# detail = f"Error from language translation: ({e})"
|
|
# return self.return_format(toxicity_prob, target_text, lang_dest, video_filename, detail=f"Error: {detail}")
|
|
#
|
|
# try:
|
|
# self.translator.length_check(lang_dest, target_text) # assertion check
|
|
# except AssertionError as e:
|
|
# return self.return_format(toxicity_prob, target_text, lang_dest, video_filename, detail=f"Error: {str(e)}")
|
|
#
|
|
# lang_rpc_code = self.get_lang_code(lang_dest)
|
|
|
|
# Video Inference
|
|
background_data, is_video_background = self.get_background_data(background_index)
|
|
|
|
video_data = self.rest_application.get_video(target_text, lang_rpc_code, duration_rate, action.lower(),
|
|
background_data, is_video_background)
|
|
print(f"Video data size: {len(video_data)}")
|
|
|
|
video_filename = self.output_dir / f"{self._file_seed:02d}.mkv"
|
|
with open(video_filename, "wb") as video_file:
|
|
video_file.write(video_data)
|
|
|
|
return self.return_format(toxicity_prob, target_text, lang_dest, video_filename)
|
|
|
|
def run(self, server_port=7860, share=False):
|
|
try:
|
|
self.iface.launch(height=900, server_name = "0.0.0.0",
|
|
share=share, server_port=server_port,
|
|
enable_queue=True)
|
|
|
|
except KeyboardInterrupt:
|
|
gr.close_all()
|
|
|
|
|
|
def prepare_input():
|
|
text_input = gr.Textbox(lines=2,
|
|
placeholder="Type your text with English, Chinese, Korean, and Japanese.",
|
|
value="Hello, this is demonstration for talking face generation "
|
|
"with multilingual text-to-speech.",
|
|
label="Text")
|
|
lang_input = gr.Radio(['Korean', 'English', 'Japanese', 'Chinese'],
|
|
type='value',
|
|
value=None,
|
|
label="Language")
|
|
duration_rate_input = gr.Slider(minimum=0.8,
|
|
maximum=1.2,
|
|
step=0.01,
|
|
value=1.0,
|
|
label="Duration (The bigger the value, the slower the speech)")
|
|
action_input = gr.Radio(['Default', 'Hand', 'BothHand', 'HandDown', 'Sorry'],
|
|
type='value',
|
|
value='Default',
|
|
label="Select an action ...")
|
|
background_input = gr.Radio(['None', 'CVPR', 'Black', 'River', 'Sky'],
|
|
type='index',
|
|
value='None',
|
|
label="Select a background image/video ...")
|
|
|
|
return [text_input, lang_input, duration_rate_input,
|
|
action_input, background_input]
|
|
|
|
|
|
def prepare_output():
|
|
# toxicity_output = gr.Label(num_top_classes=1, label="Toxicity (from Perspective API)")
|
|
# translation_result_otuput = gr.Textbox(type="str", label="Translation Result")
|
|
video_output = gr.Video(format='mp4')
|
|
#return [toxicity_output, translation_result_otuput, video_output]
|
|
return [video_output]
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description='GRADIO DEMO for talking face generation submitted to CVPR2022')
|
|
parser.add_argument('-p', '--port', dest='gradio_port', type=int, default=7860, help="Port for gradio")
|
|
parser.add_argument('--rest_ip', type=str, default="0.0.0.0", help="IP for REST API")
|
|
parser.add_argument('--rest_port', type=int, default=8881, help="Port for REST API")
|
|
parser.add_argument('--max_seed', type=int, default=20, help="Max seed for saving video")
|
|
parser.add_argument('--share', action='store_true', help='get publicly sharable link')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
|
|
gradio_application = GradioApplication(args.rest_ip, args.rest_port, args.max_seed)
|
|
gradio_application.run(server_port=7860, share=True)
|
|
|