Refactor Project to 3 parts: Models, Control, Data

Need readme
This commit is contained in:
babysor00
2022-12-03 16:54:06 +08:00
parent b402f9dbdf
commit 74a3fc97d0
179 changed files with 197 additions and 27924 deletions

0
control/__init__.py Normal file
View File

View File

@@ -0,0 +1,64 @@
import argparse
from pathlib import Path
from models.encoder.preprocess import (preprocess_aidatatang_200zh,
preprocess_librispeech, preprocess_voxceleb1,
preprocess_voxceleb2)
from utils.argutils import print_args
if __name__ == "__main__":
class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
pass
parser = argparse.ArgumentParser(
description="Preprocesses audio files from datasets, encodes them as mel spectrograms and "
"writes them to the disk. This will allow you to train the encoder. The "
"datasets required are at least one of LibriSpeech, VoxCeleb1, VoxCeleb2, aidatatang_200zh. ",
formatter_class=MyFormatter
)
parser.add_argument("datasets_root", type=Path, help=\
"Path to the directory containing your LibriSpeech/TTS and VoxCeleb datasets.")
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
"Path to the output directory that will contain the mel spectrograms. If left out, "
"defaults to <datasets_root>/SV2TTS/encoder/")
parser.add_argument("-d", "--datasets", type=str,
default="librispeech_other,voxceleb1,aidatatang_200zh", help=\
"Comma-separated list of the name of the datasets you want to preprocess. Only the train "
"set of these datasets will be used. Possible names: librispeech_other, voxceleb1, "
"voxceleb2.")
parser.add_argument("-s", "--skip_existing", action="store_true", help=\
"Whether to skip existing output files with the same name. Useful if this script was "
"interrupted.")
parser.add_argument("--no_trim", action="store_true", help=\
"Preprocess audio without trimming silences (not recommended).")
args = parser.parse_args()
# Verify webrtcvad is available
if not args.no_trim:
try:
import webrtcvad
except:
raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
"noise removal and is recommended. Please install and try again. If installation fails, "
"use --no_trim to disable this error message.")
del args.no_trim
# Process the arguments
args.datasets = args.datasets.split(",")
if not hasattr(args, "out_dir"):
args.out_dir = args.datasets_root.joinpath("SV2TTS", "encoder")
assert args.datasets_root.exists()
args.out_dir.mkdir(exist_ok=True, parents=True)
# Preprocess the datasets
print_args(args, parser)
preprocess_func = {
"librispeech_other": preprocess_librispeech,
"voxceleb1": preprocess_voxceleb1,
"voxceleb2": preprocess_voxceleb2,
"aidatatang_200zh": preprocess_aidatatang_200zh,
}
args = vars(args)
for dataset in args.pop("datasets"):
print("Preprocessing %s" % dataset)
preprocess_func[dataset](**args)

View File

@@ -0,0 +1,47 @@
from utils.argutils import print_args
from models.encoder.train import train
from pathlib import Path
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Trains the speaker encoder. You must have run encoder_preprocess.py first.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("run_id", type=str, help= \
"Name for this model instance. If a model state from the same run ID was previously "
"saved, the training will restart from there. Pass -f to overwrite saved states and "
"restart from scratch.")
parser.add_argument("clean_data_root", type=Path, help= \
"Path to the output directory of encoder_preprocess.py. If you left the default "
"output directory when preprocessing, it should be <datasets_root>/SV2TTS/encoder/.")
parser.add_argument("-m", "--models_dir", type=Path, default="encoder/saved_models/", help=\
"Path to the output directory that will contain the saved model weights, as well as "
"backups of those weights and plots generated during training.")
parser.add_argument("-v", "--vis_every", type=int, default=10, help= \
"Number of steps between updates of the loss and the plots.")
parser.add_argument("-u", "--umap_every", type=int, default=100, help= \
"Number of steps between updates of the umap projection. Set to 0 to never update the "
"projections.")
parser.add_argument("-s", "--save_every", type=int, default=500, help= \
"Number of steps between updates of the model on the disk. Set to 0 to never save the "
"model.")
parser.add_argument("-b", "--backup_every", type=int, default=7500, help= \
"Number of steps between backups of the model. Set to 0 to never make backups of the "
"model.")
parser.add_argument("-f", "--force_restart", action="store_true", help= \
"Do not load any saved model.")
parser.add_argument("--visdom_server", type=str, default="http://localhost")
parser.add_argument("--no_visdom", action="store_true", help= \
"Disable visdom.")
args = parser.parse_args()
# Process the arguments
args.models_dir.mkdir(exist_ok=True)
# Run the training
print_args(args, parser)
train(**vars(args))

View File

@@ -0,0 +1,67 @@
import sys
import torch
import argparse
import numpy as np
from utils.load_yaml import HpsYaml
from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
# For reproducibility, comment these may speed up training
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main():
# Arguments
parser = argparse.ArgumentParser(description=
'Training PPG2Mel VC model.')
parser.add_argument('--config', type=str,
help='Path to experiment config, e.g., config/vc.yaml')
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
parser.add_argument('--logdir', default='log/', type=str,
help='Logging path.', required=False)
parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str,
help='Checkpoint path.', required=False)
parser.add_argument('--outdir', default='result/', type=str,
help='Decode output path.', required=False)
parser.add_argument('--load', default=None, type=str,
help='Load pre-trained model (for training only)', required=False)
parser.add_argument('--warm_start', action='store_true',
help='Load model weights only, ignore specified layers.')
parser.add_argument('--seed', default=0, type=int,
help='Random seed for reproducable results.', required=False)
parser.add_argument('--njobs', default=8, type=int,
help='Number of threads for dataloader/decoding.', required=False)
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
parser.add_argument('--no-pin', action='store_true',
help='Disable pin-memory for dataloader')
parser.add_argument('--test', action='store_true', help='Test the model.')
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
parser.add_argument('--finetune', action='store_true', help='Finetune model')
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
###
paras = parser.parse_args()
setattr(paras, 'gpu', not paras.cpu)
setattr(paras, 'pin_memory', not paras.no_pin)
setattr(paras, 'verbose', not paras.no_msg)
# Make the config dict dot visitable
config = HpsYaml(paras.config)
np.random.seed(paras.seed)
torch.manual_seed(paras.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(paras.seed)
print(">>> OneShot VC training ...")
mode = "train"
solver = Solver(config, paras, mode)
solver.load_data()
solver.set_model()
solver.exec()
print(">>> Oneshot VC train finished!")
sys.exit(0)
if __name__ == "__main__":
main()

49
control/cli/pre4ppg.py Normal file
View File

@@ -0,0 +1,49 @@
from pathlib import Path
import argparse
from models.ppg2mel.preprocess import preprocess_dataset
from pathlib import Path
import argparse
recognized_datasets = [
"aidatatang_200zh",
"aidatatang_200zh_s", # sample
]
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Preprocesses audio files from datasets, to be used by the "
"ppg2mel model for training.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("datasets_root", type=Path, help=\
"Path to the directory containing your datasets.")
parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\
"Name of the dataset to process, allowing values: aidatatang_200zh.")
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
"Path to the output directory that will contain the mel spectrograms, the audios and the "
"embeds. Defaults to <datasets_root>/PPGVC/ppg2mel/")
parser.add_argument("-n", "--n_processes", type=int, default=8, help=\
"Number of processes in parallel.")
# parser.add_argument("-s", "--skip_existing", action="store_true", help=\
# "Whether to overwrite existing files with the same name. Useful if the preprocessing was "
# "interrupted. ")
# parser.add_argument("--hparams", type=str, default="", help=\
# "Hyperparameter overrides as a comma-separated list of name-value pairs")
# parser.add_argument("--no_trim", action="store_true", help=\
# "Preprocess audio without trimming silences (not recommended).")
parser.add_argument("-pf", "--ppg_encoder_model_fpath", type=Path, default="ppg_extractor/saved_models/24epoch.pt", help=\
"Path your trained ppg encoder model.")
parser.add_argument("-sf", "--speaker_encoder_model", type=Path, default="encoder/saved_models/pretrained_bak_5805000.pt", help=\
"Path your trained speaker encoder model.")
args = parser.parse_args()
assert args.dataset in recognized_datasets, 'is not supported, file a issue to propose a new one'
# Create directories
assert args.datasets_root.exists()
if not hasattr(args, "out_dir"):
args.out_dir = args.datasets_root.joinpath("PPGVC", "ppg2mel")
args.out_dir.mkdir(exist_ok=True, parents=True)
preprocess_dataset(**vars(args))

View File

@@ -0,0 +1,65 @@
from models.synthesizer.preprocess import preprocess_dataset
from models.synthesizer.hparams import hparams
from utils.argutils import print_args
from pathlib import Path
import argparse
recognized_datasets = [
"aidatatang_200zh",
"magicdata",
"aishell3"
]
if __name__ == "__main__":
print("This method is deprecaded and will not be longer supported, please use 'pre.py'")
parser = argparse.ArgumentParser(
description="Preprocesses audio files from datasets, encodes them as mel spectrograms "
"and writes them to the disk. Audio files are also saved, to be used by the "
"vocoder for training.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("datasets_root", type=Path, help=\
"Path to the directory containing your LibriSpeech/TTS datasets.")
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
"Path to the output directory that will contain the mel spectrograms, the audios and the "
"embeds. Defaults to <datasets_root>/SV2TTS/synthesizer/")
parser.add_argument("-n", "--n_processes", type=int, default=None, help=\
"Number of processes in parallel.")
parser.add_argument("-s", "--skip_existing", action="store_true", help=\
"Whether to overwrite existing files with the same name. Useful if the preprocessing was "
"interrupted.")
parser.add_argument("--hparams", type=str, default="", help=\
"Hyperparameter overrides as a comma-separated list of name-value pairs")
parser.add_argument("--no_trim", action="store_true", help=\
"Preprocess audio without trimming silences (not recommended).")
parser.add_argument("--no_alignments", action="store_true", help=\
"Use this option when dataset does not include alignments\
(these are used to split long audio files into sub-utterances.)")
parser.add_argument("--dataset", type=str, default="aidatatang_200zh", help=\
"Name of the dataset to process, allowing values: magicdata, aidatatang_200zh.")
args = parser.parse_args()
# Process the arguments
if not hasattr(args, "out_dir"):
args.out_dir = args.datasets_root.joinpath("SV2TTS", "synthesizer")
assert args.dataset in recognized_datasets, 'is not supported, please vote for it in https://github.com/babysor/MockingBird/issues/10'
# Create directories
assert args.datasets_root.exists()
args.out_dir.mkdir(exist_ok=True, parents=True)
# Verify webrtcvad is available
if not args.no_trim:
try:
import webrtcvad
except:
raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
"noise removal and is recommended. Please install and try again. If installation fails, "
"use --no_trim to disable this error message.")
del args.no_trim
# Preprocess the dataset
print_args(args, parser)
args.hparams = hparams.parse(args.hparams)
preprocess_dataset(**vars(args))

View File

@@ -0,0 +1,26 @@
from models.synthesizer.preprocess import create_embeddings
from utils.argutils import print_args
from pathlib import Path
import argparse
if __name__ == "__main__":
print("This method is deprecaded and will not be longer supported, please use 'pre.py'")
parser = argparse.ArgumentParser(
description="Creates embeddings for the synthesizer from the LibriSpeech utterances.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("synthesizer_root", type=Path, help=\
"Path to the synthesizer training data that contains the audios and the train.txt file. "
"If you let everything as default, it should be <datasets_root>/SV2TTS/synthesizer/.")
parser.add_argument("-e", "--encoder_model_fpath", type=Path,
default="encoder/saved_models/pretrained.pt", help=\
"Path your trained encoder model.")
parser.add_argument("-n", "--n_processes", type=int, default=4, help= \
"Number of parallel processes. An encoder is created for each, so you may need to lower "
"this value on GPUs with low memory. Set it to 1 if CUDA is unhappy.")
args = parser.parse_args()
# Preprocess the dataset
print_args(args, parser)
create_embeddings(**vars(args))

View File

@@ -0,0 +1,37 @@
from models.synthesizer.hparams import hparams
from models.synthesizer.train import train
from utils.argutils import print_args
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("run_id", type=str, help= \
"Name for this model instance. If a model state from the same run ID was previously "
"saved, the training will restart from there. Pass -f to overwrite saved states and "
"restart from scratch.")
parser.add_argument("syn_dir", type=str, default=argparse.SUPPRESS, help= \
"Path to the synthesizer directory that contains the ground truth mel spectrograms, "
"the wavs and the embeds.")
parser.add_argument("-m", "--models_dir", type=str, default="synthesizer/saved_models/", help=\
"Path to the output directory that will contain the saved model weights and the logs.")
parser.add_argument("-s", "--save_every", type=int, default=1000, help= \
"Number of steps between updates of the model on the disk. Set to 0 to never save the "
"model.")
parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \
"Number of steps between backups of the model. Set to 0 to never make backups of the "
"model.")
parser.add_argument("-l", "--log_every", type=int, default=200, help= \
"Number of steps between summary the training info in tensorboard")
parser.add_argument("-f", "--force_restart", action="store_true", help= \
"Do not load any saved model and restart from scratch.")
parser.add_argument("--hparams", default="",
help="Hyperparameter overrides as a comma-separated list of name=value "
"pairs")
args = parser.parse_args()
print_args(args, parser)
args.hparams = hparams.parse(args.hparams)
# Run the training
train(**vars(args))

View File

@@ -0,0 +1,59 @@
from models.synthesizer.synthesize import run_synthesis
from models.synthesizer.hparams import hparams
from utils.argutils import print_args
import argparse
import os
if __name__ == "__main__":
class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
pass
parser = argparse.ArgumentParser(
description="Creates ground-truth aligned (GTA) spectrograms from the vocoder.",
formatter_class=MyFormatter
)
parser.add_argument("datasets_root", type=str, help=\
"Path to the directory containing your SV2TTS directory. If you specify both --in_dir and "
"--out_dir, this argument won't be used.")
parser.add_argument("-m", "--model_dir", type=str,
default="synthesizer/saved_models/mandarin/", help=\
"Path to the pretrained model directory.")
parser.add_argument("-i", "--in_dir", type=str, default=argparse.SUPPRESS, help= \
"Path to the synthesizer directory that contains the mel spectrograms, the wavs and the "
"embeds. Defaults to <datasets_root>/SV2TTS/synthesizer/.")
parser.add_argument("-o", "--out_dir", type=str, default=argparse.SUPPRESS, help= \
"Path to the output vocoder directory that will contain the ground truth aligned mel "
"spectrograms. Defaults to <datasets_root>/SV2TTS/vocoder/.")
parser.add_argument("--hparams", default="",
help="Hyperparameter overrides as a comma-separated list of name=value "
"pairs")
parser.add_argument("--no_trim", action="store_true", help=\
"Preprocess audio without trimming silences (not recommended).")
parser.add_argument("--cpu", action="store_true", help=\
"If True, processing is done on CPU, even when a GPU is available.")
args = parser.parse_args()
print_args(args, parser)
modified_hp = hparams.parse(args.hparams)
if not hasattr(args, "in_dir"):
args.in_dir = os.path.join(args.datasets_root, "SV2TTS", "synthesizer")
if not hasattr(args, "out_dir"):
args.out_dir = os.path.join(args.datasets_root, "SV2TTS", "vocoder")
if args.cpu:
# Hide GPUs from Pytorch to force CPU processing
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# Verify webrtcvad is available
if not args.no_trim:
try:
import webrtcvad
except:
raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
"noise removal and is recommended. Please install and try again. If installation fails, "
"use --no_trim to disable this error message.")
del args.no_trim
run_synthesis(args.in_dir, args.out_dir, args.model_dir, modified_hp)

View File

@@ -0,0 +1,92 @@
from utils.argutils import print_args
from models.vocoder.wavernn.train import train
from models.vocoder.hifigan.train import train as train_hifigan
from models.vocoder.fregan.train import train as train_fregan
from utils.util import AttrDict
from pathlib import Path
import argparse
import json
import torch
import torch.multiprocessing as mp
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Trains the vocoder from the synthesizer audios and the GTA synthesized mels, "
"or ground truth mels.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("run_id", type=str, help= \
"Name for this model instance. If a model state from the same run ID was previously "
"saved, the training will restart from there. Pass -f to overwrite saved states and "
"restart from scratch.")
parser.add_argument("datasets_root", type=str, help= \
"Path to the directory containing your SV2TTS directory. Specifying --syn_dir or --voc_dir "
"will take priority over this argument.")
parser.add_argument("vocoder_type", type=str, default="wavernn", help= \
"Choose the vocoder type for train. Defaults to wavernn"
"Now, Support <hifigan> and <wavernn> for choose")
parser.add_argument("--syn_dir", type=str, default=argparse.SUPPRESS, help= \
"Path to the synthesizer directory that contains the ground truth mel spectrograms, "
"the wavs and the embeds. Defaults to <datasets_root>/SV2TTS/synthesizer/.")
parser.add_argument("--voc_dir", type=str, default=argparse.SUPPRESS, help= \
"Path to the vocoder directory that contains the GTA synthesized mel spectrograms. "
"Defaults to <datasets_root>/SV2TTS/vocoder/. Unused if --ground_truth is passed.")
parser.add_argument("-m", "--models_dir", type=str, default="vocoder/saved_models/", help=\
"Path to the directory that will contain the saved model weights, as well as backups "
"of those weights and wavs generated during training.")
parser.add_argument("-g", "--ground_truth", action="store_true", help= \
"Train on ground truth spectrograms (<datasets_root>/SV2TTS/synthesizer/mels).")
parser.add_argument("-s", "--save_every", type=int, default=1000, help= \
"Number of steps between updates of the model on the disk. Set to 0 to never save the "
"model.")
parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \
"Number of steps between backups of the model. Set to 0 to never make backups of the "
"model.")
parser.add_argument("-f", "--force_restart", action="store_true", help= \
"Do not load any saved model and restart from scratch.")
parser.add_argument("--config", type=str, default="vocoder/hifigan/config_16k_.json")
args = parser.parse_args()
if not hasattr(args, "syn_dir"):
args.syn_dir = Path(args.datasets_root, "SV2TTS", "synthesizer")
args.syn_dir = Path(args.syn_dir)
if not hasattr(args, "voc_dir"):
args.voc_dir = Path(args.datasets_root, "SV2TTS", "vocoder")
args.voc_dir = Path(args.voc_dir)
del args.datasets_root
args.models_dir = Path(args.models_dir)
args.models_dir.mkdir(exist_ok=True)
print_args(args, parser)
# Process the arguments
if args.vocoder_type == "wavernn":
# Run the training wavernn
delattr(args, 'vocoder_type')
delattr(args, 'config')
train(**vars(args))
elif args.vocoder_type == "hifigan":
with open(args.config) as f:
json_config = json.load(f)
h = AttrDict(json_config)
if h.num_gpus > 1:
h.num_gpus = torch.cuda.device_count()
h.batch_size = int(h.batch_size / h.num_gpus)
print('Batch size per GPU :', h.batch_size)
mp.spawn(train_hifigan, nprocs=h.num_gpus, args=(args, h,))
else:
train_hifigan(0, args, h)
elif args.vocoder_type == "fregan":
with open('vocoder/fregan/config.json') as f:
json_config = json.load(f)
h = AttrDict(json_config)
if h.num_gpus > 1:
h.num_gpus = torch.cuda.device_count()
h.batch_size = int(h.batch_size / h.num_gpus)
print('Batch size per GPU :', h.batch_size)
mp.spawn(train_fregan, nprocs=h.num_gpus, args=(args, h,))
else:
train_fregan(0, args, h)

View File

145
control/mkgui/app.py Normal file
View File

@@ -0,0 +1,145 @@
from pydantic import BaseModel, Field
import os
from pathlib import Path
from enum import Enum
from models.encoder import inference as encoder
import librosa
from scipy.io.wavfile import write
import re
import numpy as np
from control.mkgui.base.components.types import FileContent
from models.vocoder.hifigan import inference as gan_vocoder
from models.synthesizer.inference import Synthesizer
from typing import Any, Tuple
import matplotlib.pyplot as plt
# Constants
AUDIO_SAMPLES_DIR = f"data{os.sep}samples{os.sep}"
SYN_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}synthesizer"
ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
VOC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}vocoder"
TEMP_SOURCE_AUDIO = f"wavs{os.sep}temp_source.wav"
TEMP_RESULT_AUDIO = f"wavs{os.sep}temp_result.wav"
if not os.path.isdir("wavs"):
os.makedirs("wavs")
# Load local sample audio as options TODO: load dataset
if os.path.isdir(AUDIO_SAMPLES_DIR):
audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
# Pre-Load models
if os.path.isdir(SYN_MODELS_DIRT):
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
print("Loaded synthesizer models: " + str(len(synthesizers)))
else:
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist. 请将模型文件位置移动到上述位置中进行重试!")
if os.path.isdir(ENC_MODELS_DIRT):
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
print("Loaded encoders models: " + str(len(encoders)))
else:
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
if os.path.isdir(VOC_MODELS_DIRT):
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
print("Loaded vocoders models: " + str(len(synthesizers)))
else:
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
class Input(BaseModel):
message: str = Field(
..., example="欢迎使用工具箱, 现已支持中文输入!", alias="文本内容"
)
local_audio_file: audio_input_selection = Field(
..., alias="输入语音本地wav",
description="选择本地语音文件."
)
upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
description="拖拽或点击上传.", mime_type="audio/wav")
encoder: encoders = Field(
..., alias="编码模型",
description="选择语音编码模型文件."
)
synthesizer: synthesizers = Field(
..., alias="合成模型",
description="选择语音合成模型文件."
)
vocoder: vocoders = Field(
..., alias="语音解码模型",
description="选择语音解码模型文件(目前只支持HifiGan类型)."
)
class AudioEntity(BaseModel):
content: bytes
mel: Any
class Output(BaseModel):
__root__: Tuple[AudioEntity, AudioEntity]
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
"""Custom output UI.
If this method is implmeneted, it will be used instead of the default Output UI renderer.
"""
src, result = self.__root__
streamlit_app.subheader("Synthesized Audio")
streamlit_app.audio(result.content, format="audio/wav")
fig, ax = plt.subplots()
ax.imshow(src.mel, aspect="equal", interpolation="none")
ax.set_title("mel spectrogram(Source Audio)")
streamlit_app.pyplot(fig)
fig, ax = plt.subplots()
ax.imshow(result.mel, aspect="equal", interpolation="none")
ax.set_title("mel spectrogram(Result Audio)")
streamlit_app.pyplot(fig)
def synthesize(input: Input) -> Output:
"""synthesize(合成)"""
# load models
encoder.load_model(Path(input.encoder.value))
current_synt = Synthesizer(Path(input.synthesizer.value))
gan_vocoder.load_model(Path(input.vocoder.value))
# load file
if input.upload_audio_file != None:
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
f.write(input.upload_audio_file.as_bytes())
f.seek(0)
wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
else:
wav, sample_rate = librosa.load(input.local_audio_file.value)
write(TEMP_SOURCE_AUDIO, sample_rate, wav) #Make sure we get the correct wav
source_spec = Synthesizer.make_spectrogram(wav)
# preprocess
encoder_wav = encoder.preprocess_wav(wav, sample_rate)
embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
# Load input text
texts = filter(None, input.message.split("\n"))
punctuation = '!,。、,' # punctuate and split/clean text
processed_texts = []
for text in texts:
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
if processed_text:
processed_texts.append(processed_text.strip())
texts = processed_texts
# synthesize and vocode
embeds = [embed] * len(texts)
specs = current_synt.synthesize_spectrograms(texts, embeds)
spec = np.concatenate(specs, axis=1)
sample_rate = Synthesizer.sample_rate
wav, sample_rate = gan_vocoder.infer_waveform(spec)
# write and output
write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
with open(TEMP_SOURCE_AUDIO, "rb") as f:
source_file = f.read()
with open(TEMP_RESULT_AUDIO, "rb") as f:
result_file = f.read()
return Output(__root__=(AudioEntity(content=source_file, mel=source_spec), AudioEntity(content=result_file, mel=spec)))

166
control/mkgui/app_vc.py Normal file
View File

@@ -0,0 +1,166 @@
import os
from enum import Enum
from pathlib import Path
from typing import Any, Tuple
import librosa
import matplotlib.pyplot as plt
import torch
from pydantic import BaseModel, Field
from scipy.io.wavfile import write
import models.ppg2mel as Convertor
import models.ppg_extractor as Extractor
from control.mkgui.base.components.types import FileContent
from models.encoder import inference as speacker_encoder
from models.synthesizer.inference import Synthesizer
from models.vocoder.hifigan import inference as gan_vocoder
# Constants
AUDIO_SAMPLES_DIR = f'data{os.sep}samples{os.sep}'
EXT_MODELS_DIRT = f'data{os.sep}ckpt{os.sep}ppg_extractor'
CONV_MODELS_DIRT = f'data{os.sep}ckpt{os.sep}ppg2mel'
VOC_MODELS_DIRT = f'data{os.sep}ckpt{os.sep}vocoder'
TEMP_SOURCE_AUDIO = f'wavs{os.sep}temp_source.wav'
TEMP_TARGET_AUDIO = f'wavs{os.sep}temp_target.wav'
TEMP_RESULT_AUDIO = f'wavs{os.sep}temp_result.wav'
# Load local sample audio as options TODO: load dataset
if os.path.isdir(AUDIO_SAMPLES_DIR):
audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
# Pre-Load models
if os.path.isdir(EXT_MODELS_DIRT):
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
print("Loaded extractor models: " + str(len(extractors)))
else:
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
if os.path.isdir(CONV_MODELS_DIRT):
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
print("Loaded convertor models: " + str(len(convertors)))
else:
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
if os.path.isdir(VOC_MODELS_DIRT):
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
print("Loaded vocoders models: " + str(len(vocoders)))
else:
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
class Input(BaseModel):
local_audio_file: audio_input_selection = Field(
..., alias="输入语音本地wav",
description="选择本地语音文件."
)
upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
description="拖拽或点击上传.", mime_type="audio/wav")
local_audio_file_target: audio_input_selection = Field(
..., alias="目标语音本地wav",
description="选择本地语音文件."
)
upload_audio_file_target: FileContent = Field(default=None, alias="或上传目标语音",
description="拖拽或点击上传.", mime_type="audio/wav")
extractor: extractors = Field(
..., alias="编码模型",
description="选择语音编码模型文件."
)
convertor: convertors = Field(
..., alias="转换模型",
description="选择语音转换模型文件."
)
vocoder: vocoders = Field(
..., alias="语音解码模型",
description="选择语音解码模型文件(目前只支持HifiGan类型)."
)
class AudioEntity(BaseModel):
content: bytes
mel: Any
class Output(BaseModel):
__root__: Tuple[AudioEntity, AudioEntity, AudioEntity]
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
"""Custom output UI.
If this method is implmeneted, it will be used instead of the default Output UI renderer.
"""
src, target, result = self.__root__
streamlit_app.subheader("Synthesized Audio")
streamlit_app.audio(result.content, format="audio/wav")
fig, ax = plt.subplots()
ax.imshow(src.mel, aspect="equal", interpolation="none")
ax.set_title("mel spectrogram(Source Audio)")
streamlit_app.pyplot(fig)
fig, ax = plt.subplots()
ax.imshow(target.mel, aspect="equal", interpolation="none")
ax.set_title("mel spectrogram(Target Audio)")
streamlit_app.pyplot(fig)
fig, ax = plt.subplots()
ax.imshow(result.mel, aspect="equal", interpolation="none")
ax.set_title("mel spectrogram(Result Audio)")
streamlit_app.pyplot(fig)
def convert(input: Input) -> Output:
"""convert(转换)"""
# load models
extractor = Extractor.load_model(Path(input.extractor.value))
convertor = Convertor.load_model(Path(input.convertor.value))
# current_synt = Synthesizer(Path(input.synthesizer.value))
gan_vocoder.load_model(Path(input.vocoder.value))
# load file
if input.upload_audio_file != None:
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
f.write(input.upload_audio_file.as_bytes())
f.seek(0)
src_wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
else:
src_wav, sample_rate = librosa.load(input.local_audio_file.value)
write(TEMP_SOURCE_AUDIO, sample_rate, src_wav) #Make sure we get the correct wav
if input.upload_audio_file_target != None:
with open(TEMP_TARGET_AUDIO, "w+b") as f:
f.write(input.upload_audio_file_target.as_bytes())
f.seek(0)
ref_wav, _ = librosa.load(TEMP_TARGET_AUDIO)
else:
ref_wav, _ = librosa.load(input.local_audio_file_target.value)
write(TEMP_TARGET_AUDIO, sample_rate, ref_wav) #Make sure we get the correct wav
ppg = extractor.extract_from_wav(src_wav)
# Import necessary dependency of Voice Conversion
from utils.f0_utils import (compute_f0, compute_mean_std, f02lf0,
get_converted_lf0uv)
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
speacker_encoder.load_model(Path(f"data{os.sep}ckpt{os.sep}encoder{os.sep}pretrained_bak_5805000.pt"))
embed = speacker_encoder.embed_utterance(ref_wav)
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
min_len = min(ppg.shape[1], len(lf0_uv))
ppg = ppg[:, :min_len]
lf0_uv = lf0_uv[:min_len]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_, mel_pred, att_ws = convertor.inference(
ppg,
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
spembs=torch.from_numpy(embed).unsqueeze(0).to(device),
)
mel_pred= mel_pred.transpose(0, 1)
breaks = [mel_pred.shape[1]]
mel_pred= mel_pred.detach().cpu().numpy()
# synthesize and vocode
wav, sample_rate = gan_vocoder.infer_waveform(mel_pred)
# write and output
write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
with open(TEMP_SOURCE_AUDIO, "rb") as f:
source_file = f.read()
with open(TEMP_TARGET_AUDIO, "rb") as f:
target_file = f.read()
with open(TEMP_RESULT_AUDIO, "rb") as f:
result_file = f.read()
return Output(__root__=(AudioEntity(content=source_file, mel=Synthesizer.make_spectrogram(src_wav)), AudioEntity(content=target_file, mel=Synthesizer.make_spectrogram(ref_wav)), AudioEntity(content=result_file, mel=Synthesizer.make_spectrogram(wav))))

View File

@@ -0,0 +1,2 @@
from .core import Opyrator

View File

@@ -0,0 +1 @@
from .fastapi_app import create_api

View File

@@ -0,0 +1,102 @@
"""Collection of utilities for FastAPI apps."""
import inspect
from typing import Any, Type
from fastapi import FastAPI, Form
from pydantic import BaseModel
def as_form(cls: Type[BaseModel]) -> Any:
"""Adds an as_form class method to decorated models.
The as_form class method can be used with FastAPI endpoints
"""
new_params = [
inspect.Parameter(
field.alias,
inspect.Parameter.POSITIONAL_ONLY,
default=(Form(field.default) if not field.required else Form(...)),
)
for field in cls.__fields__.values()
]
async def _as_form(**data): # type: ignore
return cls(**data)
sig = inspect.signature(_as_form)
sig = sig.replace(parameters=new_params)
_as_form.__signature__ = sig # type: ignore
setattr(cls, "as_form", _as_form)
return cls
def patch_fastapi(app: FastAPI) -> None:
"""Patch function to allow relative url resolution.
This patch is required to make fastapi fully functional with a relative url path.
This code snippet can be copy-pasted to any Fastapi application.
"""
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from starlette.requests import Request
from starlette.responses import HTMLResponse
async def redoc_ui_html(req: Request) -> HTMLResponse:
assert app.openapi_url is not None
redoc_ui = get_redoc_html(
openapi_url="./" + app.openapi_url.lstrip("/"),
title=app.title + " - Redoc UI",
)
return HTMLResponse(redoc_ui.body.decode("utf-8"))
async def swagger_ui_html(req: Request) -> HTMLResponse:
assert app.openapi_url is not None
swagger_ui = get_swagger_ui_html(
openapi_url="./" + app.openapi_url.lstrip("/"),
title=app.title + " - Swagger UI",
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
)
# insert request interceptor to have all request run on relativ path
request_interceptor = (
"requestInterceptor: (e) => {"
"\n\t\t\tvar url = window.location.origin + window.location.pathname"
'\n\t\t\turl = url.substring( 0, url.lastIndexOf( "/" ) + 1);'
"\n\t\t\turl = e.url.replace(/http(s)?:\/\/[^/]*\//i, url);" # noqa: W605
"\n\t\t\te.contextUrl = url"
"\n\t\t\te.url = url"
"\n\t\t\treturn e;}"
)
return HTMLResponse(
swagger_ui.body.decode("utf-8").replace(
"dom_id: '#swagger-ui',",
"dom_id: '#swagger-ui',\n\t\t" + request_interceptor + ",",
)
)
# remove old docs route and add our patched route
routes_new = []
for app_route in app.routes:
if app_route.path == "/docs": # type: ignore
continue
if app_route.path == "/redoc": # type: ignore
continue
routes_new.append(app_route)
app.router.routes = routes_new
assert app.docs_url is not None
app.add_route(app.docs_url, swagger_ui_html, include_in_schema=False)
assert app.redoc_url is not None
app.add_route(app.redoc_url, redoc_ui_html, include_in_schema=False)
# Make graphql realtive
from starlette import graphql
graphql.GRAPHIQL = graphql.GRAPHIQL.replace(
"({{REQUEST_PATH}}", '("." + {{REQUEST_PATH}}'
)

View File

@@ -0,0 +1,43 @@
from typing import List
from pydantic import BaseModel
class ScoredLabel(BaseModel):
label: str
score: float
class ClassificationOutput(BaseModel):
__root__: List[ScoredLabel]
def __iter__(self): # type: ignore
return iter(self.__root__)
def __getitem__(self, item): # type: ignore
return self.__root__[item]
def render_output_ui(self, streamlit) -> None: # type: ignore
import plotly.express as px
sorted_predictions = sorted(
[prediction.dict() for prediction in self.__root__],
key=lambda k: k["score"],
)
num_labels = len(sorted_predictions)
if len(sorted_predictions) > 10:
num_labels = streamlit.slider(
"Maximum labels to show: ",
min_value=1,
max_value=len(sorted_predictions),
value=len(sorted_predictions),
)
fig = px.bar(
sorted_predictions[len(sorted_predictions) - num_labels :],
x="score",
y="label",
orientation="h",
)
streamlit.plotly_chart(fig, use_container_width=True)
# fig.show()

View File

@@ -0,0 +1,46 @@
import base64
from typing import Any, Dict, overload
class FileContent(str):
def as_bytes(self) -> bytes:
return base64.b64decode(self, validate=True)
def as_str(self) -> str:
return self.as_bytes().decode()
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
field_schema.update(format="byte")
@classmethod
def __get_validators__(cls) -> Any: # type: ignore
yield cls.validate
@classmethod
def validate(cls, value: Any) -> "FileContent":
if isinstance(value, FileContent):
return value
elif isinstance(value, str):
return FileContent(value)
elif isinstance(value, (bytes, bytearray, memoryview)):
return FileContent(base64.b64encode(value).decode())
else:
raise Exception("Wrong type")
# # 暂时无法使用,因为浏览器中没有考虑选择文件夹
# class DirectoryContent(FileContent):
# @classmethod
# def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
# field_schema.update(format="path")
# @classmethod
# def validate(cls, value: Any) -> "DirectoryContent":
# if isinstance(value, DirectoryContent):
# return value
# elif isinstance(value, str):
# return DirectoryContent(value)
# elif isinstance(value, (bytes, bytearray, memoryview)):
# return DirectoryContent(base64.b64encode(value).decode())
# else:
# raise Exception("Wrong type")

203
control/mkgui/base/core.py Normal file
View File

@@ -0,0 +1,203 @@
import importlib
import inspect
import re
from typing import Any, Callable, Type, Union, get_type_hints
from pydantic import BaseModel, parse_raw_as
from pydantic.tools import parse_obj_as
def name_to_title(name: str) -> str:
"""Converts a camelCase or snake_case name to title case."""
# If camelCase -> convert to snake case
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
# Convert to title case
return name.replace("_", " ").strip().title()
def is_compatible_type(type: Type) -> bool:
"""Returns `True` if the type is opyrator-compatible."""
try:
if issubclass(type, BaseModel):
return True
except Exception:
pass
try:
# valid list type
if type.__origin__ is list and issubclass(type.__args__[0], BaseModel):
return True
except Exception:
pass
return False
def get_input_type(func: Callable) -> Type:
"""Returns the input type of a given function (callable).
Args:
func: The function for which to get the input type.
Raises:
ValueError: If the function does not have a valid input type annotation.
"""
type_hints = get_type_hints(func)
if "input" not in type_hints:
raise ValueError(
"The callable MUST have a parameter with the name `input` with typing annotation. "
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
)
input_type = type_hints["input"]
if not is_compatible_type(input_type):
raise ValueError(
"The `input` parameter MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
)
# TODO: return warning if more than one input parameters
return input_type
def get_output_type(func: Callable) -> Type:
"""Returns the output type of a given function (callable).
Args:
func: The function for which to get the output type.
Raises:
ValueError: If the function does not have a valid output type annotation.
"""
type_hints = get_type_hints(func)
if "return" not in type_hints:
raise ValueError(
"The return type of the callable MUST be annotated with type hints."
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
)
output_type = type_hints["return"]
if not is_compatible_type(output_type):
raise ValueError(
"The return value MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
)
return output_type
def get_callable(import_string: str) -> Callable:
"""Import a callable from an string."""
callable_seperator = ":"
if callable_seperator not in import_string:
# Use dot as seperator
callable_seperator = "."
if callable_seperator not in import_string:
raise ValueError("The callable path MUST specify the function. ")
mod_name, callable_name = import_string.rsplit(callable_seperator, 1)
mod = importlib.import_module(mod_name)
return getattr(mod, callable_name)
class Opyrator:
def __init__(self, func: Union[Callable, str]) -> None:
if isinstance(func, str):
# Try to load the function from a string notion
self.function = get_callable(func)
else:
self.function = func
self._action = "Execute"
self._input_type = None
self._output_type = None
if not callable(self.function):
raise ValueError("The provided function parameters is not a callable.")
if inspect.isclass(self.function):
raise ValueError(
"The provided callable is an uninitialized Class. This is not allowed."
)
if inspect.isfunction(self.function):
# The provided callable is a function
self._input_type = get_input_type(self.function)
self._output_type = get_output_type(self.function)
try:
# Get name
self._name = name_to_title(self.function.__name__)
except Exception:
pass
try:
# Get description from function
doc_string = inspect.getdoc(self.function)
if doc_string:
self._action = doc_string
except Exception:
pass
elif hasattr(self.function, "__call__"):
# The provided callable is a function
self._input_type = get_input_type(self.function.__call__) # type: ignore
self._output_type = get_output_type(self.function.__call__) # type: ignore
try:
# Get name
self._name = name_to_title(type(self.function).__name__)
except Exception:
pass
try:
# Get action from
doc_string = inspect.getdoc(self.function.__call__) # type: ignore
if doc_string:
self._action = doc_string
if (
not self._action
or self._action == "Call"
):
# Get docstring from class instead of __call__ function
doc_string = inspect.getdoc(self.function)
if doc_string:
self._action = doc_string
except Exception:
pass
else:
raise ValueError("Unknown callable type.")
@property
def name(self) -> str:
return self._name
@property
def action(self) -> str:
return self._action
@property
def input_type(self) -> Any:
return self._input_type
@property
def output_type(self) -> Any:
return self._output_type
def __call__(self, input: Any, **kwargs: Any) -> Any:
input_obj = input
if isinstance(input, str):
# Allow json input
input_obj = parse_raw_as(self.input_type, input)
if isinstance(input, dict):
# Allow dict input
input_obj = parse_obj_as(self.input_type, input)
return self.function(input_obj, **kwargs)

View File

@@ -0,0 +1 @@
from .streamlit_ui import render_streamlit_ui

View File

@@ -0,0 +1,129 @@
from typing import Dict
def resolve_reference(reference: str, references: Dict) -> Dict:
return references[reference.split("/")[-1]]
def get_single_reference_item(property: Dict, references: Dict) -> Dict:
# Ref can either be directly in the properties or the first element of allOf
reference = property.get("$ref")
if reference is None:
reference = property["allOf"][0]["$ref"]
return resolve_reference(reference, references)
def is_single_string_property(property: Dict) -> bool:
return property.get("type") == "string"
def is_single_datetime_property(property: Dict) -> bool:
if property.get("type") != "string":
return False
return property.get("format") in ["date-time", "time", "date"]
def is_single_boolean_property(property: Dict) -> bool:
return property.get("type") == "boolean"
def is_single_number_property(property: Dict) -> bool:
return property.get("type") in ["integer", "number"]
def is_single_file_property(property: Dict) -> bool:
if property.get("type") != "string":
return False
# TODO: binary?
return property.get("format") == "byte"
def is_single_directory_property(property: Dict) -> bool:
if property.get("type") != "string":
return False
return property.get("format") == "path"
def is_multi_enum_property(property: Dict, references: Dict) -> bool:
if property.get("type") != "array":
return False
if property.get("uniqueItems") is not True:
# Only relevant if it is a set or other datastructures with unique items
return False
try:
_ = resolve_reference(property["items"]["$ref"], references)["enum"]
return True
except Exception:
return False
def is_single_enum_property(property: Dict, references: Dict) -> bool:
try:
_ = get_single_reference_item(property, references)["enum"]
return True
except Exception:
return False
def is_single_dict_property(property: Dict) -> bool:
if property.get("type") != "object":
return False
return "additionalProperties" in property
def is_single_reference(property: Dict) -> bool:
if property.get("type") is not None:
return False
return bool(property.get("$ref"))
def is_multi_file_property(property: Dict) -> bool:
if property.get("type") != "array":
return False
if property.get("items") is None:
return False
try:
# TODO: binary
return property["items"]["format"] == "byte"
except Exception:
return False
def is_single_object(property: Dict, references: Dict) -> bool:
try:
object_reference = get_single_reference_item(property, references)
if object_reference["type"] != "object":
return False
return "properties" in object_reference
except Exception:
return False
def is_property_list(property: Dict) -> bool:
if property.get("type") != "array":
return False
if property.get("items") is None:
return False
try:
return property["items"]["type"] in ["string", "number", "integer"]
except Exception:
return False
def is_object_list_property(property: Dict, references: Dict) -> bool:
if property.get("type") != "array":
return False
try:
object_reference = resolve_reference(property["items"]["$ref"], references)
if object_reference["type"] != "object":
return False
return "properties" in object_reference
except Exception:
return False

View File

@@ -0,0 +1,887 @@
import datetime
import inspect
import mimetypes
import sys
from os import getcwd, unlink, path
from platform import system
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Dict, List, Type
from PIL import Image
import pandas as pd
import streamlit as st
from fastapi.encoders import jsonable_encoder
from loguru import logger
from pydantic import BaseModel, ValidationError, parse_obj_as
from control.mkgui.base import Opyrator
from control.mkgui.base.core import name_to_title
from . import schema_utils
from .streamlit_utils import CUSTOM_STREAMLIT_CSS
STREAMLIT_RUNNER_SNIPPET = """
from control.mkgui.base.ui import render_streamlit_ui
import streamlit as st
# TODO: Make it configurable
# Page config can only be setup once
st.set_page_config(
page_title="MockingBird",
page_icon="🧊",
layout="wide")
render_streamlit_ui()
"""
# with st.spinner("Loading MockingBird GUI. Please wait..."):
# opyrator = Opyrator("{opyrator_path}")
def launch_ui(port: int = 8501) -> None:
with NamedTemporaryFile(
suffix=".py", mode="w", encoding="utf-8", delete=False
) as f:
f.write(STREAMLIT_RUNNER_SNIPPET)
f.seek(0)
import subprocess
python_path = f'PYTHONPATH="$PYTHONPATH:{getcwd()}"'
if system() == "Windows":
python_path = f"set PYTHONPATH=%PYTHONPATH%;{getcwd()} &&"
subprocess.run(
f"""set STREAMLIT_GLOBAL_SHOW_WARNING_ON_DIRECT_EXECUTION=false""",
shell=True,
)
subprocess.run(
f"""{python_path} "{sys.executable}" -m streamlit run --server.port={port} --server.headless=True --runner.magicEnabled=False --server.maxUploadSize=50 --browser.gatherUsageStats=False {f.name}""",
shell=True,
)
f.close()
unlink(f.name)
def function_has_named_arg(func: Callable, parameter: str) -> bool:
try:
sig = inspect.signature(func)
for param in sig.parameters.values():
if param.name == "input":
return True
except Exception:
return False
return False
def has_output_ui_renderer(data_item: BaseModel) -> bool:
return hasattr(data_item, "render_output_ui")
def has_input_ui_renderer(input_class: Type[BaseModel]) -> bool:
return hasattr(input_class, "render_input_ui")
def is_compatible_audio(mime_type: str) -> bool:
return mime_type in ["audio/mpeg", "audio/ogg", "audio/wav"]
def is_compatible_image(mime_type: str) -> bool:
return mime_type in ["image/png", "image/jpeg"]
def is_compatible_video(mime_type: str) -> bool:
return mime_type in ["video/mp4"]
class InputUI:
def __init__(self, session_state, input_class: Type[BaseModel]):
self._session_state = session_state
self._input_class = input_class
self._schema_properties = input_class.schema(by_alias=True).get(
"properties", {}
)
self._schema_references = input_class.schema(by_alias=True).get(
"definitions", {}
)
def render_ui(self, streamlit_app_root) -> None:
if has_input_ui_renderer(self._input_class):
# The input model has a rendering function
# The rendering also returns the current state of input data
self._session_state.input_data = self._input_class.render_input_ui( # type: ignore
st, self._session_state.input_data
)
return
# print(self._schema_properties)
for property_key in self._schema_properties.keys():
property = self._schema_properties[property_key]
if not property.get("title"):
# Set property key as fallback title
property["title"] = name_to_title(property_key)
try:
if "input_data" in self._session_state:
self._store_value(
property_key,
self._render_property(streamlit_app_root, property_key, property),
)
except Exception as e:
print("Exception!", e)
pass
def _get_default_streamlit_input_kwargs(self, key: str, property: Dict) -> Dict:
streamlit_kwargs = {
"label": property.get("title"),
"key": key,
}
if property.get("description"):
streamlit_kwargs["help"] = property.get("description")
return streamlit_kwargs
def _store_value(self, key: str, value: Any) -> None:
data_element = self._session_state.input_data
key_elements = key.split(".")
for i, key_element in enumerate(key_elements):
if i == len(key_elements) - 1:
# add value to this element
data_element[key_element] = value
return
if key_element not in data_element:
data_element[key_element] = {}
data_element = data_element[key_element]
def _get_value(self, key: str) -> Any:
data_element = self._session_state.input_data
key_elements = key.split(".")
for i, key_element in enumerate(key_elements):
if i == len(key_elements) - 1:
# add value to this element
if key_element not in data_element:
return None
return data_element[key_element]
if key_element not in data_element:
data_element[key_element] = {}
data_element = data_element[key_element]
return None
def _render_single_datetime_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
if property.get("format") == "time":
if property.get("default"):
try:
streamlit_kwargs["value"] = datetime.time.fromisoformat( # type: ignore
property.get("default")
)
except Exception:
pass
return streamlit_app.time_input(**streamlit_kwargs)
elif property.get("format") == "date":
if property.get("default"):
try:
streamlit_kwargs["value"] = datetime.date.fromisoformat( # type: ignore
property.get("default")
)
except Exception:
pass
return streamlit_app.date_input(**streamlit_kwargs)
elif property.get("format") == "date-time":
if property.get("default"):
try:
streamlit_kwargs["value"] = datetime.datetime.fromisoformat( # type: ignore
property.get("default")
)
except Exception:
pass
with streamlit_app.container():
streamlit_app.subheader(streamlit_kwargs.get("label"))
if streamlit_kwargs.get("description"):
streamlit_app.text(streamlit_kwargs.get("description"))
selected_date = None
selected_time = None
date_col, time_col = streamlit_app.columns(2)
with date_col:
date_kwargs = {"label": "Date", "key": key + "-date-input"}
if streamlit_kwargs.get("value"):
try:
date_kwargs["value"] = streamlit_kwargs.get( # type: ignore
"value"
).date()
except Exception:
pass
selected_date = streamlit_app.date_input(**date_kwargs)
with time_col:
time_kwargs = {"label": "Time", "key": key + "-time-input"}
if streamlit_kwargs.get("value"):
try:
time_kwargs["value"] = streamlit_kwargs.get( # type: ignore
"value"
).time()
except Exception:
pass
selected_time = streamlit_app.time_input(**time_kwargs)
return datetime.datetime.combine(selected_date, selected_time)
else:
streamlit_app.warning(
"Date format is not supported: " + str(property.get("format"))
)
def _render_single_file_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
file_extension = None
if "mime_type" in property:
file_extension = mimetypes.guess_extension(property["mime_type"])
uploaded_file = streamlit_app.file_uploader(
**streamlit_kwargs, accept_multiple_files=False, type=file_extension
)
if uploaded_file is None:
return None
bytes = uploaded_file.getvalue()
if property.get("mime_type"):
if is_compatible_audio(property["mime_type"]):
# Show audio
streamlit_app.audio(bytes, format=property.get("mime_type"))
if is_compatible_image(property["mime_type"]):
# Show image
streamlit_app.image(bytes)
if is_compatible_video(property["mime_type"]):
# Show video
streamlit_app.video(bytes, format=property.get("mime_type"))
return bytes
def _render_single_string_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
if property.get("default"):
streamlit_kwargs["value"] = property.get("default")
elif property.get("example"):
# TODO: also use example for other property types
# Use example as value if it is provided
streamlit_kwargs["value"] = property.get("example")
if property.get("maxLength") is not None:
streamlit_kwargs["max_chars"] = property.get("maxLength")
if (
property.get("format")
or (
property.get("maxLength") is not None
and int(property.get("maxLength")) < 140 # type: ignore
)
or property.get("writeOnly")
):
# If any format is set, use single text input
# If max chars is set to less than 140, use single text input
# If write only -> password field
if property.get("writeOnly"):
streamlit_kwargs["type"] = "password"
return streamlit_app.text_input(**streamlit_kwargs)
else:
# Otherwise use multiline text area
return streamlit_app.text_area(**streamlit_kwargs)
def _render_multi_enum_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
reference_item = schema_utils.resolve_reference(
property["items"]["$ref"], self._schema_references
)
# TODO: how to select defaults
return streamlit_app.multiselect(
**streamlit_kwargs, options=reference_item["enum"]
)
def _render_single_enum_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
reference_item = schema_utils.get_single_reference_item(
property, self._schema_references
)
if property.get("default") is not None:
try:
streamlit_kwargs["index"] = reference_item["enum"].index(
property.get("default")
)
except Exception:
# Use default selection
pass
return streamlit_app.selectbox(
**streamlit_kwargs, options=reference_item["enum"]
)
def _render_single_dict_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
# Add title and subheader
streamlit_app.subheader(property.get("title"))
if property.get("description"):
streamlit_app.markdown(property.get("description"))
streamlit_app.markdown("---")
current_dict = self._get_value(key)
if not current_dict:
current_dict = {}
key_col, value_col = streamlit_app.columns(2)
with key_col:
updated_key = streamlit_app.text_input(
"Key", value="", key=key + "-new-key"
)
with value_col:
# TODO: also add boolean?
value_kwargs = {"label": "Value", "key": key + "-new-value"}
if property["additionalProperties"].get("type") == "integer":
value_kwargs["value"] = 0 # type: ignore
updated_value = streamlit_app.number_input(**value_kwargs)
elif property["additionalProperties"].get("type") == "number":
value_kwargs["value"] = 0.0 # type: ignore
value_kwargs["format"] = "%f"
updated_value = streamlit_app.number_input(**value_kwargs)
else:
value_kwargs["value"] = ""
updated_value = streamlit_app.text_input(**value_kwargs)
streamlit_app.markdown("---")
with streamlit_app.container():
clear_col, add_col = streamlit_app.columns([1, 2])
with clear_col:
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
current_dict = {}
with add_col:
if (
streamlit_app.button("Add Item", key=key + "-add-item")
and updated_key
):
current_dict[updated_key] = updated_value
streamlit_app.write(current_dict)
return current_dict
def _render_single_reference(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
reference_item = schema_utils.get_single_reference_item(
property, self._schema_references
)
return self._render_property(streamlit_app, key, reference_item)
def _render_multi_file_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
file_extension = None
if "mime_type" in property:
file_extension = mimetypes.guess_extension(property["mime_type"])
uploaded_files = streamlit_app.file_uploader(
**streamlit_kwargs, accept_multiple_files=True, type=file_extension
)
uploaded_files_bytes = []
if uploaded_files:
for uploaded_file in uploaded_files:
uploaded_files_bytes.append(uploaded_file.read())
return uploaded_files_bytes
def _render_single_boolean_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
if property.get("default"):
streamlit_kwargs["value"] = property.get("default")
return streamlit_app.checkbox(**streamlit_kwargs)
def _render_single_number_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
number_transform = int
if property.get("type") == "number":
number_transform = float # type: ignore
streamlit_kwargs["format"] = "%f"
if "multipleOf" in property:
# Set stepcount based on multiple of parameter
streamlit_kwargs["step"] = number_transform(property["multipleOf"])
elif number_transform == int:
# Set step size to 1 as default
streamlit_kwargs["step"] = 1
elif number_transform == float:
# Set step size to 0.01 as default
# TODO: adapt to default value
streamlit_kwargs["step"] = 0.01
if "minimum" in property:
streamlit_kwargs["min_value"] = number_transform(property["minimum"])
if "exclusiveMinimum" in property:
streamlit_kwargs["min_value"] = number_transform(
property["exclusiveMinimum"] + streamlit_kwargs["step"]
)
if "maximum" in property:
streamlit_kwargs["max_value"] = number_transform(property["maximum"])
if "exclusiveMaximum" in property:
streamlit_kwargs["max_value"] = number_transform(
property["exclusiveMaximum"] - streamlit_kwargs["step"]
)
if property.get("default") is not None:
streamlit_kwargs["value"] = number_transform(property.get("default")) # type: ignore
else:
if "min_value" in streamlit_kwargs:
streamlit_kwargs["value"] = streamlit_kwargs["min_value"]
elif number_transform == int:
streamlit_kwargs["value"] = 0
else:
# Set default value to step
streamlit_kwargs["value"] = number_transform(streamlit_kwargs["step"])
if "min_value" in streamlit_kwargs and "max_value" in streamlit_kwargs:
# TODO: Only if less than X steps
return streamlit_app.slider(**streamlit_kwargs)
else:
return streamlit_app.number_input(**streamlit_kwargs)
def _render_object_input(self, streamlit_app: st, key: str, property: Dict) -> Any:
properties = property["properties"]
object_inputs = {}
for property_key in properties:
property = properties[property_key]
if not property.get("title"):
# Set property key as fallback title
property["title"] = name_to_title(property_key)
# construct full key based on key parts -> required later to get the value
full_key = key + "." + property_key
object_inputs[property_key] = self._render_property(
streamlit_app, full_key, property
)
return object_inputs
def _render_single_object_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
# Add title and subheader
title = property.get("title")
streamlit_app.subheader(title)
if property.get("description"):
streamlit_app.markdown(property.get("description"))
object_reference = schema_utils.get_single_reference_item(
property, self._schema_references
)
return self._render_object_input(streamlit_app, key, object_reference)
def _render_property_list_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
# Add title and subheader
streamlit_app.subheader(property.get("title"))
if property.get("description"):
streamlit_app.markdown(property.get("description"))
streamlit_app.markdown("---")
current_list = self._get_value(key)
if not current_list:
current_list = []
value_kwargs = {"label": "Value", "key": key + "-new-value"}
if property["items"]["type"] == "integer":
value_kwargs["value"] = 0 # type: ignore
new_value = streamlit_app.number_input(**value_kwargs)
elif property["items"]["type"] == "number":
value_kwargs["value"] = 0.0 # type: ignore
value_kwargs["format"] = "%f"
new_value = streamlit_app.number_input(**value_kwargs)
else:
value_kwargs["value"] = ""
new_value = streamlit_app.text_input(**value_kwargs)
streamlit_app.markdown("---")
with streamlit_app.container():
clear_col, add_col = streamlit_app.columns([1, 2])
with clear_col:
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
current_list = []
with add_col:
if (
streamlit_app.button("Add Item", key=key + "-add-item")
and new_value is not None
):
current_list.append(new_value)
streamlit_app.write(current_list)
return current_list
def _render_object_list_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
# TODO: support max_items, and min_items properties
# Add title and subheader
streamlit_app.subheader(property.get("title"))
if property.get("description"):
streamlit_app.markdown(property.get("description"))
streamlit_app.markdown("---")
current_list = self._get_value(key)
if not current_list:
current_list = []
object_reference = schema_utils.resolve_reference(
property["items"]["$ref"], self._schema_references
)
input_data = self._render_object_input(streamlit_app, key, object_reference)
streamlit_app.markdown("---")
with streamlit_app.container():
clear_col, add_col = streamlit_app.columns([1, 2])
with clear_col:
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
current_list = []
with add_col:
if (
streamlit_app.button("Add Item", key=key + "-add-item")
and input_data
):
current_list.append(input_data)
streamlit_app.write(current_list)
return current_list
def _render_property(self, streamlit_app: st, key: str, property: Dict) -> Any:
if schema_utils.is_single_enum_property(property, self._schema_references):
return self._render_single_enum_input(streamlit_app, key, property)
if schema_utils.is_multi_enum_property(property, self._schema_references):
return self._render_multi_enum_input(streamlit_app, key, property)
if schema_utils.is_single_file_property(property):
return self._render_single_file_input(streamlit_app, key, property)
if schema_utils.is_multi_file_property(property):
return self._render_multi_file_input(streamlit_app, key, property)
if schema_utils.is_single_datetime_property(property):
return self._render_single_datetime_input(streamlit_app, key, property)
if schema_utils.is_single_boolean_property(property):
return self._render_single_boolean_input(streamlit_app, key, property)
if schema_utils.is_single_dict_property(property):
return self._render_single_dict_input(streamlit_app, key, property)
if schema_utils.is_single_number_property(property):
return self._render_single_number_input(streamlit_app, key, property)
if schema_utils.is_single_string_property(property):
return self._render_single_string_input(streamlit_app, key, property)
if schema_utils.is_single_object(property, self._schema_references):
return self._render_single_object_input(streamlit_app, key, property)
if schema_utils.is_object_list_property(property, self._schema_references):
return self._render_object_list_input(streamlit_app, key, property)
if schema_utils.is_property_list(property):
return self._render_property_list_input(streamlit_app, key, property)
if schema_utils.is_single_reference(property):
return self._render_single_reference(streamlit_app, key, property)
streamlit_app.warning(
"The type of the following property is currently not supported: "
+ str(property.get("title"))
)
raise Exception("Unsupported property")
class OutputUI:
def __init__(self, output_data: Any, input_data: Any):
self._output_data = output_data
self._input_data = input_data
def render_ui(self, streamlit_app) -> None:
try:
if isinstance(self._output_data, BaseModel):
self._render_single_output(streamlit_app, self._output_data)
return
if type(self._output_data) == list:
self._render_list_output(streamlit_app, self._output_data)
return
except Exception as ex:
streamlit_app.exception(ex)
# Fallback to
streamlit_app.json(jsonable_encoder(self._output_data))
def _render_single_text_property(
self, streamlit: st, property_schema: Dict, value: Any
) -> None:
# Add title and subheader
streamlit.subheader(property_schema.get("title"))
if property_schema.get("description"):
streamlit.markdown(property_schema.get("description"))
if value is None or value == "":
streamlit.info("No value returned!")
else:
streamlit.code(str(value), language="plain")
def _render_single_file_property(
self, streamlit: st, property_schema: Dict, value: Any
) -> None:
# Add title and subheader
streamlit.subheader(property_schema.get("title"))
if property_schema.get("description"):
streamlit.markdown(property_schema.get("description"))
if value is None or value == "":
streamlit.info("No value returned!")
else:
# TODO: Detect if it is a FileContent instance
# TODO: detect if it is base64
file_extension = ""
if "mime_type" in property_schema:
mime_type = property_schema["mime_type"]
file_extension = mimetypes.guess_extension(mime_type) or ""
if is_compatible_audio(mime_type):
streamlit.audio(value.as_bytes(), format=mime_type)
return
if is_compatible_image(mime_type):
streamlit.image(value.as_bytes())
return
if is_compatible_video(mime_type):
streamlit.video(value.as_bytes(), format=mime_type)
return
filename = (
(property_schema["title"] + file_extension)
.lower()
.strip()
.replace(" ", "-")
)
streamlit.markdown(
f'<a href="data:application/octet-stream;base64,{value}" download="{filename}"><input type="button" value="Download File"></a>',
unsafe_allow_html=True,
)
def _render_single_complex_property(
self, streamlit: st, property_schema: Dict, value: Any
) -> None:
# Add title and subheader
streamlit.subheader(property_schema.get("title"))
if property_schema.get("description"):
streamlit.markdown(property_schema.get("description"))
streamlit.json(jsonable_encoder(value))
def _render_single_output(self, streamlit: st, output_data: BaseModel) -> None:
try:
if has_output_ui_renderer(output_data):
if function_has_named_arg(output_data.render_output_ui, "input"): # type: ignore
# render method also requests the input data
output_data.render_output_ui(streamlit, input=self._input_data) # type: ignore
else:
output_data.render_output_ui(streamlit) # type: ignore
return
except Exception:
# Use default auto-generation methods if the custom rendering throws an exception
logger.exception(
"Failed to execute custom render_output_ui function. Using auto-generation instead"
)
model_schema = output_data.schema(by_alias=False)
model_properties = model_schema.get("properties")
definitions = model_schema.get("definitions")
if model_properties:
for property_key in output_data.__dict__:
property_schema = model_properties.get(property_key)
if not property_schema.get("title"):
# Set property key as fallback title
property_schema["title"] = property_key
output_property_value = output_data.__dict__[property_key]
if has_output_ui_renderer(output_property_value):
output_property_value.render_output_ui(streamlit) # type: ignore
continue
if isinstance(output_property_value, BaseModel):
# Render output recursivly
streamlit.subheader(property_schema.get("title"))
if property_schema.get("description"):
streamlit.markdown(property_schema.get("description"))
self._render_single_output(streamlit, output_property_value)
continue
if property_schema:
if schema_utils.is_single_file_property(property_schema):
self._render_single_file_property(
streamlit, property_schema, output_property_value
)
continue
if (
schema_utils.is_single_string_property(property_schema)
or schema_utils.is_single_number_property(property_schema)
or schema_utils.is_single_datetime_property(property_schema)
or schema_utils.is_single_boolean_property(property_schema)
):
self._render_single_text_property(
streamlit, property_schema, output_property_value
)
continue
if definitions and schema_utils.is_single_enum_property(
property_schema, definitions
):
self._render_single_text_property(
streamlit, property_schema, output_property_value.value
)
continue
# TODO: render dict as table
self._render_single_complex_property(
streamlit, property_schema, output_property_value
)
return
def _render_list_output(self, streamlit: st, output_data: List) -> None:
try:
data_items: List = []
for data_item in output_data:
if has_output_ui_renderer(data_item):
# Render using the render function
data_item.render_output_ui(streamlit) # type: ignore
continue
data_items.append(data_item.dict())
# Try to show as dataframe
streamlit.table(pd.DataFrame(data_items))
except Exception:
# Fallback to
streamlit.json(jsonable_encoder(output_data))
def getOpyrator(mode: str) -> Opyrator:
if mode == None or mode.startswith('VC'):
from control.mkgui.app_vc import convert
return Opyrator(convert)
if mode == None or mode.startswith('预处理'):
from control.mkgui.preprocess import preprocess
return Opyrator(preprocess)
if mode == None or mode.startswith('模型训练'):
from control.mkgui.train import train
return Opyrator(train)
if mode == None or mode.startswith('模型训练(VC)'):
from control.mkgui.train_vc import train_vc
return Opyrator(train_vc)
from control.mkgui.app import synthesize
return Opyrator(synthesize)
def render_streamlit_ui() -> None:
# init
session_state = st.session_state
session_state.input_data = {}
# Add custom css settings
st.markdown(f"<style>{CUSTOM_STREAMLIT_CSS}</style>", unsafe_allow_html=True)
with st.spinner("Loading MockingBird GUI. Please wait..."):
session_state.mode = st.sidebar.selectbox(
'模式选择',
( "AI拟音", "VC拟音", "预处理", "模型训练", "模型训练(VC)")
)
if "mode" in session_state:
mode = session_state.mode
else:
mode = ""
opyrator = getOpyrator(mode)
title = opyrator.name + mode
col1, col2, _ = st.columns(3)
col2.title(title)
col2.markdown("欢迎使用MockingBird Web 2")
image = Image.open(path.join('control','mkgui', 'static', 'mb.png'))
col1.image(image)
st.markdown("---")
left, right = st.columns([0.4, 0.6])
with left:
st.header("Control 控制")
InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(st)
execute_selected = st.button(opyrator.action)
if execute_selected:
with st.spinner("Executing operation. Please wait..."):
try:
input_data_obj = parse_obj_as(
opyrator.input_type, session_state.input_data
)
session_state.output_data = opyrator(input=input_data_obj)
session_state.latest_operation_input = input_data_obj # should this really be saved as additional session object?
except ValidationError as ex:
st.error(ex)
else:
# st.success("Operation executed successfully.")
pass
with right:
st.header("Result 结果")
if 'output_data' in session_state:
OutputUI(
session_state.output_data, session_state.latest_operation_input
).render_ui(st)
if st.button("Clear"):
# Clear all state
for key in st.session_state.keys():
del st.session_state[key]
session_state.input_data = {}
st.experimental_rerun()
else:
# placeholder
st.caption("请使用左侧控制板进行输入并运行获得结果")

View File

@@ -0,0 +1,13 @@
CUSTOM_STREAMLIT_CSS = """
div[data-testid="stBlock"] button {
width: 100% !important;
margin-bottom: 20px !important;
border-color: #bfbfbf !important;
}
section[data-testid="stSidebar"] div {
max-width: 10rem;
}
pre code {
white-space: pre-wrap;
}
"""

View File

@@ -0,0 +1,96 @@
from pydantic import BaseModel, Field
import os
from pathlib import Path
from enum import Enum
from typing import Any, Tuple
# Constants
EXT_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg_extractor"
ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
if os.path.isdir(EXT_MODELS_DIRT):
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
print("Loaded extractor models: " + str(len(extractors)))
else:
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
if os.path.isdir(ENC_MODELS_DIRT):
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
print("Loaded encoders models: " + str(len(encoders)))
else:
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
class Model(str, Enum):
VC_PPG2MEL = "ppg2mel"
class Dataset(str, Enum):
AIDATATANG_200ZH = "aidatatang_200zh"
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
class Input(BaseModel):
# def render_input_ui(st, input) -> Dict:
# input["selected_dataset"] = st.selectbox(
# '选择数据集',
# ("aidatatang_200zh", "aidatatang_200zh_s")
# )
# return input
model: Model = Field(
Model.VC_PPG2MEL, title="目标模型",
)
dataset: Dataset = Field(
Dataset.AIDATATANG_200ZH, title="数据集选择",
)
datasets_root: str = Field(
..., alias="数据集根目录", description="输入数据集根目录(相对/绝对)",
format=True,
example="..\\trainning_data\\"
)
output_root: str = Field(
..., alias="输出根目录", description="输出结果根目录(相对/绝对)",
format=True,
example="..\\trainning_data\\"
)
n_processes: int = Field(
2, alias="处理线程数", description="根据CPU线程数来设置",
le=32, ge=1
)
extractor: extractors = Field(
..., alias="特征提取模型",
description="选择PPG特征提取模型文件."
)
encoder: encoders = Field(
..., alias="语音编码模型",
description="选择语音编码模型文件."
)
class AudioEntity(BaseModel):
content: bytes
mel: Any
class Output(BaseModel):
__root__: Tuple[str, int]
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
"""Custom output UI.
If this method is implmeneted, it will be used instead of the default Output UI renderer.
"""
sr, count = self.__root__
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
def preprocess(input: Input) -> Output:
"""Preprocess(预处理)"""
finished = 0
if input.model == Model.VC_PPG2MEL:
from models.ppg2mel.preprocess import preprocess_dataset
finished = preprocess_dataset(
datasets_root=Path(input.datasets_root),
dataset=input.dataset,
out_dir=Path(input.output_root),
n_processes=input.n_processes,
ppg_encoder_model_fpath=Path(input.extractor.value),
speaker_encoder_model=Path(input.encoder.value)
)
# TODO: pass useful return code
return Output(__root__=(input.dataset, finished))

BIN
control/mkgui/static/mb.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.6 KiB

106
control/mkgui/train.py Normal file
View File

@@ -0,0 +1,106 @@
from pydantic import BaseModel, Field
import os
from pathlib import Path
from enum import Enum
from typing import Any
from models.synthesizer.hparams import hparams
from models.synthesizer.train import train as synt_train
# Constants
SYN_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}synthesizer"
ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
# EXT_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg_extractor"
# CONV_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg2mel"
# ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
# Pre-Load models
if os.path.isdir(SYN_MODELS_DIRT):
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
print("Loaded synthesizer models: " + str(len(synthesizers)))
else:
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.")
if os.path.isdir(ENC_MODELS_DIRT):
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
print("Loaded encoders models: " + str(len(encoders)))
else:
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
class Model(str, Enum):
DEFAULT = "default"
class Input(BaseModel):
model: Model = Field(
Model.DEFAULT, title="模型类型",
)
# datasets_root: str = Field(
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
# format=True,
# example="..\\trainning_data\\"
# )
input_root: str = Field(
..., alias="输入目录", description="预处理数据根目录",
format=True,
example=f"..{os.sep}audiodata{os.sep}SV2TTS{os.sep}synthesizer"
)
run_id: str = Field(
"", alias="新模型名/运行ID", description="使用新ID进行重新训练否则选择下面的模型进行继续训练",
)
synthesizer: synthesizers = Field(
..., alias="已有合成模型",
description="选择语音合成模型文件."
)
gpu: bool = Field(
True, alias="GPU训练", description="选择“是”则使用GPU训练",
)
verbose: bool = Field(
True, alias="打印详情", description="选择“是”,输出更多详情",
)
encoder: encoders = Field(
..., alias="语音编码模型",
description="选择语音编码模型文件."
)
save_every: int = Field(
1000, alias="更新间隔", description="每隔n步则更新一次模型",
)
backup_every: int = Field(
10000, alias="保存间隔", description="每隔n步则保存一次模型",
)
log_every: int = Field(
500, alias="打印间隔", description="每隔n步则打印一次训练统计",
)
class AudioEntity(BaseModel):
content: bytes
mel: Any
class Output(BaseModel):
__root__: int
def render_output_ui(self, streamlit_app) -> None: # type: ignore
"""Custom output UI.
If this method is implmeneted, it will be used instead of the default Output UI renderer.
"""
streamlit_app.subheader(f"Training started with code: {self.__root__}")
def train(input: Input) -> Output:
"""Train(训练)"""
print(">>> Start training ...")
force_restart = len(input.run_id) > 0
if not force_restart:
input.run_id = Path(input.synthesizer.value).name.split('.')[0]
synt_train(
input.run_id,
input.input_root,
f"data{os.sep}ckpt{os.sep}synthesizer",
input.save_every,
input.backup_every,
input.log_every,
force_restart,
hparams
)
return Output(__root__=0)

155
control/mkgui/train_vc.py Normal file
View File

@@ -0,0 +1,155 @@
from pydantic import BaseModel, Field
import os
from pathlib import Path
from enum import Enum
from typing import Any, Tuple
import numpy as np
from utils.load_yaml import HpsYaml
from utils.util import AttrDict
import torch
# Constants
EXT_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg_extractor"
CONV_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg2mel"
ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
if os.path.isdir(EXT_MODELS_DIRT):
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
print("Loaded extractor models: " + str(len(extractors)))
else:
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
if os.path.isdir(CONV_MODELS_DIRT):
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
print("Loaded convertor models: " + str(len(convertors)))
else:
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
if os.path.isdir(ENC_MODELS_DIRT):
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
print("Loaded encoders models: " + str(len(encoders)))
else:
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
class Model(str, Enum):
VC_PPG2MEL = "ppg2mel"
class Dataset(str, Enum):
AIDATATANG_200ZH = "aidatatang_200zh"
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
class Input(BaseModel):
# def render_input_ui(st, input) -> Dict:
# input["selected_dataset"] = st.selectbox(
# '选择数据集',
# ("aidatatang_200zh", "aidatatang_200zh_s")
# )
# return input
model: Model = Field(
Model.VC_PPG2MEL, title="模型类型",
)
# datasets_root: str = Field(
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
# format=True,
# example="..\\trainning_data\\"
# )
output_root: str = Field(
..., alias="输出目录(可选)", description="建议不填,保持默认",
format=True,
example=""
)
continue_mode: bool = Field(
True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练",
)
gpu: bool = Field(
True, alias="GPU训练", description="选择“是”则使用GPU训练",
)
verbose: bool = Field(
True, alias="打印详情", description="选择“是”,输出更多详情",
)
# TODO: Move to hiden fields by default
convertor: convertors = Field(
..., alias="转换模型",
description="选择语音转换模型文件."
)
extractor: extractors = Field(
..., alias="特征提取模型",
description="选择PPG特征提取模型文件."
)
encoder: encoders = Field(
..., alias="语音编码模型",
description="选择语音编码模型文件."
)
njobs: int = Field(
8, alias="进程数", description="适用于ppg2mel",
)
seed: int = Field(
default=0, alias="初始随机数", description="适用于ppg2mel",
)
model_name: str = Field(
..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效",
example="test"
)
model_config: str = Field(
..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效",
example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2"
)
class AudioEntity(BaseModel):
content: bytes
mel: Any
class Output(BaseModel):
__root__: Tuple[str, int]
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
"""Custom output UI.
If this method is implmeneted, it will be used instead of the default Output UI renderer.
"""
sr, count = self.__root__
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
def train_vc(input: Input) -> Output:
"""Train VC(训练 VC)"""
print(">>> OneShot VC training ...")
params = AttrDict()
params.update({
"gpu": input.gpu,
"cpu": not input.gpu,
"njobs": input.njobs,
"seed": input.seed,
"verbose": input.verbose,
"load": input.convertor.value,
"warm_start": False,
})
if input.continue_mode:
# trace old model and config
p = Path(input.convertor.value)
params.name = p.parent.name
# search a config file
model_config_fpaths = list(p.parent.rglob("*.yaml"))
if len(model_config_fpaths) == 0:
raise "No model yaml config found for convertor"
config = HpsYaml(model_config_fpaths[0])
params.ckpdir = p.parent.parent
params.config = model_config_fpaths[0]
params.logdir = os.path.join(p.parent, "log")
else:
# Make the config dict dot visitable
config = HpsYaml(input.config)
np.random.seed(input.seed)
torch.manual_seed(input.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(input.seed)
mode = "train"
from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
solver = Solver(config, params, mode)
solver.load_data()
solver.set_model()
solver.exec()
print(">>> Oneshot VC train finished!")
# TODO: pass useful return code
return Output(__root__=(input.dataset, 0))

475
control/toolbox/__init__.py Normal file
View File

@@ -0,0 +1,475 @@
from control.toolbox.ui import UI
from models.encoder import inference as encoder
from models.synthesizer.inference import Synthesizer
from models.vocoder.wavernn import inference as rnn_vocoder
from models.vocoder.hifigan import inference as gan_vocoder
from models.vocoder.fregan import inference as fgan_vocoder
from pathlib import Path
from time import perf_counter as timer
from control.toolbox.utterance import Utterance
import numpy as np
import traceback
import sys
import torch
import re
# 默认使用wavernn
vocoder = rnn_vocoder
# Use this directory structure for your datasets, or modify it to fit your needs
recognized_datasets = [
"LibriSpeech/dev-clean",
"LibriSpeech/dev-other",
"LibriSpeech/test-clean",
"LibriSpeech/test-other",
"LibriSpeech/train-clean-100",
"LibriSpeech/train-clean-360",
"LibriSpeech/train-other-500",
"LibriTTS/dev-clean",
"LibriTTS/dev-other",
"LibriTTS/test-clean",
"LibriTTS/test-other",
"LibriTTS/train-clean-100",
"LibriTTS/train-clean-360",
"LibriTTS/train-other-500",
"LJSpeech-1.1",
"VoxCeleb1/wav",
"VoxCeleb1/test_wav",
"VoxCeleb2/dev/aac",
"VoxCeleb2/test/aac",
"VCTK-Corpus/wav48",
"aidatatang_200zh/corpus",
"aishell3/test/wav",
"magicdata/train",
]
#Maximum of generated wavs to keep on memory
MAX_WAVES = 15
class Toolbox:
def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, extractor_models_dir, convertor_models_dir, seed, no_mp3_support, vc_mode):
self.no_mp3_support = no_mp3_support
self.vc_mode = vc_mode
sys.excepthook = self.excepthook
self.datasets_root = datasets_root
self.utterances = set()
self.current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav
self.synthesizer = None # type: Synthesizer
# for ppg-based voice conversion
self.extractor = None
self.convertor = None # ppg2mel
self.current_wav = None
self.waves_list = []
self.waves_count = 0
self.waves_namelist = []
# Check for webrtcvad (enables removal of silences in vocoder output)
try:
import webrtcvad
self.trim_silences = True
except:
self.trim_silences = False
# Initialize the events and the interface
self.ui = UI(vc_mode)
self.style_idx = 0
self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, extractor_models_dir, convertor_models_dir, seed)
self.setup_events()
self.ui.start()
def excepthook(self, exc_type, exc_value, exc_tb):
traceback.print_exception(exc_type, exc_value, exc_tb)
self.ui.log("Exception: %s" % exc_value)
def setup_events(self):
# Dataset, speaker and utterance selection
self.ui.browser_load_button.clicked.connect(lambda: self.load_from_browser())
random_func = lambda level: lambda: self.ui.populate_browser(self.datasets_root,
recognized_datasets,
level)
self.ui.random_dataset_button.clicked.connect(random_func(0))
self.ui.random_speaker_button.clicked.connect(random_func(1))
self.ui.random_utterance_button.clicked.connect(random_func(2))
self.ui.dataset_box.currentIndexChanged.connect(random_func(1))
self.ui.speaker_box.currentIndexChanged.connect(random_func(2))
# Model selection
self.ui.encoder_box.currentIndexChanged.connect(self.init_encoder)
def func():
self.synthesizer = None
if self.vc_mode:
self.ui.extractor_box.currentIndexChanged.connect(self.init_extractor)
else:
self.ui.synthesizer_box.currentIndexChanged.connect(func)
self.ui.vocoder_box.currentIndexChanged.connect(self.init_vocoder)
# Utterance selection
func = lambda: self.load_from_browser(self.ui.browse_file())
self.ui.browser_browse_button.clicked.connect(func)
func = lambda: self.ui.draw_utterance(self.ui.selected_utterance, "current")
self.ui.utterance_history.currentIndexChanged.connect(func)
func = lambda: self.ui.play(self.ui.selected_utterance.wav, Synthesizer.sample_rate)
self.ui.play_button.clicked.connect(func)
self.ui.stop_button.clicked.connect(self.ui.stop)
self.ui.record_button.clicked.connect(self.record)
# Source Utterance selection
if self.vc_mode:
func = lambda: self.load_soruce_button(self.ui.selected_utterance)
self.ui.load_soruce_button.clicked.connect(func)
#Audio
self.ui.setup_audio_devices(Synthesizer.sample_rate)
#Wav playback & save
func = lambda: self.replay_last_wav()
self.ui.replay_wav_button.clicked.connect(func)
func = lambda: self.export_current_wave()
self.ui.export_wav_button.clicked.connect(func)
self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
# Generation
self.ui.vocode_button.clicked.connect(self.vocode)
self.ui.random_seed_checkbox.clicked.connect(self.update_seed_textbox)
if self.vc_mode:
func = lambda: self.convert() or self.vocode()
self.ui.convert_button.clicked.connect(func)
else:
func = lambda: self.synthesize() or self.vocode()
self.ui.generate_button.clicked.connect(func)
self.ui.synthesize_button.clicked.connect(self.synthesize)
# UMAP legend
self.ui.clear_button.clicked.connect(self.clear_utterances)
def set_current_wav(self, index):
self.current_wav = self.waves_list[index]
def export_current_wave(self):
self.ui.save_audio_file(self.current_wav, Synthesizer.sample_rate)
def replay_last_wav(self):
self.ui.play(self.current_wav, Synthesizer.sample_rate)
def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, extractor_models_dir, convertor_models_dir, seed):
self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True)
self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, extractor_models_dir, convertor_models_dir, self.vc_mode)
self.ui.populate_gen_options(seed, self.trim_silences)
def load_from_browser(self, fpath=None):
if fpath is None:
fpath = Path(self.datasets_root,
self.ui.current_dataset_name,
self.ui.current_speaker_name,
self.ui.current_utterance_name)
name = str(fpath.relative_to(self.datasets_root))
speaker_name = self.ui.current_dataset_name + '_' + self.ui.current_speaker_name
# Select the next utterance
if self.ui.auto_next_checkbox.isChecked():
self.ui.browser_select_next()
elif fpath == "":
return
else:
name = fpath.name
speaker_name = fpath.parent.name
if fpath.suffix.lower() == ".mp3" and self.no_mp3_support:
self.ui.log("Error: No mp3 file argument was passed but an mp3 file was used")
return
# Get the wav from the disk. We take the wav with the vocoder/synthesizer format for
# playback, so as to have a fair comparison with the generated audio
wav = Synthesizer.load_preprocess_wav(fpath)
self.ui.log("Loaded %s" % name)
self.add_real_utterance(wav, name, speaker_name)
def load_soruce_button(self, utterance: Utterance):
self.selected_source_utterance = utterance
def record(self):
wav = self.ui.record_one(encoder.sampling_rate, 5)
if wav is None:
return
self.ui.play(wav, encoder.sampling_rate)
speaker_name = "user01"
name = speaker_name + "_rec_%05d" % np.random.randint(100000)
self.add_real_utterance(wav, name, speaker_name)
def add_real_utterance(self, wav, name, speaker_name):
# Compute the mel spectrogram
spec = Synthesizer.make_spectrogram(wav)
self.ui.draw_spec(spec, "current")
# Compute the embedding
if not encoder.is_loaded():
self.init_encoder()
encoder_wav = encoder.preprocess_wav(wav)
embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
# Add the utterance
utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, False)
self.utterances.add(utterance)
self.ui.register_utterance(utterance, self.vc_mode)
# Plot it
self.ui.draw_embed(embed, name, "current")
self.ui.draw_umap_projections(self.utterances)
def clear_utterances(self):
self.utterances.clear()
self.ui.draw_umap_projections(self.utterances)
def synthesize(self):
self.ui.log("Generating the mel spectrogram...")
self.ui.set_loading(1)
# Update the synthesizer random seed
if self.ui.random_seed_checkbox.isChecked():
seed = int(self.ui.seed_textbox.text())
self.ui.populate_gen_options(seed, self.trim_silences)
else:
seed = None
if seed is not None:
torch.manual_seed(seed)
# Synthesize the spectrogram
if self.synthesizer is None or seed is not None:
self.init_synthesizer()
texts = self.ui.text_prompt.toPlainText().split("\n")
punctuation = '!,。、,' # punctuate and split/clean text
processed_texts = []
for text in texts:
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
if processed_text:
processed_texts.append(processed_text.strip())
texts = processed_texts
embed = self.ui.selected_utterance.embed
embeds = [embed] * len(texts)
min_token = int(self.ui.token_slider.value())
specs = self.synthesizer.synthesize_spectrograms(texts, embeds, style_idx=int(self.ui.style_slider.value()), min_stop_token=min_token, steps=int(self.ui.length_slider.value())*200)
breaks = [spec.shape[1] for spec in specs]
spec = np.concatenate(specs, axis=1)
self.ui.draw_spec(spec, "generated")
self.current_generated = (self.ui.selected_utterance.speaker_name, spec, breaks, None)
self.ui.set_loading(0)
def vocode(self):
speaker_name, spec, breaks, _ = self.current_generated
assert spec is not None
# Initialize the vocoder model and make it determinstic, if user provides a seed
if self.ui.random_seed_checkbox.isChecked():
seed = int(self.ui.seed_textbox.text())
self.ui.populate_gen_options(seed, self.trim_silences)
else:
seed = None
if seed is not None:
torch.manual_seed(seed)
# Synthesize the waveform
if not vocoder.is_loaded() or seed is not None:
self.init_vocoder()
def vocoder_progress(i, seq_len, b_size, gen_rate):
real_time_factor = (gen_rate / Synthesizer.sample_rate) * 1000
line = "Waveform generation: %d/%d (batch size: %d, rate: %.1fkHz - %.2fx real time)" \
% (i * b_size, seq_len * b_size, b_size, gen_rate, real_time_factor)
self.ui.log(line, "overwrite")
self.ui.set_loading(i, seq_len)
if self.ui.current_vocoder_fpath is not None:
self.ui.log("")
wav, sample_rate = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
else:
self.ui.log("Waveform generation with Griffin-Lim... ")
wav = Synthesizer.griffin_lim(spec)
self.ui.set_loading(0)
self.ui.log(" Done!", "append")
# Add breaks
b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
b_starts = np.concatenate(([0], b_ends[:-1]))
wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)]
breaks = [np.zeros(int(0.15 * sample_rate))] * len(breaks)
wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
# Trim excessive silences
if self.ui.trim_silences_checkbox.isChecked():
wav = encoder.preprocess_wav(wav)
# Play it
wav = wav / np.abs(wav).max() * 0.97
self.ui.play(wav, sample_rate)
# Name it (history displayed in combobox)
# TODO better naming for the combobox items?
wav_name = str(self.waves_count + 1)
#Update waves combobox
self.waves_count += 1
if self.waves_count > MAX_WAVES:
self.waves_list.pop()
self.waves_namelist.pop()
self.waves_list.insert(0, wav)
self.waves_namelist.insert(0, wav_name)
self.ui.waves_cb.disconnect()
self.ui.waves_cb_model.setStringList(self.waves_namelist)
self.ui.waves_cb.setCurrentIndex(0)
self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
# Update current wav
self.set_current_wav(0)
#Enable replay and save buttons:
self.ui.replay_wav_button.setDisabled(False)
self.ui.export_wav_button.setDisabled(False)
# Compute the embedding
# TODO: this is problematic with different sampling rates, gotta fix it
if not encoder.is_loaded():
self.init_encoder()
encoder_wav = encoder.preprocess_wav(wav)
embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
# Add the utterance
name = speaker_name + "_gen_%05d" % np.random.randint(100000)
utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, True)
self.utterances.add(utterance)
# Plot it
self.ui.draw_embed(embed, name, "generated")
self.ui.draw_umap_projections(self.utterances)
def convert(self):
self.ui.log("Extract PPG and Converting...")
self.ui.set_loading(1)
# Init
if self.convertor is None:
self.init_convertor()
if self.extractor is None:
self.init_extractor()
src_wav = self.selected_source_utterance.wav
# Compute the ppg
if not self.extractor is None:
ppg = self.extractor.extract_from_wav(src_wav)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ref_wav = self.ui.selected_utterance.wav
# Import necessary dependency of Voice Conversion
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
min_len = min(ppg.shape[1], len(lf0_uv))
ppg = ppg[:, :min_len]
lf0_uv = lf0_uv[:min_len]
_, mel_pred, att_ws = self.convertor.inference(
ppg,
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
spembs=torch.from_numpy(self.ui.selected_utterance.embed).unsqueeze(0).to(device),
)
mel_pred= mel_pred.transpose(0, 1)
breaks = [mel_pred.shape[1]]
mel_pred= mel_pred.detach().cpu().numpy()
self.ui.draw_spec(mel_pred, "generated")
self.current_generated = (self.ui.selected_utterance.speaker_name, mel_pred, breaks, None)
self.ui.set_loading(0)
def init_extractor(self):
if self.ui.current_extractor_fpath is None:
return
model_fpath = self.ui.current_extractor_fpath
self.ui.log("Loading the extractor %s... " % model_fpath)
self.ui.set_loading(1)
start = timer()
import models.ppg_extractor as extractor
self.extractor = extractor.load_model(model_fpath)
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
self.ui.set_loading(0)
def init_convertor(self):
if self.ui.current_convertor_fpath is None:
return
model_fpath = self.ui.current_convertor_fpath
self.ui.log("Loading the convertor %s... " % model_fpath)
self.ui.set_loading(1)
start = timer()
import models.ppg2mel as convertor
self.convertor = convertor.load_model( model_fpath)
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
self.ui.set_loading(0)
def init_encoder(self):
model_fpath = self.ui.current_encoder_fpath
self.ui.log("Loading the encoder %s... " % model_fpath)
self.ui.set_loading(1)
start = timer()
encoder.load_model(model_fpath)
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
self.ui.set_loading(0)
def init_synthesizer(self):
model_fpath = self.ui.current_synthesizer_fpath
self.ui.log("Loading the synthesizer %s... " % model_fpath)
self.ui.set_loading(1)
start = timer()
self.synthesizer = Synthesizer(model_fpath)
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
self.ui.set_loading(0)
def init_vocoder(self):
global vocoder
model_fpath = self.ui.current_vocoder_fpath
# Case of Griffin-lim
if model_fpath is None:
return
# Sekect vocoder based on model name
model_config_fpath = None
if model_fpath.name is not None and model_fpath.name.find("hifigan") > -1:
vocoder = gan_vocoder
self.ui.log("set hifigan as vocoder")
# search a config file
model_config_fpaths = list(model_fpath.parent.rglob("*.json"))
if self.vc_mode and self.ui.current_extractor_fpath is None:
return
if len(model_config_fpaths) > 0:
model_config_fpath = model_config_fpaths[0]
elif model_fpath.name is not None and model_fpath.name.find("fregan") > -1:
vocoder = fgan_vocoder
self.ui.log("set fregan as vocoder")
# search a config file
model_config_fpaths = list(model_fpath.parent.rglob("*.json"))
if self.vc_mode and self.ui.current_extractor_fpath is None:
return
if len(model_config_fpaths) > 0:
model_config_fpath = model_config_fpaths[0]
else:
vocoder = rnn_vocoder
self.ui.log("set wavernn as vocoder")
self.ui.log("Loading the vocoder %s... " % model_fpath)
self.ui.set_loading(1)
start = timer()
vocoder.load_model(model_fpath, model_config_fpath)
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
self.ui.set_loading(0)
def update_seed_textbox(self):
self.ui.update_seed_textbox()

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.6 KiB

701
control/toolbox/ui.py Normal file
View File

@@ -0,0 +1,701 @@
from PyQt5.QtCore import Qt, QStringListModel
from PyQt5 import QtGui
from PyQt5.QtWidgets import *
import matplotlib.pyplot as plt
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from models.encoder.inference import plot_embedding_as_heatmap
from control.toolbox.utterance import Utterance
from pathlib import Path
from typing import List, Set
import sounddevice as sd
import soundfile as sf
import numpy as np
# from sklearn.manifold import TSNE # You can try with TSNE if you like, I prefer UMAP
from time import sleep
import umap
import sys
from warnings import filterwarnings, warn
filterwarnings("ignore")
colormap = np.array([
[0, 127, 70],
[255, 0, 0],
[255, 217, 38],
[0, 135, 255],
[165, 0, 165],
[255, 167, 255],
[97, 142, 151],
[0, 255, 255],
[255, 96, 38],
[142, 76, 0],
[33, 0, 127],
[0, 0, 0],
[183, 183, 183],
[76, 255, 0],
], dtype=np.float) / 255
default_text = \
"欢迎使用工具箱, 现已支持中文输入!"
class UI(QDialog):
min_umap_points = 4
max_log_lines = 5
max_saved_utterances = 20
def draw_utterance(self, utterance: Utterance, which):
self.draw_spec(utterance.spec, which)
self.draw_embed(utterance.embed, utterance.name, which)
def draw_embed(self, embed, name, which):
embed_ax, _ = self.current_ax if which == "current" else self.gen_ax
embed_ax.figure.suptitle("" if embed is None else name)
## Embedding
# Clear the plot
if len(embed_ax.images) > 0:
embed_ax.images[0].colorbar.remove()
embed_ax.clear()
# Draw the embed
if embed is not None:
plot_embedding_as_heatmap(embed, embed_ax)
embed_ax.set_title("embedding")
embed_ax.set_aspect("equal", "datalim")
embed_ax.set_xticks([])
embed_ax.set_yticks([])
embed_ax.figure.canvas.draw()
def draw_spec(self, spec, which):
_, spec_ax = self.current_ax if which == "current" else self.gen_ax
## Spectrogram
# Draw the spectrogram
spec_ax.clear()
if spec is not None:
im = spec_ax.imshow(spec, aspect="auto", interpolation="none")
# spec_ax.figure.colorbar(mappable=im, shrink=0.65, orientation="horizontal",
# spec_ax=spec_ax)
spec_ax.set_title("mel spectrogram")
spec_ax.set_xticks([])
spec_ax.set_yticks([])
spec_ax.figure.canvas.draw()
if which != "current":
self.vocode_button.setDisabled(spec is None)
def draw_umap_projections(self, utterances: Set[Utterance]):
self.umap_ax.clear()
speakers = np.unique([u.speaker_name for u in utterances])
colors = {speaker_name: colormap[i] for i, speaker_name in enumerate(speakers)}
embeds = [u.embed for u in utterances]
# Display a message if there aren't enough points
if len(utterances) < self.min_umap_points:
self.umap_ax.text(.5, .5, "Add %d more points to\ngenerate the projections" %
(self.min_umap_points - len(utterances)),
horizontalalignment='center', fontsize=15)
self.umap_ax.set_title("")
# Compute the projections
else:
if not self.umap_hot:
self.log(
"Drawing UMAP projections for the first time, this will take a few seconds.")
self.umap_hot = True
reducer = umap.UMAP(int(np.ceil(np.sqrt(len(embeds)))), metric="cosine")
# reducer = TSNE()
projections = reducer.fit_transform(embeds)
speakers_done = set()
for projection, utterance in zip(projections, utterances):
color = colors[utterance.speaker_name]
mark = "x" if "_gen_" in utterance.name else "o"
label = None if utterance.speaker_name in speakers_done else utterance.speaker_name
speakers_done.add(utterance.speaker_name)
self.umap_ax.scatter(projection[0], projection[1], c=[color], marker=mark,
label=label)
# self.umap_ax.set_title("UMAP projections")
self.umap_ax.legend(prop={'size': 10})
# Draw the plot
self.umap_ax.set_aspect("equal", "datalim")
self.umap_ax.set_xticks([])
self.umap_ax.set_yticks([])
self.umap_ax.figure.canvas.draw()
def save_audio_file(self, wav, sample_rate):
dialog = QFileDialog()
dialog.setDefaultSuffix(".wav")
fpath, _ = dialog.getSaveFileName(
parent=self,
caption="Select a path to save the audio file",
filter="Audio Files (*.flac *.wav)"
)
if fpath:
#Default format is wav
if Path(fpath).suffix == "":
fpath += ".wav"
sf.write(fpath, wav, sample_rate)
def setup_audio_devices(self, sample_rate):
input_devices = []
output_devices = []
for device in sd.query_devices():
# Check if valid input
try:
sd.check_input_settings(device=device["name"], samplerate=sample_rate)
input_devices.append(device["name"])
except:
pass
# Check if valid output
try:
sd.check_output_settings(device=device["name"], samplerate=sample_rate)
output_devices.append(device["name"])
except Exception as e:
# Log a warning only if the device is not an input
if not device["name"] in input_devices:
warn("Unsupported output device %s for the sample rate: %d \nError: %s" % (device["name"], sample_rate, str(e)))
if len(input_devices) == 0:
self.log("No audio input device detected. Recording may not work.")
self.audio_in_device = None
else:
self.audio_in_device = input_devices[0]
if len(output_devices) == 0:
self.log("No supported output audio devices were found! Audio output may not work.")
self.audio_out_devices_cb.addItems(["None"])
self.audio_out_devices_cb.setDisabled(True)
else:
self.audio_out_devices_cb.clear()
self.audio_out_devices_cb.addItems(output_devices)
self.audio_out_devices_cb.currentTextChanged.connect(self.set_audio_device)
self.set_audio_device()
def set_audio_device(self):
output_device = self.audio_out_devices_cb.currentText()
if output_device == "None":
output_device = None
# If None, sounddevice queries portaudio
sd.default.device = (self.audio_in_device, output_device)
def play(self, wav, sample_rate):
try:
sd.stop()
sd.play(wav, sample_rate)
except Exception as e:
print(e)
self.log("Error in audio playback. Try selecting a different audio output device.")
self.log("Your device must be connected before you start the toolbox.")
def stop(self):
sd.stop()
def record_one(self, sample_rate, duration):
self.record_button.setText("Recording...")
self.record_button.setDisabled(True)
self.log("Recording %d seconds of audio" % duration)
sd.stop()
try:
wav = sd.rec(duration * sample_rate, sample_rate, 1)
except Exception as e:
print(e)
self.log("Could not record anything. Is your recording device enabled?")
self.log("Your device must be connected before you start the toolbox.")
return None
for i in np.arange(0, duration, 0.1):
self.set_loading(i, duration)
sleep(0.1)
self.set_loading(duration, duration)
sd.wait()
self.log("Done recording.")
self.record_button.setText("Record")
self.record_button.setDisabled(False)
return wav.squeeze()
@property
def current_dataset_name(self):
return self.dataset_box.currentText()
@property
def current_speaker_name(self):
return self.speaker_box.currentText()
@property
def current_utterance_name(self):
return self.utterance_box.currentText()
def browse_file(self):
fpath = QFileDialog().getOpenFileName(
parent=self,
caption="Select an audio file",
filter="Audio Files (*.mp3 *.flac *.wav *.m4a)"
)
return Path(fpath[0]) if fpath[0] != "" else ""
@staticmethod
def repopulate_box(box, items, random=False):
"""
Resets a box and adds a list of items. Pass a list of (item, data) pairs instead to join
data to the items
"""
box.blockSignals(True)
box.clear()
for item in items:
item = list(item) if isinstance(item, tuple) else [item]
box.addItem(str(item[0]), *item[1:])
if len(items) > 0:
box.setCurrentIndex(np.random.randint(len(items)) if random else 0)
box.setDisabled(len(items) == 0)
box.blockSignals(False)
def populate_browser(self, datasets_root: Path, recognized_datasets: List, level: int,
random=True):
# Select a random dataset
if level <= 0:
if datasets_root is not None:
datasets = [datasets_root.joinpath(d) for d in recognized_datasets]
datasets = [d.relative_to(datasets_root) for d in datasets if d.exists()]
self.browser_load_button.setDisabled(len(datasets) == 0)
if datasets_root is None or len(datasets) == 0:
msg = "Warning: you d" + ("id not pass a root directory for datasets as argument" \
if datasets_root is None else "o not have any of the recognized datasets" \
" in %s \n" \
"Please note use 'E:\datasets' as root path " \
"instead of 'E:\datasets\aidatatang_200zh\corpus\test' as an example " % datasets_root)
self.log(msg)
msg += ".\nThe recognized datasets are:\n\t%s\nFeel free to add your own. You " \
"can still use the toolbox by recording samples yourself." % \
("\n\t".join(recognized_datasets))
print(msg, file=sys.stderr)
self.random_utterance_button.setDisabled(True)
self.random_speaker_button.setDisabled(True)
self.random_dataset_button.setDisabled(True)
self.utterance_box.setDisabled(True)
self.speaker_box.setDisabled(True)
self.dataset_box.setDisabled(True)
self.browser_load_button.setDisabled(True)
self.auto_next_checkbox.setDisabled(True)
return
self.repopulate_box(self.dataset_box, datasets, random)
# Select a random speaker
if level <= 1:
speakers_root = datasets_root.joinpath(self.current_dataset_name)
speaker_names = [d.stem for d in speakers_root.glob("*") if d.is_dir()]
self.repopulate_box(self.speaker_box, speaker_names, random)
# Select a random utterance
if level <= 2:
utterances_root = datasets_root.joinpath(
self.current_dataset_name,
self.current_speaker_name
)
utterances = []
for extension in ['mp3', 'flac', 'wav', 'm4a']:
utterances.extend(Path(utterances_root).glob("**/*.%s" % extension))
utterances = [fpath.relative_to(utterances_root) for fpath in utterances]
self.repopulate_box(self.utterance_box, utterances, random)
def browser_select_next(self):
index = (self.utterance_box.currentIndex() + 1) % len(self.utterance_box)
self.utterance_box.setCurrentIndex(index)
@property
def current_encoder_fpath(self):
return self.encoder_box.itemData(self.encoder_box.currentIndex())
@property
def current_synthesizer_fpath(self):
return self.synthesizer_box.itemData(self.synthesizer_box.currentIndex())
@property
def current_vocoder_fpath(self):
return self.vocoder_box.itemData(self.vocoder_box.currentIndex())
@property
def current_extractor_fpath(self):
return self.extractor_box.itemData(self.extractor_box.currentIndex())
@property
def current_convertor_fpath(self):
return self.convertor_box.itemData(self.convertor_box.currentIndex())
def populate_models(self, encoder_models_dir: Path, synthesizer_models_dir: Path,
vocoder_models_dir: Path, extractor_models_dir: Path, convertor_models_dir: Path, vc_mode: bool):
# Encoder
encoder_fpaths = list(encoder_models_dir.glob("*.pt"))
if len(encoder_fpaths) == 0:
raise Exception("No encoder models found in %s" % encoder_models_dir)
self.repopulate_box(self.encoder_box, [(f.stem, f) for f in encoder_fpaths])
if vc_mode:
# Extractor
extractor_fpaths = list(extractor_models_dir.glob("*.pt"))
if len(extractor_fpaths) == 0:
self.log("No extractor models found in %s" % extractor_fpaths)
self.repopulate_box(self.extractor_box, [(f.stem, f) for f in extractor_fpaths])
# Convertor
convertor_fpaths = list(convertor_models_dir.glob("*.pth"))
if len(convertor_fpaths) == 0:
self.log("No convertor models found in %s" % convertor_fpaths)
self.repopulate_box(self.convertor_box, [(f.stem, f) for f in convertor_fpaths])
else:
# Synthesizer
synthesizer_fpaths = list(synthesizer_models_dir.glob("**/*.pt"))
if len(synthesizer_fpaths) == 0:
raise Exception("No synthesizer models found in %s" % synthesizer_models_dir)
self.repopulate_box(self.synthesizer_box, [(f.stem, f) for f in synthesizer_fpaths])
# Vocoder
vocoder_fpaths = list(vocoder_models_dir.glob("**/*.pt"))
vocoder_items = [(f.stem, f) for f in vocoder_fpaths] + [("Griffin-Lim", None)]
self.repopulate_box(self.vocoder_box, vocoder_items)
@property
def selected_utterance(self):
return self.utterance_history.itemData(self.utterance_history.currentIndex())
def register_utterance(self, utterance: Utterance, vc_mode):
self.utterance_history.blockSignals(True)
self.utterance_history.insertItem(0, utterance.name, utterance)
self.utterance_history.setCurrentIndex(0)
self.utterance_history.blockSignals(False)
if len(self.utterance_history) > self.max_saved_utterances:
self.utterance_history.removeItem(self.max_saved_utterances)
self.play_button.setDisabled(False)
if vc_mode:
self.convert_button.setDisabled(False)
else:
self.generate_button.setDisabled(False)
self.synthesize_button.setDisabled(False)
def log(self, line, mode="newline"):
if mode == "newline":
self.logs.append(line)
if len(self.logs) > self.max_log_lines:
del self.logs[0]
elif mode == "append":
self.logs[-1] += line
elif mode == "overwrite":
self.logs[-1] = line
log_text = '\n'.join(self.logs)
self.log_window.setText(log_text)
self.app.processEvents()
def set_loading(self, value, maximum=1):
self.loading_bar.setValue(value * 100)
self.loading_bar.setMaximum(maximum * 100)
self.loading_bar.setTextVisible(value != 0)
self.app.processEvents()
def populate_gen_options(self, seed, trim_silences):
if seed is not None:
self.random_seed_checkbox.setChecked(True)
self.seed_textbox.setText(str(seed))
self.seed_textbox.setEnabled(True)
else:
self.random_seed_checkbox.setChecked(False)
self.seed_textbox.setText(str(0))
self.seed_textbox.setEnabled(False)
if not trim_silences:
self.trim_silences_checkbox.setChecked(False)
self.trim_silences_checkbox.setDisabled(True)
def update_seed_textbox(self):
if self.random_seed_checkbox.isChecked():
self.seed_textbox.setEnabled(True)
else:
self.seed_textbox.setEnabled(False)
def reset_interface(self, vc_mode):
self.draw_embed(None, None, "current")
self.draw_embed(None, None, "generated")
self.draw_spec(None, "current")
self.draw_spec(None, "generated")
self.draw_umap_projections(set())
self.set_loading(0)
self.play_button.setDisabled(True)
if vc_mode:
self.convert_button.setDisabled(True)
else:
self.generate_button.setDisabled(True)
self.synthesize_button.setDisabled(True)
self.vocode_button.setDisabled(True)
self.replay_wav_button.setDisabled(True)
self.export_wav_button.setDisabled(True)
[self.log("") for _ in range(self.max_log_lines)]
def __init__(self, vc_mode):
## Initialize the application
self.app = QApplication(sys.argv)
super().__init__(None)
self.setWindowTitle("MockingBird GUI")
self.setWindowIcon(QtGui.QIcon('toolbox\\assets\\mb.png'))
self.setWindowFlag(Qt.WindowMinimizeButtonHint, True)
self.setWindowFlag(Qt.WindowMaximizeButtonHint, True)
## Main layouts
# Root
root_layout = QGridLayout()
self.setLayout(root_layout)
# Browser
browser_layout = QGridLayout()
root_layout.addLayout(browser_layout, 0, 0, 1, 8)
# Generation
gen_layout = QVBoxLayout()
root_layout.addLayout(gen_layout, 0, 8)
# Visualizations
vis_layout = QVBoxLayout()
root_layout.addLayout(vis_layout, 1, 0, 2, 8)
# Output
output_layout = QGridLayout()
vis_layout.addLayout(output_layout, 0)
# Projections
self.projections_layout = QVBoxLayout()
root_layout.addLayout(self.projections_layout, 1, 8, 2, 2)
## Projections
# UMap
fig, self.umap_ax = plt.subplots(figsize=(3, 3), facecolor="#F0F0F0")
fig.subplots_adjust(left=0.02, bottom=0.02, right=0.98, top=0.98)
self.projections_layout.addWidget(FigureCanvas(fig))
self.umap_hot = False
self.clear_button = QPushButton("Clear")
self.projections_layout.addWidget(self.clear_button)
## Browser
# Dataset, speaker and utterance selection
i = 0
source_groupbox = QGroupBox('Source(源音频)')
source_layout = QGridLayout()
source_groupbox.setLayout(source_layout)
browser_layout.addWidget(source_groupbox, i, 0, 1, 5)
self.dataset_box = QComboBox()
source_layout.addWidget(QLabel("Dataset(数据集):"), i, 0)
source_layout.addWidget(self.dataset_box, i, 1)
self.random_dataset_button = QPushButton("Random")
source_layout.addWidget(self.random_dataset_button, i, 2)
i += 1
self.speaker_box = QComboBox()
source_layout.addWidget(QLabel("Speaker(说话者)"), i, 0)
source_layout.addWidget(self.speaker_box, i, 1)
self.random_speaker_button = QPushButton("Random")
source_layout.addWidget(self.random_speaker_button, i, 2)
i += 1
self.utterance_box = QComboBox()
source_layout.addWidget(QLabel("Utterance(音频):"), i, 0)
source_layout.addWidget(self.utterance_box, i, 1)
self.random_utterance_button = QPushButton("Random")
source_layout.addWidget(self.random_utterance_button, i, 2)
i += 1
source_layout.addWidget(QLabel("<b>Use(使用):</b>"), i, 0)
self.browser_load_button = QPushButton("Load Above(加载上面)")
source_layout.addWidget(self.browser_load_button, i, 1, 1, 2)
self.auto_next_checkbox = QCheckBox("Auto select next")
self.auto_next_checkbox.setChecked(True)
source_layout.addWidget(self.auto_next_checkbox, i+1, 1)
self.browser_browse_button = QPushButton("Browse(打开本地)")
source_layout.addWidget(self.browser_browse_button, i, 3)
self.record_button = QPushButton("Record(录音)")
source_layout.addWidget(self.record_button, i+1, 3)
i += 2
# Utterance box
browser_layout.addWidget(QLabel("<b>Current(当前):</b>"), i, 0)
self.utterance_history = QComboBox()
browser_layout.addWidget(self.utterance_history, i, 1)
self.play_button = QPushButton("Play(播放)")
browser_layout.addWidget(self.play_button, i, 2)
self.stop_button = QPushButton("Stop(暂停)")
browser_layout.addWidget(self.stop_button, i, 3)
if vc_mode:
self.load_soruce_button = QPushButton("Select(选择为被转换的语音输入)")
browser_layout.addWidget(self.load_soruce_button, i, 4)
i += 1
model_groupbox = QGroupBox('Models(模型选择)')
model_layout = QHBoxLayout()
model_groupbox.setLayout(model_layout)
browser_layout.addWidget(model_groupbox, i, 0, 2, 5)
# Model and audio output selection
self.encoder_box = QComboBox()
model_layout.addWidget(QLabel("Encoder:"))
model_layout.addWidget(self.encoder_box)
self.synthesizer_box = QComboBox()
if vc_mode:
self.extractor_box = QComboBox()
model_layout.addWidget(QLabel("Extractor:"))
model_layout.addWidget(self.extractor_box)
self.convertor_box = QComboBox()
model_layout.addWidget(QLabel("Convertor:"))
model_layout.addWidget(self.convertor_box)
else:
model_layout.addWidget(QLabel("Synthesizer:"))
model_layout.addWidget(self.synthesizer_box)
self.vocoder_box = QComboBox()
model_layout.addWidget(QLabel("Vocoder:"))
model_layout.addWidget(self.vocoder_box)
#Replay & Save Audio
i = 0
output_layout.addWidget(QLabel("<b>Toolbox Output:</b>"), i, 0)
self.waves_cb = QComboBox()
self.waves_cb_model = QStringListModel()
self.waves_cb.setModel(self.waves_cb_model)
self.waves_cb.setToolTip("Select one of the last generated waves in this section for replaying or exporting")
output_layout.addWidget(self.waves_cb, i, 1)
self.replay_wav_button = QPushButton("Replay")
self.replay_wav_button.setToolTip("Replay last generated vocoder")
output_layout.addWidget(self.replay_wav_button, i, 2)
self.export_wav_button = QPushButton("Export")
self.export_wav_button.setToolTip("Save last generated vocoder audio in filesystem as a wav file")
output_layout.addWidget(self.export_wav_button, i, 3)
self.audio_out_devices_cb=QComboBox()
i += 1
output_layout.addWidget(QLabel("<b>Audio Output</b>"), i, 0)
output_layout.addWidget(self.audio_out_devices_cb, i, 1)
## Embed & spectrograms
vis_layout.addStretch()
# TODO: add spectrograms for source
gridspec_kw = {"width_ratios": [1, 4]}
fig, self.current_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0",
gridspec_kw=gridspec_kw)
fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8)
vis_layout.addWidget(FigureCanvas(fig))
fig, self.gen_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0",
gridspec_kw=gridspec_kw)
fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8)
vis_layout.addWidget(FigureCanvas(fig))
for ax in self.current_ax.tolist() + self.gen_ax.tolist():
ax.set_facecolor("#F0F0F0")
for side in ["top", "right", "bottom", "left"]:
ax.spines[side].set_visible(False)
## Generation
self.text_prompt = QPlainTextEdit(default_text)
gen_layout.addWidget(self.text_prompt, stretch=1)
if vc_mode:
layout = QHBoxLayout()
self.convert_button = QPushButton("Extract and Convert")
layout.addWidget(self.convert_button)
gen_layout.addLayout(layout)
else:
self.generate_button = QPushButton("Synthesize and vocode")
gen_layout.addWidget(self.generate_button)
layout = QHBoxLayout()
self.synthesize_button = QPushButton("Synthesize only")
layout.addWidget(self.synthesize_button)
self.vocode_button = QPushButton("Vocode only")
layout.addWidget(self.vocode_button)
gen_layout.addLayout(layout)
layout_seed = QGridLayout()
self.random_seed_checkbox = QCheckBox("Random seed:")
self.random_seed_checkbox.setToolTip("When checked, makes the synthesizer and vocoder deterministic.")
layout_seed.addWidget(self.random_seed_checkbox, 0, 0)
self.seed_textbox = QLineEdit()
self.seed_textbox.setMaximumWidth(80)
layout_seed.addWidget(self.seed_textbox, 0, 1)
self.trim_silences_checkbox = QCheckBox("Enhance vocoder output")
self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output."
" This feature requires `webrtcvad` to be installed.")
layout_seed.addWidget(self.trim_silences_checkbox, 0, 2, 1, 2)
self.style_slider = QSlider(Qt.Horizontal)
self.style_slider.setTickInterval(1)
self.style_slider.setFocusPolicy(Qt.NoFocus)
self.style_slider.setSingleStep(1)
self.style_slider.setRange(-1, 9)
self.style_value_label = QLabel("-1")
self.style_slider.setValue(-1)
layout_seed.addWidget(QLabel("Style:"), 1, 0)
self.style_slider.valueChanged.connect(lambda s: self.style_value_label.setNum(s))
layout_seed.addWidget(self.style_value_label, 1, 1)
layout_seed.addWidget(self.style_slider, 1, 3)
self.token_slider = QSlider(Qt.Horizontal)
self.token_slider.setTickInterval(1)
self.token_slider.setFocusPolicy(Qt.NoFocus)
self.token_slider.setSingleStep(1)
self.token_slider.setRange(3, 9)
self.token_value_label = QLabel("5")
self.token_slider.setValue(4)
layout_seed.addWidget(QLabel("Accuracy(精度):"), 2, 0)
self.token_slider.valueChanged.connect(lambda s: self.token_value_label.setNum(s))
layout_seed.addWidget(self.token_value_label, 2, 1)
layout_seed.addWidget(self.token_slider, 2, 3)
self.length_slider = QSlider(Qt.Horizontal)
self.length_slider.setTickInterval(1)
self.length_slider.setFocusPolicy(Qt.NoFocus)
self.length_slider.setSingleStep(1)
self.length_slider.setRange(1, 10)
self.length_value_label = QLabel("2")
self.length_slider.setValue(2)
layout_seed.addWidget(QLabel("MaxLength(最大句长):"), 3, 0)
self.length_slider.valueChanged.connect(lambda s: self.length_value_label.setNum(s))
layout_seed.addWidget(self.length_value_label, 3, 1)
layout_seed.addWidget(self.length_slider, 3, 3)
gen_layout.addLayout(layout_seed)
self.loading_bar = QProgressBar()
gen_layout.addWidget(self.loading_bar)
self.log_window = QLabel()
self.log_window.setAlignment(Qt.AlignBottom | Qt.AlignLeft)
gen_layout.addWidget(self.log_window)
self.logs = []
gen_layout.addStretch()
## Set the size of the window and of the elements
max_size = QDesktopWidget().availableGeometry(self).size() * 0.5
self.resize(max_size)
## Finalize the display
self.reset_interface(vc_mode)
self.show()
def start(self):
self.app.exec_()

View File

@@ -0,0 +1,5 @@
from collections import namedtuple
Utterance = namedtuple("Utterance", "name speaker_name wav spec embed partial_embeds synth")
Utterance.__eq__ = lambda x, y: x.name == y.name
Utterance.__hash__ = lambda x: hash(x.name)