ml-talking-face
Build-Deploy-Actions
Details
Build-Deploy-Actions
Details
This commit is contained in:
commit
98372b417b
|
@ -0,0 +1,2 @@
|
|||
output_file/*
|
||||
!output_file/.gitkeep
|
|
@ -0,0 +1,31 @@
|
|||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
||||
*.png filter=lfs diff=lfs merge=lfs -text
|
||||
output_file/* filter=lfs diff=lfs merge=lfs -text
|
||||
background_image/* filter=lfs diff=lfs merge=lfs -text
|
|
@ -0,0 +1,47 @@
|
|||
name: Build
|
||||
run-name: ${{ github.actor }} is upgrade release 🚀
|
||||
on: [push]
|
||||
env:
|
||||
REPOSITORY: ${{ github.repository }}
|
||||
COMMIT_ID: ${{ github.sha }}
|
||||
jobs:
|
||||
Build-Deploy-Actions:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- run: echo "🎉 The job was automatically triggered by a ${{ github.event_name }} event."
|
||||
- run: echo "🐧 This job is now running on a ${{ runner.os }} server hosted by Gitea!"
|
||||
- run: echo "🔎 The name of your branch is ${{ github.ref }} and your repository is ${{ github.repository }}."
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v3
|
||||
-
|
||||
name: Setup Git LFS
|
||||
run: |
|
||||
git lfs install
|
||||
git lfs fetch
|
||||
git lfs checkout
|
||||
- name: List files in the repository
|
||||
run: |
|
||||
ls ${{ github.workspace }}
|
||||
-
|
||||
name: Docker Image Info
|
||||
id: image-info
|
||||
run: |
|
||||
echo "::set-output name=image_name::$(echo $REPOSITORY | tr '[:upper:]' '[:lower:]')"
|
||||
echo "::set-output name=image_tag::${COMMIT_ID:0:10}"
|
||||
-
|
||||
name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: artifacts.iflytek.com
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
-
|
||||
name: Build and push
|
||||
run: |
|
||||
docker version
|
||||
docker buildx build -t artifacts.iflytek.com/docker-private/atp/${{ steps.image-info.outputs.image_name }}:${{ steps.image-info.outputs.image_tag }} . --file ${{ github.workspace }}/Dockerfile --load
|
||||
docker push artifacts.iflytek.com/docker-private/atp/${{ steps.image-info.outputs.image_name }}:${{ steps.image-info.outputs.image_tag }}
|
||||
docker rmi artifacts.iflytek.com/docker-private/atp/${{ steps.image-info.outputs.image_name }}:${{ steps.image-info.outputs.image_tag }}
|
||||
- run: echo "🍏 This job's status is ${{ job.status }}."
|
|
@ -0,0 +1,14 @@
|
|||
.DS_Store
|
||||
flagged/
|
||||
__pycache__/
|
||||
.vscode/
|
||||
output_file/*
|
||||
|
||||
!output_file/.gitkeep
|
||||
|
||||
*.mp4
|
||||
*.png
|
||||
!background_image/*
|
||||
*.mkv
|
||||
gradio_queue.db*
|
||||
!vacant.mp4
|
|
@ -0,0 +1,11 @@
|
|||
#FROM python:3.8.13
|
||||
FROM artifacts.iflytek.com/docker-private/atp/base_image_for_ailab:0.0.1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . /app
|
||||
|
||||
RUN pip config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple
|
||||
RUN pip install -r requirements.txt
|
||||
|
||||
CMD ["python", "app.py"]
|
|
@ -0,0 +1,47 @@
|
|||
---
|
||||
title: Talking Face Generation with Multilingual TTS
|
||||
emoji: 👄
|
||||
colorFrom: blue
|
||||
colorTo: blue
|
||||
sdk: gradio
|
||||
sdk_version: 3.0.6
|
||||
app_file: app.py
|
||||
pinned: false
|
||||
license: cc-by-nc-sa-4.0
|
||||
---
|
||||
|
||||
# Configuration
|
||||
|
||||
`title`: _string_
|
||||
Display title for the Space
|
||||
|
||||
`emoji`: _string_
|
||||
Space emoji (emoji-only character allowed)
|
||||
|
||||
`colorFrom`: _string_
|
||||
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
||||
|
||||
`colorTo`: _string_
|
||||
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
||||
|
||||
`sdk`: _string_
|
||||
Can be either `gradio`, `streamlit`, or `static`
|
||||
|
||||
`sdk_version` : _string_
|
||||
Only applicable for `streamlit` SDK.
|
||||
See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
|
||||
|
||||
`app_file`: _string_
|
||||
Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code).
|
||||
Path is relative to the root of the repository.
|
||||
|
||||
`models`: _List[string]_
|
||||
HF model IDs (like "gpt2" or "deepset/roberta-base-squad2") used in the Space.
|
||||
Will be parsed automatically from your code if not specified here.
|
||||
|
||||
`datasets`: _List[string]_
|
||||
HF dataset IDs (like "common_voice" or "oscar-corpus/OSCAR-2109") used in the Space.
|
||||
Will be parsed automatically from your code if not specified here.
|
||||
|
||||
`pinned`: _boolean_
|
||||
Whether the Space stays on top of your list.
|
|
@ -0,0 +1,206 @@
|
|||
# https://huggingface.co/deepkyu/ml-talking-face
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
#REST_IP = os.environ['REST_IP']
|
||||
#SERVICE_PORT = int(os.environ['SERVICE_PORT'])
|
||||
#ETRANSLATION_APIKEY_URL = os.environ['TRANSLATION_APIKEY_URL']
|
||||
#$GOOGLE_APPLICATION_CREDENTIALS = os.environ['GOOGLE_APPLICATION_CREDENTIALS']
|
||||
#subprocess.call(f"wget --no-check-certificate -O {GOOGLE_APPLICATION_CREDENTIALS} {TRANSLATION_APIKEY_URL}", shell=True)
|
||||
|
||||
#TOXICITY_THRESHOLD = float(os.getenv('TOXICITY_THRESHOLD', 0.7))
|
||||
|
||||
import gradio as gr
|
||||
#from toxicity_estimator import PerspectiveAPI
|
||||
from translator import Translator
|
||||
from client_rest import RestAPIApplication
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import threading
|
||||
import yaml
|
||||
|
||||
TITLE = Path("docs/title.txt").read_text()
|
||||
DESCRIPTION = Path("docs/description.md").read_text()
|
||||
|
||||
|
||||
class GradioApplication:
|
||||
def __init__(self, rest_ip, rest_port, max_seed):
|
||||
self.lang_list = {
|
||||
'ko': 'ko_KR',
|
||||
'en': 'en_US',
|
||||
'ja': 'ja_JP',
|
||||
'zh': 'zh_CN',
|
||||
'zh-CN': 'zh_CN'
|
||||
}
|
||||
self.background_list = [None,
|
||||
"background_image/cvpr.png",
|
||||
"background_image/black.png",
|
||||
"background_image/river.mp4",
|
||||
"background_image/sky.mp4"]
|
||||
|
||||
#self.perspective_api = PerspectiveAPI()
|
||||
#self.translator = Translator()
|
||||
self.rest_application = RestAPIApplication(rest_ip, rest_port)
|
||||
self.output_dir = Path("output_file")
|
||||
|
||||
inputs = prepare_input()
|
||||
outputs = prepare_output()
|
||||
|
||||
self.iface = gr.Interface(fn=self.infer,
|
||||
title=TITLE,
|
||||
description=DESCRIPTION,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
allow_flagging='never',
|
||||
article=Path("docs/article.md").read_text())
|
||||
|
||||
self.max_seed = max_seed
|
||||
self._file_seed = 0
|
||||
self.lock = threading.Lock()
|
||||
|
||||
|
||||
def _get_file_seed(self):
|
||||
return f"{self._file_seed % self.max_seed:02d}"
|
||||
|
||||
def _reset_file_seed(self):
|
||||
self._file_seed = 0
|
||||
|
||||
def _counter_file_seed(self):
|
||||
with self.lock:
|
||||
self._file_seed += 1
|
||||
|
||||
def get_lang_code(self, lang):
|
||||
return self.lang_list[lang]
|
||||
|
||||
def get_background_data(self, background_index):
|
||||
# get background filename and its extension
|
||||
data_path = self.background_list[background_index]
|
||||
|
||||
if data_path is not None:
|
||||
with open(data_path, 'rb') as rf:
|
||||
background_data = rf.read()
|
||||
is_video_background = str(data_path).endswith(".mp4")
|
||||
else:
|
||||
background_data = None
|
||||
is_video_background = False
|
||||
|
||||
return background_data, is_video_background
|
||||
|
||||
@staticmethod
|
||||
def return_format(toxicity_prob, target_text, lang_dest, video_filename, detail=""):
|
||||
return {'Toxicity': toxicity_prob}, f"Language: {lang_dest}\nText: {target_text}\n-\nDetails: {detail}", str(video_filename)
|
||||
|
||||
def infer(self, text, lang, duration_rate, action, background_index):
|
||||
self._counter_file_seed()
|
||||
print(f"File Seed: {self._file_seed}")
|
||||
toxicity_prob = 0.0
|
||||
target_text = ""
|
||||
lang_dest = ""
|
||||
video_filename = "vacant.mp4"
|
||||
|
||||
# Toxicity estimation
|
||||
# try:
|
||||
# toxicity_prob = self.perspective_api.get_score(text)
|
||||
# except Exception as e: # when Perspective API doesn't work
|
||||
# pass
|
||||
#
|
||||
# if toxicity_prob > TOXICITY_THRESHOLD:
|
||||
# detail = "Sorry, it seems that the input text is too toxic."
|
||||
# return self.return_format(toxicity_prob, target_text, lang_dest, video_filename, detail=f"Error: {detail}")
|
||||
|
||||
target_text = text
|
||||
lang_rpc_code = "en-GB"
|
||||
|
||||
# Google Translate API
|
||||
# try:
|
||||
# target_text, lang_dest = self.translator.get_translation(text, lang)
|
||||
# except Exception as e:
|
||||
# target_text = ""
|
||||
# lang_dest = ""
|
||||
# detail = f"Error from language translation: ({e})"
|
||||
# return self.return_format(toxicity_prob, target_text, lang_dest, video_filename, detail=f"Error: {detail}")
|
||||
#
|
||||
# try:
|
||||
# self.translator.length_check(lang_dest, target_text) # assertion check
|
||||
# except AssertionError as e:
|
||||
# return self.return_format(toxicity_prob, target_text, lang_dest, video_filename, detail=f"Error: {str(e)}")
|
||||
#
|
||||
# lang_rpc_code = self.get_lang_code(lang_dest)
|
||||
|
||||
# Video Inference
|
||||
background_data, is_video_background = self.get_background_data(background_index)
|
||||
|
||||
video_data = self.rest_application.get_video(target_text, lang_rpc_code, duration_rate, action.lower(),
|
||||
background_data, is_video_background)
|
||||
print(f"Video data size: {len(video_data)}")
|
||||
|
||||
video_filename = self.output_dir / f"{self._file_seed:02d}.mkv"
|
||||
with open(video_filename, "wb") as video_file:
|
||||
video_file.write(video_data)
|
||||
|
||||
return self.return_format(toxicity_prob, target_text, lang_dest, video_filename)
|
||||
|
||||
def run(self, server_port=7860, share=False):
|
||||
try:
|
||||
self.iface.launch(height=900, server_name = "0.0.0.0",
|
||||
share=share, server_port=server_port,
|
||||
enable_queue=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
gr.close_all()
|
||||
|
||||
|
||||
def prepare_input():
|
||||
text_input = gr.Textbox(lines=2,
|
||||
placeholder="Type your text with English, Chinese, Korean, and Japanese.",
|
||||
value="Hello, this is demonstration for talking face generation "
|
||||
"with multilingual text-to-speech.",
|
||||
label="Text")
|
||||
lang_input = gr.Radio(['Korean', 'English', 'Japanese', 'Chinese'],
|
||||
type='value',
|
||||
value=None,
|
||||
label="Language")
|
||||
duration_rate_input = gr.Slider(minimum=0.8,
|
||||
maximum=1.2,
|
||||
step=0.01,
|
||||
value=1.0,
|
||||
label="Duration (The bigger the value, the slower the speech)")
|
||||
action_input = gr.Radio(['Default', 'Hand', 'BothHand', 'HandDown', 'Sorry'],
|
||||
type='value',
|
||||
value='Default',
|
||||
label="Select an action ...")
|
||||
background_input = gr.Radio(['None', 'CVPR', 'Black', 'River', 'Sky'],
|
||||
type='index',
|
||||
value='None',
|
||||
label="Select a background image/video ...")
|
||||
|
||||
return [text_input, lang_input, duration_rate_input,
|
||||
action_input, background_input]
|
||||
|
||||
|
||||
def prepare_output():
|
||||
# toxicity_output = gr.Label(num_top_classes=1, label="Toxicity (from Perspective API)")
|
||||
# translation_result_otuput = gr.Textbox(type="str", label="Translation Result")
|
||||
video_output = gr.Video(format='mp4')
|
||||
#return [toxicity_output, translation_result_otuput, video_output]
|
||||
return [video_output]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='GRADIO DEMO for talking face generation submitted to CVPR2022')
|
||||
parser.add_argument('-p', '--port', dest='gradio_port', type=int, default=7860, help="Port for gradio")
|
||||
parser.add_argument('--rest_ip', type=str, default="0.0.0.0", help="IP for REST API")
|
||||
parser.add_argument('--rest_port', type=int, default=8881, help="Port for REST API")
|
||||
parser.add_argument('--max_seed', type=int, default=20, help="Max seed for saving video")
|
||||
parser.add_argument('--share', action='store_true', help='get publicly sharable link')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
gradio_application = GradioApplication(args.rest_ip, args.rest_port, args.max_seed)
|
||||
gradio_application.run(server_port=7860, share=True)
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,74 @@
|
|||
import requests
|
||||
import json
|
||||
import base64
|
||||
import argparse
|
||||
|
||||
VIDEO_WIDTH = 1080
|
||||
VIDEO_HEIGHT = 1920
|
||||
SPEAKER_ID = 0
|
||||
|
||||
class RestAPIApplication:
|
||||
def __init__(self, ip, port):
|
||||
|
||||
if port < 0:
|
||||
self.post_request_addr = f"http://{ip}/register/"
|
||||
self.post_headers = {"Content-Type": "application/json"}
|
||||
self.generate_addr = (lambda id_: f'http://{ip}/generate/{id_}')
|
||||
else:
|
||||
self.post_request_addr = f"http://{ip}:{port}/register/"
|
||||
self.post_headers = {"Content-Type": "application/json"}
|
||||
self.generate_addr = (lambda id_: f'http://{ip}:{port}/generate/{id_}')
|
||||
|
||||
@staticmethod
|
||||
def _get_json_request(text, lang, duration_rate, action, background_data=None, is_video_background=False):
|
||||
request_form = dict()
|
||||
|
||||
request_form['text'] = text
|
||||
request_form['speaker'] = SPEAKER_ID
|
||||
request_form['width'] = VIDEO_WIDTH
|
||||
request_form['height'] = VIDEO_HEIGHT
|
||||
|
||||
request_form['action'] = action
|
||||
|
||||
if background_data is not None:
|
||||
background_base64 = base64.b64encode(background_data).decode("UTF-8")
|
||||
else:
|
||||
background_base64 = ""
|
||||
|
||||
request_form['background'] = background_base64
|
||||
request_form['durationRate'] = duration_rate
|
||||
request_form['isVideoBackground'] = is_video_background
|
||||
request_form['lang'] = lang
|
||||
|
||||
request_as_json = json.dumps(request_form)
|
||||
return request_as_json
|
||||
|
||||
@staticmethod
|
||||
def _get_video_id(results):
|
||||
return json.loads(bytes.decode(results.content))['id']
|
||||
|
||||
def get_video(self, text, lang, duration_rate, action, background_data=None, is_video_background=False):
|
||||
request_json = self._get_json_request(text, lang, duration_rate, action, background_data, is_video_background)
|
||||
|
||||
# POST request with jsonified request
|
||||
results = requests.post(self.post_request_addr, headers=self.post_headers, data=request_json)
|
||||
|
||||
# GET video with the given id
|
||||
video_id = self._get_video_id(results)
|
||||
video_results = requests.get(self.generate_addr(video_id))
|
||||
|
||||
return video_results.content
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='REST API interface for talking face generation submitted to CVPR2022')
|
||||
parser.add_argument('-i', '--ip', dest='rest_ip', type=str, default="127.0.0.1", help="IP for REST API")
|
||||
parser.add_argument('-p', '--port', dest='rest_port', type=int, default=8080, help="Port for REST API")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
rest_api_application = RestAPIApplication(args.rest_ip, args.rest_port)
|
|
@ -0,0 +1,23 @@
|
|||
|
||||
## Why learn a new language, when your model can learn it for you?
|
||||
|
||||
<div style="max-width: 720px;max-height: 405px;margin: auto;">
|
||||
<div style="float: none;clear: both;position: relative;padding-bottom: 56.25%;height: 0;width: 100%">
|
||||
<iframe width="720" height="405" src="https://www.youtube.com/embed/toqdD1F_ZsU" title="YouTube video player" style="position: absolute;top: 0;left: 0;width: 100%;height: 100%;" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen>
|
||||
</iframe>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
### Abstract
|
||||
|
||||
Recent studies in talking face generation have focused on building a train-once-use-everywhere model i.e. a model that will generalize from any source speech to any target identity. A number of works have already claimed this functionality and have added that their models will also generalize to any language. However, we show, using languages from different language families, that these models do not translate well when the training language and the testing language are sufficiently different. We reduce the scope of the problem to building a language-robust talking face generation system on seen identities i.e. the target identity is the same as the training identity. In this work, we introduce a talking face generation system that will generalize to different languages. We evaluate the efficacy of our system using a multilingual text-to-speech system. We also discuss the usage of joint text-to-speech system and the talking face generation system as a neural dubber system.
|
||||
|
||||
[CVPR Open Access](https://openaccess.thecvf.com/content/CVPR2022/html/Song_Talking_Face_Generation_With_Multilingual_TTS_CVPR_2022_paper.html) [arXiv](https://arxiv.org/abs/2205.06421)
|
||||
|
||||
### News
|
||||
|
||||
(2022.08.18.) We got the CVPR Hugging Face prize! Thank you all and special thanks to AK([@akhaliq](https://huggingface.co/akhaliq)).
|
||||
|
||||
<center>
|
||||
<img alt="we-got-huggingface-prize" src="https://github.com/deepkyu/ml-talking-face/blob/main/docs/we-got-huggingface-prize.jpeg?raw=true" width="50%" />
|
||||
</center>
|
|
@ -0,0 +1,18 @@
|
|||
This system generates a talking face video based on the input text.
|
||||
You can provide the input text in one of the four languages: Chinese (Mandarin), English, Japanese, and Korean.
|
||||
You may also select the target language, the language of the output speech.
|
||||
If the input text language and the target language are different, the input text will be translated to the target language using Google Translate API.
|
||||
|
||||
### Updates
|
||||
|
||||
(2022.09.29.) **NOTE!** The core part of the demonstration has been working on the AWS instance of MINDsLab, and I found that it can't connect to the instance now. I want to fix this issue, but I'm sorry to say that I left the company last week. I've contacted the company, but it takes some time to restore the session. If you're in a hurry, please send the e-mail directly to MINDsLab (hello@mindslab.ai).
|
||||
Whatever the reason, I'm sorry again. Hope you understand.
|
||||
|
||||
(2022.06.17.) Thank you for visiting our demo!😊 This demo attracted a lot more attention than we anticipated. This, unfortunately, means that the computational burden is heavier than this demo was designed for. So, to maximize everyone's experience, we capped the length of the translated texts at:
|
||||
|
||||
- 200 characters for English
|
||||
- 100 characters for Chinese, Japaense, and Korean.
|
||||
|
||||
(2022.06.17.) We were originally planning to support any input text. However, when checking the logs recently, we found that there were a lot of inappropriate input texts. So, we decided to filter the inputs based on toxicity using [Perspective API @Google](https://developers.perspectiveapi.com/s/). Now, if you enter a possibily toxic text, the video generation will fail. We hope you understand.
|
||||
|
||||
(2022.06.05.) Due to the latency from HuggingFace Spaces and video rendering, it takes 15 ~ 30 seconds to get a video result.
|
|
@ -0,0 +1 @@
|
|||
Talking Face Generation with Multilingual TTS (CVPR 2022 Demo Track)
|
Binary file not shown.
After Width: | Height: | Size: 363 KiB |
|
@ -0,0 +1,20 @@
|
|||
ko:
|
||||
index: 1
|
||||
language: "Korean"
|
||||
locale: "ko_KR"
|
||||
google_dest: "ko"
|
||||
en:
|
||||
index: 2
|
||||
language: "English"
|
||||
locale: "en_US"
|
||||
google_dest: "en"
|
||||
ja:
|
||||
index: 3
|
||||
language: "Japanese"
|
||||
locale: "ja_JP"
|
||||
google_dest: "ja"
|
||||
zh:
|
||||
index: 4
|
||||
language: "Chinese"
|
||||
locale: "zh_CN"
|
||||
google_dest: "zh-CN"
|
|
@ -0,0 +1,7 @@
|
|||
gradio
|
||||
jinja2
|
||||
googletrans==4.0.0-rc1
|
||||
PyYAML
|
||||
opencv-python
|
||||
google-cloud-translate
|
||||
google-api-python-client
|
|
@ -0,0 +1,8 @@
|
|||
ko:
|
||||
- "안녕하세요? 한국어로 말하고 있습니다."
|
||||
en:
|
||||
- "Hello. Now I'm speaking in English."
|
||||
zh:
|
||||
- "你好? 我在说普通话。"
|
||||
ja:
|
||||
- "こんにちは。 今、日本語で話しています。"
|
|
@ -0,0 +1 @@
|
|||
from .module import PerspectiveAPI
|
|
@ -0,0 +1,51 @@
|
|||
from googleapiclient import discovery
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
API_KEY = os.environ['PERSPECTIVE_API_KEY']
|
||||
|
||||
class PerspectiveAPI:
|
||||
def __init__(self):
|
||||
self.client = discovery.build(
|
||||
"commentanalyzer",
|
||||
"v1alpha1",
|
||||
developerKey=API_KEY,
|
||||
discoveryServiceUrl="https://commentanalyzer.googleapis.com/$discovery/rest?version=v1alpha1",
|
||||
static_discovery=False,
|
||||
)
|
||||
@staticmethod
|
||||
def _get_request(text):
|
||||
return {
|
||||
'comment': {'text': text},
|
||||
'requestedAttributes': {'TOXICITY': {}}
|
||||
}
|
||||
|
||||
def _infer(self, text):
|
||||
request = self._get_request(text)
|
||||
response = self.client.comments().analyze(body=request).execute()
|
||||
return response
|
||||
|
||||
def infer(self, text):
|
||||
return self._infer(text)
|
||||
|
||||
def get_score(self, text, label='TOXICITY'):
|
||||
response = self._infer(text)
|
||||
return response['attributeScores'][label]['spanScores'][0]['score']['value']
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Perspective API Test.')
|
||||
parser.add_argument('-i', '--input-text', type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
perspective_api = PerspectiveAPI()
|
||||
score = perspective_api.get_score(args.input_text)
|
||||
|
||||
print(score)
|
|
@ -0,0 +1 @@
|
|||
from .module import Translator
|
|
@ -0,0 +1,59 @@
|
|||
from .v3 import GoogleAuthTranslation
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
import os
|
||||
|
||||
MAX_ENG_TEXT_LENGTH = int(os.getenv('MAX_ENG_TEXT_LENGTH', 200))
|
||||
MAX_CJK_TEXT_LENGTH = int(os.getenv('MAX_CJK_TEXT_LENGTH', 100))
|
||||
|
||||
class Translator:
|
||||
def __init__(self, yaml_path='./lang.yaml'):
|
||||
self.google_translation = GoogleAuthTranslation(project_id="cvpr-2022-demonstration")
|
||||
with open(yaml_path) as f:
|
||||
self.supporting_languages = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
@staticmethod
|
||||
def length_check(lang, text):
|
||||
if lang in ['en']:
|
||||
if len(text) > MAX_ENG_TEXT_LENGTH:
|
||||
raise AssertionError(f"Input text is too long. For English, the text length should be less than {MAX_ENG_TEXT_LENGTH}. | Length: {len(text)}")
|
||||
elif lang in ['ko', 'ja', 'zh-CN', 'zh']:
|
||||
if len(text) > MAX_CJK_TEXT_LENGTH:
|
||||
raise AssertionError(f"Input text is too long. For CJK, the text length should be less than {MAX_CJK_TEXT_LENGTH}. | Length: {len(text)}")
|
||||
else:
|
||||
raise AssertionError(f"Not in ['ko', 'ja', 'zh-CN', 'zh', 'en'] ! | Language: {lang}")
|
||||
|
||||
return
|
||||
|
||||
def _get_text_with_lang(self, text, lang):
|
||||
lang_detected = self.google_translation.detect(text)
|
||||
print(f"Detected as: {lang_detected} | Destination: {lang}")
|
||||
|
||||
if lang is None:
|
||||
lang = lang_detected
|
||||
|
||||
if lang != lang_detected:
|
||||
target_text = self.google_translation.translate(text, lang=lang)
|
||||
else:
|
||||
target_text = text
|
||||
|
||||
return target_text, lang
|
||||
|
||||
def _convert_lang_from_index(self, lang):
|
||||
try:
|
||||
lang = [name for name in self.supporting_languages
|
||||
if self.supporting_languages[name]['language'] == lang][0]
|
||||
except Exception as e:
|
||||
raise RuntimeError(e)
|
||||
|
||||
return lang
|
||||
|
||||
def get_translation(self, text, lang, use_translation=True):
|
||||
lang_ = self._convert_lang_from_index(lang)
|
||||
|
||||
if use_translation:
|
||||
target_text, _ = self._get_text_with_lang(text, lang_)
|
||||
else:
|
||||
target_text = text
|
||||
|
||||
return target_text, lang_
|
|
@ -0,0 +1,58 @@
|
|||
from google.cloud import translate
|
||||
import yaml
|
||||
|
||||
|
||||
class GoogleAuthTranslation:
|
||||
def __init__(self, project_id, yaml_path='lang.yaml'):
|
||||
self.translator = translate.TranslationServiceClient()
|
||||
self.location = "global"
|
||||
self.parent = f"projects/{project_id}/locations/{self.location}"
|
||||
|
||||
with open(yaml_path) as f:
|
||||
self.supporting_languages = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
def _detect(self, query):
|
||||
response = self.translator.detect_language(
|
||||
request={
|
||||
"parent": self.parent,
|
||||
"content": query,
|
||||
"mime_type": "text/plain", # mime types: text/plain, text/html
|
||||
}
|
||||
)
|
||||
|
||||
for language in response.languages:
|
||||
# First language is the most confident one
|
||||
return language.language_code
|
||||
|
||||
def _get_dest_from_lang(self, lang):
|
||||
try:
|
||||
return self.supporting_languages[lang]['google_dest']
|
||||
|
||||
except KeyError as e:
|
||||
raise e
|
||||
|
||||
def _get_lang_from_dest(self, dest):
|
||||
for key in self.supporting_languages:
|
||||
if self.supporting_languages[key]['google_dest'] == dest:
|
||||
return key
|
||||
|
||||
raise RuntimeError(f"Detected langauge is not supported in our multilingual TTS. |\n Code: {dest} | See https://cloud.google.com/translate/docs/languages")
|
||||
|
||||
def translate(self, query, lang):
|
||||
|
||||
dest = self._get_dest_from_lang(lang)
|
||||
|
||||
response = self.translator.translate_text(
|
||||
request={
|
||||
"parent": self.parent,
|
||||
"contents": [query],
|
||||
"mime_type": "text/plain", # mime types: text/plain, text/html
|
||||
"target_language_code": dest,
|
||||
}
|
||||
)
|
||||
|
||||
return " ".join([translation.translated_text for translation in response.translations])
|
||||
|
||||
def detect(self, query):
|
||||
dest = self._detect(query)
|
||||
return self._get_lang_from_dest(dest)
|
Loading…
Reference in New Issue