mirror of
https://github.com/babysor/Realtime-Voice-Clone-Chinese.git
synced 2026-05-12 11:35:56 +08:00
Refactor Project to 3 parts: Models, Control, Data
Need readme
This commit is contained in:
0
control/__init__.py
Normal file
0
control/__init__.py
Normal file
64
control/cli/encoder_preprocess.py
Normal file
64
control/cli/encoder_preprocess.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from models.encoder.preprocess import (preprocess_aidatatang_200zh,
|
||||
preprocess_librispeech, preprocess_voxceleb1,
|
||||
preprocess_voxceleb2)
|
||||
from utils.argutils import print_args
|
||||
|
||||
if __name__ == "__main__":
|
||||
class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
|
||||
pass
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocesses audio files from datasets, encodes them as mel spectrograms and "
|
||||
"writes them to the disk. This will allow you to train the encoder. The "
|
||||
"datasets required are at least one of LibriSpeech, VoxCeleb1, VoxCeleb2, aidatatang_200zh. ",
|
||||
formatter_class=MyFormatter
|
||||
)
|
||||
parser.add_argument("datasets_root", type=Path, help=\
|
||||
"Path to the directory containing your LibriSpeech/TTS and VoxCeleb datasets.")
|
||||
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
|
||||
"Path to the output directory that will contain the mel spectrograms. If left out, "
|
||||
"defaults to <datasets_root>/SV2TTS/encoder/")
|
||||
parser.add_argument("-d", "--datasets", type=str,
|
||||
default="librispeech_other,voxceleb1,aidatatang_200zh", help=\
|
||||
"Comma-separated list of the name of the datasets you want to preprocess. Only the train "
|
||||
"set of these datasets will be used. Possible names: librispeech_other, voxceleb1, "
|
||||
"voxceleb2.")
|
||||
parser.add_argument("-s", "--skip_existing", action="store_true", help=\
|
||||
"Whether to skip existing output files with the same name. Useful if this script was "
|
||||
"interrupted.")
|
||||
parser.add_argument("--no_trim", action="store_true", help=\
|
||||
"Preprocess audio without trimming silences (not recommended).")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Verify webrtcvad is available
|
||||
if not args.no_trim:
|
||||
try:
|
||||
import webrtcvad
|
||||
except:
|
||||
raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
|
||||
"noise removal and is recommended. Please install and try again. If installation fails, "
|
||||
"use --no_trim to disable this error message.")
|
||||
del args.no_trim
|
||||
|
||||
# Process the arguments
|
||||
args.datasets = args.datasets.split(",")
|
||||
if not hasattr(args, "out_dir"):
|
||||
args.out_dir = args.datasets_root.joinpath("SV2TTS", "encoder")
|
||||
assert args.datasets_root.exists()
|
||||
args.out_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Preprocess the datasets
|
||||
print_args(args, parser)
|
||||
preprocess_func = {
|
||||
"librispeech_other": preprocess_librispeech,
|
||||
"voxceleb1": preprocess_voxceleb1,
|
||||
"voxceleb2": preprocess_voxceleb2,
|
||||
"aidatatang_200zh": preprocess_aidatatang_200zh,
|
||||
}
|
||||
args = vars(args)
|
||||
for dataset in args.pop("datasets"):
|
||||
print("Preprocessing %s" % dataset)
|
||||
preprocess_func[dataset](**args)
|
||||
47
control/cli/encoder_train.py
Normal file
47
control/cli/encoder_train.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from utils.argutils import print_args
|
||||
from models.encoder.train import train
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Trains the speaker encoder. You must have run encoder_preprocess.py first.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument("run_id", type=str, help= \
|
||||
"Name for this model instance. If a model state from the same run ID was previously "
|
||||
"saved, the training will restart from there. Pass -f to overwrite saved states and "
|
||||
"restart from scratch.")
|
||||
parser.add_argument("clean_data_root", type=Path, help= \
|
||||
"Path to the output directory of encoder_preprocess.py. If you left the default "
|
||||
"output directory when preprocessing, it should be <datasets_root>/SV2TTS/encoder/.")
|
||||
parser.add_argument("-m", "--models_dir", type=Path, default="encoder/saved_models/", help=\
|
||||
"Path to the output directory that will contain the saved model weights, as well as "
|
||||
"backups of those weights and plots generated during training.")
|
||||
parser.add_argument("-v", "--vis_every", type=int, default=10, help= \
|
||||
"Number of steps between updates of the loss and the plots.")
|
||||
parser.add_argument("-u", "--umap_every", type=int, default=100, help= \
|
||||
"Number of steps between updates of the umap projection. Set to 0 to never update the "
|
||||
"projections.")
|
||||
parser.add_argument("-s", "--save_every", type=int, default=500, help= \
|
||||
"Number of steps between updates of the model on the disk. Set to 0 to never save the "
|
||||
"model.")
|
||||
parser.add_argument("-b", "--backup_every", type=int, default=7500, help= \
|
||||
"Number of steps between backups of the model. Set to 0 to never make backups of the "
|
||||
"model.")
|
||||
parser.add_argument("-f", "--force_restart", action="store_true", help= \
|
||||
"Do not load any saved model.")
|
||||
parser.add_argument("--visdom_server", type=str, default="http://localhost")
|
||||
parser.add_argument("--no_visdom", action="store_true", help= \
|
||||
"Disable visdom.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Process the arguments
|
||||
args.models_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Run the training
|
||||
print_args(args, parser)
|
||||
train(**vars(args))
|
||||
|
||||
67
control/cli/ppg2mel_train.py
Normal file
67
control/cli/ppg2mel_train.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import sys
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from utils.load_yaml import HpsYaml
|
||||
from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||
|
||||
# For reproducibility, comment these may speed up training
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def main():
|
||||
# Arguments
|
||||
parser = argparse.ArgumentParser(description=
|
||||
'Training PPG2Mel VC model.')
|
||||
parser.add_argument('--config', type=str,
|
||||
help='Path to experiment config, e.g., config/vc.yaml')
|
||||
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
||||
parser.add_argument('--logdir', default='log/', type=str,
|
||||
help='Logging path.', required=False)
|
||||
parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str,
|
||||
help='Checkpoint path.', required=False)
|
||||
parser.add_argument('--outdir', default='result/', type=str,
|
||||
help='Decode output path.', required=False)
|
||||
parser.add_argument('--load', default=None, type=str,
|
||||
help='Load pre-trained model (for training only)', required=False)
|
||||
parser.add_argument('--warm_start', action='store_true',
|
||||
help='Load model weights only, ignore specified layers.')
|
||||
parser.add_argument('--seed', default=0, type=int,
|
||||
help='Random seed for reproducable results.', required=False)
|
||||
parser.add_argument('--njobs', default=8, type=int,
|
||||
help='Number of threads for dataloader/decoding.', required=False)
|
||||
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
||||
parser.add_argument('--no-pin', action='store_true',
|
||||
help='Disable pin-memory for dataloader')
|
||||
parser.add_argument('--test', action='store_true', help='Test the model.')
|
||||
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
||||
parser.add_argument('--finetune', action='store_true', help='Finetune model')
|
||||
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
|
||||
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
|
||||
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
|
||||
|
||||
###
|
||||
|
||||
paras = parser.parse_args()
|
||||
setattr(paras, 'gpu', not paras.cpu)
|
||||
setattr(paras, 'pin_memory', not paras.no_pin)
|
||||
setattr(paras, 'verbose', not paras.no_msg)
|
||||
# Make the config dict dot visitable
|
||||
config = HpsYaml(paras.config)
|
||||
|
||||
np.random.seed(paras.seed)
|
||||
torch.manual_seed(paras.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(paras.seed)
|
||||
|
||||
print(">>> OneShot VC training ...")
|
||||
mode = "train"
|
||||
solver = Solver(config, paras, mode)
|
||||
solver.load_data()
|
||||
solver.set_model()
|
||||
solver.exec()
|
||||
print(">>> Oneshot VC train finished!")
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
49
control/cli/pre4ppg.py
Normal file
49
control/cli/pre4ppg.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
from models.ppg2mel.preprocess import preprocess_dataset
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
recognized_datasets = [
|
||||
"aidatatang_200zh",
|
||||
"aidatatang_200zh_s", # sample
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocesses audio files from datasets, to be used by the "
|
||||
"ppg2mel model for training.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument("datasets_root", type=Path, help=\
|
||||
"Path to the directory containing your datasets.")
|
||||
parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\
|
||||
"Name of the dataset to process, allowing values: aidatatang_200zh.")
|
||||
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
|
||||
"Path to the output directory that will contain the mel spectrograms, the audios and the "
|
||||
"embeds. Defaults to <datasets_root>/PPGVC/ppg2mel/")
|
||||
parser.add_argument("-n", "--n_processes", type=int, default=8, help=\
|
||||
"Number of processes in parallel.")
|
||||
# parser.add_argument("-s", "--skip_existing", action="store_true", help=\
|
||||
# "Whether to overwrite existing files with the same name. Useful if the preprocessing was "
|
||||
# "interrupted. ")
|
||||
# parser.add_argument("--hparams", type=str, default="", help=\
|
||||
# "Hyperparameter overrides as a comma-separated list of name-value pairs")
|
||||
# parser.add_argument("--no_trim", action="store_true", help=\
|
||||
# "Preprocess audio without trimming silences (not recommended).")
|
||||
parser.add_argument("-pf", "--ppg_encoder_model_fpath", type=Path, default="ppg_extractor/saved_models/24epoch.pt", help=\
|
||||
"Path your trained ppg encoder model.")
|
||||
parser.add_argument("-sf", "--speaker_encoder_model", type=Path, default="encoder/saved_models/pretrained_bak_5805000.pt", help=\
|
||||
"Path your trained speaker encoder model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.dataset in recognized_datasets, 'is not supported, file a issue to propose a new one'
|
||||
|
||||
# Create directories
|
||||
assert args.datasets_root.exists()
|
||||
if not hasattr(args, "out_dir"):
|
||||
args.out_dir = args.datasets_root.joinpath("PPGVC", "ppg2mel")
|
||||
args.out_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
preprocess_dataset(**vars(args))
|
||||
65
control/cli/synthesizer_preprocess_audio.py
Normal file
65
control/cli/synthesizer_preprocess_audio.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from models.synthesizer.preprocess import preprocess_dataset
|
||||
from models.synthesizer.hparams import hparams
|
||||
from utils.argutils import print_args
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
|
||||
recognized_datasets = [
|
||||
"aidatatang_200zh",
|
||||
"magicdata",
|
||||
"aishell3"
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("This method is deprecaded and will not be longer supported, please use 'pre.py'")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocesses audio files from datasets, encodes them as mel spectrograms "
|
||||
"and writes them to the disk. Audio files are also saved, to be used by the "
|
||||
"vocoder for training.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument("datasets_root", type=Path, help=\
|
||||
"Path to the directory containing your LibriSpeech/TTS datasets.")
|
||||
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
|
||||
"Path to the output directory that will contain the mel spectrograms, the audios and the "
|
||||
"embeds. Defaults to <datasets_root>/SV2TTS/synthesizer/")
|
||||
parser.add_argument("-n", "--n_processes", type=int, default=None, help=\
|
||||
"Number of processes in parallel.")
|
||||
parser.add_argument("-s", "--skip_existing", action="store_true", help=\
|
||||
"Whether to overwrite existing files with the same name. Useful if the preprocessing was "
|
||||
"interrupted.")
|
||||
parser.add_argument("--hparams", type=str, default="", help=\
|
||||
"Hyperparameter overrides as a comma-separated list of name-value pairs")
|
||||
parser.add_argument("--no_trim", action="store_true", help=\
|
||||
"Preprocess audio without trimming silences (not recommended).")
|
||||
parser.add_argument("--no_alignments", action="store_true", help=\
|
||||
"Use this option when dataset does not include alignments\
|
||||
(these are used to split long audio files into sub-utterances.)")
|
||||
parser.add_argument("--dataset", type=str, default="aidatatang_200zh", help=\
|
||||
"Name of the dataset to process, allowing values: magicdata, aidatatang_200zh.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Process the arguments
|
||||
if not hasattr(args, "out_dir"):
|
||||
args.out_dir = args.datasets_root.joinpath("SV2TTS", "synthesizer")
|
||||
assert args.dataset in recognized_datasets, 'is not supported, please vote for it in https://github.com/babysor/MockingBird/issues/10'
|
||||
# Create directories
|
||||
assert args.datasets_root.exists()
|
||||
args.out_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Verify webrtcvad is available
|
||||
if not args.no_trim:
|
||||
try:
|
||||
import webrtcvad
|
||||
except:
|
||||
raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
|
||||
"noise removal and is recommended. Please install and try again. If installation fails, "
|
||||
"use --no_trim to disable this error message.")
|
||||
del args.no_trim
|
||||
|
||||
# Preprocess the dataset
|
||||
print_args(args, parser)
|
||||
args.hparams = hparams.parse(args.hparams)
|
||||
|
||||
preprocess_dataset(**vars(args))
|
||||
26
control/cli/synthesizer_preprocess_embeds.py
Normal file
26
control/cli/synthesizer_preprocess_embeds.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from models.synthesizer.preprocess import create_embeddings
|
||||
from utils.argutils import print_args
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("This method is deprecaded and will not be longer supported, please use 'pre.py'")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Creates embeddings for the synthesizer from the LibriSpeech utterances.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument("synthesizer_root", type=Path, help=\
|
||||
"Path to the synthesizer training data that contains the audios and the train.txt file. "
|
||||
"If you let everything as default, it should be <datasets_root>/SV2TTS/synthesizer/.")
|
||||
parser.add_argument("-e", "--encoder_model_fpath", type=Path,
|
||||
default="encoder/saved_models/pretrained.pt", help=\
|
||||
"Path your trained encoder model.")
|
||||
parser.add_argument("-n", "--n_processes", type=int, default=4, help= \
|
||||
"Number of parallel processes. An encoder is created for each, so you may need to lower "
|
||||
"this value on GPUs with low memory. Set it to 1 if CUDA is unhappy.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Preprocess the dataset
|
||||
print_args(args, parser)
|
||||
create_embeddings(**vars(args))
|
||||
37
control/cli/synthesizer_train.py
Normal file
37
control/cli/synthesizer_train.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from models.synthesizer.hparams import hparams
|
||||
from models.synthesizer.train import train
|
||||
from utils.argutils import print_args
|
||||
import argparse
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("run_id", type=str, help= \
|
||||
"Name for this model instance. If a model state from the same run ID was previously "
|
||||
"saved, the training will restart from there. Pass -f to overwrite saved states and "
|
||||
"restart from scratch.")
|
||||
parser.add_argument("syn_dir", type=str, default=argparse.SUPPRESS, help= \
|
||||
"Path to the synthesizer directory that contains the ground truth mel spectrograms, "
|
||||
"the wavs and the embeds.")
|
||||
parser.add_argument("-m", "--models_dir", type=str, default="synthesizer/saved_models/", help=\
|
||||
"Path to the output directory that will contain the saved model weights and the logs.")
|
||||
parser.add_argument("-s", "--save_every", type=int, default=1000, help= \
|
||||
"Number of steps between updates of the model on the disk. Set to 0 to never save the "
|
||||
"model.")
|
||||
parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \
|
||||
"Number of steps between backups of the model. Set to 0 to never make backups of the "
|
||||
"model.")
|
||||
parser.add_argument("-l", "--log_every", type=int, default=200, help= \
|
||||
"Number of steps between summary the training info in tensorboard")
|
||||
parser.add_argument("-f", "--force_restart", action="store_true", help= \
|
||||
"Do not load any saved model and restart from scratch.")
|
||||
parser.add_argument("--hparams", default="",
|
||||
help="Hyperparameter overrides as a comma-separated list of name=value "
|
||||
"pairs")
|
||||
args = parser.parse_args()
|
||||
print_args(args, parser)
|
||||
|
||||
args.hparams = hparams.parse(args.hparams)
|
||||
|
||||
# Run the training
|
||||
train(**vars(args))
|
||||
59
control/cli/vocoder_preprocess.py
Normal file
59
control/cli/vocoder_preprocess.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from models.synthesizer.synthesize import run_synthesis
|
||||
from models.synthesizer.hparams import hparams
|
||||
from utils.argutils import print_args
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
|
||||
pass
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Creates ground-truth aligned (GTA) spectrograms from the vocoder.",
|
||||
formatter_class=MyFormatter
|
||||
)
|
||||
parser.add_argument("datasets_root", type=str, help=\
|
||||
"Path to the directory containing your SV2TTS directory. If you specify both --in_dir and "
|
||||
"--out_dir, this argument won't be used.")
|
||||
parser.add_argument("-m", "--model_dir", type=str,
|
||||
default="synthesizer/saved_models/mandarin/", help=\
|
||||
"Path to the pretrained model directory.")
|
||||
parser.add_argument("-i", "--in_dir", type=str, default=argparse.SUPPRESS, help= \
|
||||
"Path to the synthesizer directory that contains the mel spectrograms, the wavs and the "
|
||||
"embeds. Defaults to <datasets_root>/SV2TTS/synthesizer/.")
|
||||
parser.add_argument("-o", "--out_dir", type=str, default=argparse.SUPPRESS, help= \
|
||||
"Path to the output vocoder directory that will contain the ground truth aligned mel "
|
||||
"spectrograms. Defaults to <datasets_root>/SV2TTS/vocoder/.")
|
||||
parser.add_argument("--hparams", default="",
|
||||
help="Hyperparameter overrides as a comma-separated list of name=value "
|
||||
"pairs")
|
||||
parser.add_argument("--no_trim", action="store_true", help=\
|
||||
"Preprocess audio without trimming silences (not recommended).")
|
||||
parser.add_argument("--cpu", action="store_true", help=\
|
||||
"If True, processing is done on CPU, even when a GPU is available.")
|
||||
args = parser.parse_args()
|
||||
print_args(args, parser)
|
||||
modified_hp = hparams.parse(args.hparams)
|
||||
|
||||
if not hasattr(args, "in_dir"):
|
||||
args.in_dir = os.path.join(args.datasets_root, "SV2TTS", "synthesizer")
|
||||
if not hasattr(args, "out_dir"):
|
||||
args.out_dir = os.path.join(args.datasets_root, "SV2TTS", "vocoder")
|
||||
|
||||
if args.cpu:
|
||||
# Hide GPUs from Pytorch to force CPU processing
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
|
||||
# Verify webrtcvad is available
|
||||
if not args.no_trim:
|
||||
try:
|
||||
import webrtcvad
|
||||
except:
|
||||
raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
|
||||
"noise removal and is recommended. Please install and try again. If installation fails, "
|
||||
"use --no_trim to disable this error message.")
|
||||
del args.no_trim
|
||||
|
||||
run_synthesis(args.in_dir, args.out_dir, args.model_dir, modified_hp)
|
||||
|
||||
92
control/cli/vocoder_train.py
Normal file
92
control/cli/vocoder_train.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from utils.argutils import print_args
|
||||
from models.vocoder.wavernn.train import train
|
||||
from models.vocoder.hifigan.train import train as train_hifigan
|
||||
from models.vocoder.fregan.train import train as train_fregan
|
||||
from utils.util import AttrDict
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import json
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Trains the vocoder from the synthesizer audios and the GTA synthesized mels, "
|
||||
"or ground truth mels.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument("run_id", type=str, help= \
|
||||
"Name for this model instance. If a model state from the same run ID was previously "
|
||||
"saved, the training will restart from there. Pass -f to overwrite saved states and "
|
||||
"restart from scratch.")
|
||||
parser.add_argument("datasets_root", type=str, help= \
|
||||
"Path to the directory containing your SV2TTS directory. Specifying --syn_dir or --voc_dir "
|
||||
"will take priority over this argument.")
|
||||
parser.add_argument("vocoder_type", type=str, default="wavernn", help= \
|
||||
"Choose the vocoder type for train. Defaults to wavernn"
|
||||
"Now, Support <hifigan> and <wavernn> for choose")
|
||||
parser.add_argument("--syn_dir", type=str, default=argparse.SUPPRESS, help= \
|
||||
"Path to the synthesizer directory that contains the ground truth mel spectrograms, "
|
||||
"the wavs and the embeds. Defaults to <datasets_root>/SV2TTS/synthesizer/.")
|
||||
parser.add_argument("--voc_dir", type=str, default=argparse.SUPPRESS, help= \
|
||||
"Path to the vocoder directory that contains the GTA synthesized mel spectrograms. "
|
||||
"Defaults to <datasets_root>/SV2TTS/vocoder/. Unused if --ground_truth is passed.")
|
||||
parser.add_argument("-m", "--models_dir", type=str, default="vocoder/saved_models/", help=\
|
||||
"Path to the directory that will contain the saved model weights, as well as backups "
|
||||
"of those weights and wavs generated during training.")
|
||||
parser.add_argument("-g", "--ground_truth", action="store_true", help= \
|
||||
"Train on ground truth spectrograms (<datasets_root>/SV2TTS/synthesizer/mels).")
|
||||
parser.add_argument("-s", "--save_every", type=int, default=1000, help= \
|
||||
"Number of steps between updates of the model on the disk. Set to 0 to never save the "
|
||||
"model.")
|
||||
parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \
|
||||
"Number of steps between backups of the model. Set to 0 to never make backups of the "
|
||||
"model.")
|
||||
parser.add_argument("-f", "--force_restart", action="store_true", help= \
|
||||
"Do not load any saved model and restart from scratch.")
|
||||
parser.add_argument("--config", type=str, default="vocoder/hifigan/config_16k_.json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not hasattr(args, "syn_dir"):
|
||||
args.syn_dir = Path(args.datasets_root, "SV2TTS", "synthesizer")
|
||||
args.syn_dir = Path(args.syn_dir)
|
||||
if not hasattr(args, "voc_dir"):
|
||||
args.voc_dir = Path(args.datasets_root, "SV2TTS", "vocoder")
|
||||
args.voc_dir = Path(args.voc_dir)
|
||||
del args.datasets_root
|
||||
args.models_dir = Path(args.models_dir)
|
||||
args.models_dir.mkdir(exist_ok=True)
|
||||
|
||||
print_args(args, parser)
|
||||
|
||||
# Process the arguments
|
||||
if args.vocoder_type == "wavernn":
|
||||
# Run the training wavernn
|
||||
delattr(args, 'vocoder_type')
|
||||
delattr(args, 'config')
|
||||
train(**vars(args))
|
||||
elif args.vocoder_type == "hifigan":
|
||||
with open(args.config) as f:
|
||||
json_config = json.load(f)
|
||||
h = AttrDict(json_config)
|
||||
if h.num_gpus > 1:
|
||||
h.num_gpus = torch.cuda.device_count()
|
||||
h.batch_size = int(h.batch_size / h.num_gpus)
|
||||
print('Batch size per GPU :', h.batch_size)
|
||||
mp.spawn(train_hifigan, nprocs=h.num_gpus, args=(args, h,))
|
||||
else:
|
||||
train_hifigan(0, args, h)
|
||||
elif args.vocoder_type == "fregan":
|
||||
with open('vocoder/fregan/config.json') as f:
|
||||
json_config = json.load(f)
|
||||
h = AttrDict(json_config)
|
||||
if h.num_gpus > 1:
|
||||
h.num_gpus = torch.cuda.device_count()
|
||||
h.batch_size = int(h.batch_size / h.num_gpus)
|
||||
print('Batch size per GPU :', h.batch_size)
|
||||
mp.spawn(train_fregan, nprocs=h.num_gpus, args=(args, h,))
|
||||
else:
|
||||
train_fregan(0, args, h)
|
||||
|
||||
|
||||
0
control/mkgui/__init__.py
Normal file
0
control/mkgui/__init__.py
Normal file
145
control/mkgui/app.py
Normal file
145
control/mkgui/app.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from models.encoder import inference as encoder
|
||||
import librosa
|
||||
from scipy.io.wavfile import write
|
||||
import re
|
||||
import numpy as np
|
||||
from control.mkgui.base.components.types import FileContent
|
||||
from models.vocoder.hifigan import inference as gan_vocoder
|
||||
from models.synthesizer.inference import Synthesizer
|
||||
from typing import Any, Tuple
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Constants
|
||||
AUDIO_SAMPLES_DIR = f"data{os.sep}samples{os.sep}"
|
||||
SYN_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}synthesizer"
|
||||
ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
|
||||
VOC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}vocoder"
|
||||
TEMP_SOURCE_AUDIO = f"wavs{os.sep}temp_source.wav"
|
||||
TEMP_RESULT_AUDIO = f"wavs{os.sep}temp_result.wav"
|
||||
if not os.path.isdir("wavs"):
|
||||
os.makedirs("wavs")
|
||||
|
||||
# Load local sample audio as options TODO: load dataset
|
||||
if os.path.isdir(AUDIO_SAMPLES_DIR):
|
||||
audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
|
||||
# Pre-Load models
|
||||
if os.path.isdir(SYN_MODELS_DIRT):
|
||||
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
||||
else:
|
||||
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist. 请将模型文件位置移动到上述位置中进行重试!")
|
||||
|
||||
if os.path.isdir(ENC_MODELS_DIRT):
|
||||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded encoders models: " + str(len(encoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(VOC_MODELS_DIRT):
|
||||
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
|
||||
print("Loaded vocoders models: " + str(len(synthesizers)))
|
||||
else:
|
||||
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
|
||||
|
||||
class Input(BaseModel):
|
||||
message: str = Field(
|
||||
..., example="欢迎使用工具箱, 现已支持中文输入!", alias="文本内容"
|
||||
)
|
||||
local_audio_file: audio_input_selection = Field(
|
||||
..., alias="输入语音(本地wav)",
|
||||
description="选择本地语音文件."
|
||||
)
|
||||
upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
|
||||
description="拖拽或点击上传.", mime_type="audio/wav")
|
||||
encoder: encoders = Field(
|
||||
..., alias="编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
synthesizer: synthesizers = Field(
|
||||
..., alias="合成模型",
|
||||
description="选择语音合成模型文件."
|
||||
)
|
||||
vocoder: vocoders = Field(
|
||||
..., alias="语音解码模型",
|
||||
description="选择语音解码模型文件(目前只支持HifiGan类型)."
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: Tuple[AudioEntity, AudioEntity]
|
||||
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
src, result = self.__root__
|
||||
|
||||
streamlit_app.subheader("Synthesized Audio")
|
||||
streamlit_app.audio(result.content, format="audio/wav")
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(src.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Source Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(result.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Result Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
|
||||
|
||||
def synthesize(input: Input) -> Output:
|
||||
"""synthesize(合成)"""
|
||||
# load models
|
||||
encoder.load_model(Path(input.encoder.value))
|
||||
current_synt = Synthesizer(Path(input.synthesizer.value))
|
||||
gan_vocoder.load_model(Path(input.vocoder.value))
|
||||
|
||||
# load file
|
||||
if input.upload_audio_file != None:
|
||||
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
||||
f.write(input.upload_audio_file.as_bytes())
|
||||
f.seek(0)
|
||||
wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
|
||||
else:
|
||||
wav, sample_rate = librosa.load(input.local_audio_file.value)
|
||||
write(TEMP_SOURCE_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
||||
|
||||
source_spec = Synthesizer.make_spectrogram(wav)
|
||||
|
||||
# preprocess
|
||||
encoder_wav = encoder.preprocess_wav(wav, sample_rate)
|
||||
embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
||||
|
||||
# Load input text
|
||||
texts = filter(None, input.message.split("\n"))
|
||||
punctuation = '!,。、,' # punctuate and split/clean text
|
||||
processed_texts = []
|
||||
for text in texts:
|
||||
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
|
||||
if processed_text:
|
||||
processed_texts.append(processed_text.strip())
|
||||
texts = processed_texts
|
||||
|
||||
# synthesize and vocode
|
||||
embeds = [embed] * len(texts)
|
||||
specs = current_synt.synthesize_spectrograms(texts, embeds)
|
||||
spec = np.concatenate(specs, axis=1)
|
||||
sample_rate = Synthesizer.sample_rate
|
||||
wav, sample_rate = gan_vocoder.infer_waveform(spec)
|
||||
|
||||
# write and output
|
||||
write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
||||
with open(TEMP_SOURCE_AUDIO, "rb") as f:
|
||||
source_file = f.read()
|
||||
with open(TEMP_RESULT_AUDIO, "rb") as f:
|
||||
result_file = f.read()
|
||||
return Output(__root__=(AudioEntity(content=source_file, mel=source_spec), AudioEntity(content=result_file, mel=spec)))
|
||||
166
control/mkgui/app_vc.py
Normal file
166
control/mkgui/app_vc.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Tuple
|
||||
|
||||
import librosa
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from scipy.io.wavfile import write
|
||||
|
||||
import models.ppg2mel as Convertor
|
||||
import models.ppg_extractor as Extractor
|
||||
from control.mkgui.base.components.types import FileContent
|
||||
from models.encoder import inference as speacker_encoder
|
||||
from models.synthesizer.inference import Synthesizer
|
||||
from models.vocoder.hifigan import inference as gan_vocoder
|
||||
|
||||
# Constants
|
||||
AUDIO_SAMPLES_DIR = f'data{os.sep}samples{os.sep}'
|
||||
EXT_MODELS_DIRT = f'data{os.sep}ckpt{os.sep}ppg_extractor'
|
||||
CONV_MODELS_DIRT = f'data{os.sep}ckpt{os.sep}ppg2mel'
|
||||
VOC_MODELS_DIRT = f'data{os.sep}ckpt{os.sep}vocoder'
|
||||
TEMP_SOURCE_AUDIO = f'wavs{os.sep}temp_source.wav'
|
||||
TEMP_TARGET_AUDIO = f'wavs{os.sep}temp_target.wav'
|
||||
TEMP_RESULT_AUDIO = f'wavs{os.sep}temp_result.wav'
|
||||
|
||||
# Load local sample audio as options TODO: load dataset
|
||||
if os.path.isdir(AUDIO_SAMPLES_DIR):
|
||||
audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
|
||||
# Pre-Load models
|
||||
if os.path.isdir(EXT_MODELS_DIRT):
|
||||
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded extractor models: " + str(len(extractors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(CONV_MODELS_DIRT):
|
||||
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
|
||||
print("Loaded convertor models: " + str(len(convertors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(VOC_MODELS_DIRT):
|
||||
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
|
||||
print("Loaded vocoders models: " + str(len(vocoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
class Input(BaseModel):
|
||||
local_audio_file: audio_input_selection = Field(
|
||||
..., alias="输入语音(本地wav)",
|
||||
description="选择本地语音文件."
|
||||
)
|
||||
upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
|
||||
description="拖拽或点击上传.", mime_type="audio/wav")
|
||||
local_audio_file_target: audio_input_selection = Field(
|
||||
..., alias="目标语音(本地wav)",
|
||||
description="选择本地语音文件."
|
||||
)
|
||||
upload_audio_file_target: FileContent = Field(default=None, alias="或上传目标语音",
|
||||
description="拖拽或点击上传.", mime_type="audio/wav")
|
||||
extractor: extractors = Field(
|
||||
..., alias="编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
convertor: convertors = Field(
|
||||
..., alias="转换模型",
|
||||
description="选择语音转换模型文件."
|
||||
)
|
||||
vocoder: vocoders = Field(
|
||||
..., alias="语音解码模型",
|
||||
description="选择语音解码模型文件(目前只支持HifiGan类型)."
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: Tuple[AudioEntity, AudioEntity, AudioEntity]
|
||||
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
src, target, result = self.__root__
|
||||
|
||||
streamlit_app.subheader("Synthesized Audio")
|
||||
streamlit_app.audio(result.content, format="audio/wav")
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(src.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Source Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(target.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Target Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(result.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Result Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
|
||||
def convert(input: Input) -> Output:
|
||||
"""convert(转换)"""
|
||||
# load models
|
||||
extractor = Extractor.load_model(Path(input.extractor.value))
|
||||
convertor = Convertor.load_model(Path(input.convertor.value))
|
||||
# current_synt = Synthesizer(Path(input.synthesizer.value))
|
||||
gan_vocoder.load_model(Path(input.vocoder.value))
|
||||
|
||||
# load file
|
||||
if input.upload_audio_file != None:
|
||||
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
||||
f.write(input.upload_audio_file.as_bytes())
|
||||
f.seek(0)
|
||||
src_wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
|
||||
else:
|
||||
src_wav, sample_rate = librosa.load(input.local_audio_file.value)
|
||||
write(TEMP_SOURCE_AUDIO, sample_rate, src_wav) #Make sure we get the correct wav
|
||||
|
||||
if input.upload_audio_file_target != None:
|
||||
with open(TEMP_TARGET_AUDIO, "w+b") as f:
|
||||
f.write(input.upload_audio_file_target.as_bytes())
|
||||
f.seek(0)
|
||||
ref_wav, _ = librosa.load(TEMP_TARGET_AUDIO)
|
||||
else:
|
||||
ref_wav, _ = librosa.load(input.local_audio_file_target.value)
|
||||
write(TEMP_TARGET_AUDIO, sample_rate, ref_wav) #Make sure we get the correct wav
|
||||
|
||||
ppg = extractor.extract_from_wav(src_wav)
|
||||
# Import necessary dependency of Voice Conversion
|
||||
from utils.f0_utils import (compute_f0, compute_mean_std, f02lf0,
|
||||
get_converted_lf0uv)
|
||||
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
|
||||
speacker_encoder.load_model(Path(f"data{os.sep}ckpt{os.sep}encoder{os.sep}pretrained_bak_5805000.pt"))
|
||||
embed = speacker_encoder.embed_utterance(ref_wav)
|
||||
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
|
||||
min_len = min(ppg.shape[1], len(lf0_uv))
|
||||
ppg = ppg[:, :min_len]
|
||||
lf0_uv = lf0_uv[:min_len]
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
_, mel_pred, att_ws = convertor.inference(
|
||||
ppg,
|
||||
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
|
||||
spembs=torch.from_numpy(embed).unsqueeze(0).to(device),
|
||||
)
|
||||
mel_pred= mel_pred.transpose(0, 1)
|
||||
breaks = [mel_pred.shape[1]]
|
||||
mel_pred= mel_pred.detach().cpu().numpy()
|
||||
|
||||
# synthesize and vocode
|
||||
wav, sample_rate = gan_vocoder.infer_waveform(mel_pred)
|
||||
|
||||
# write and output
|
||||
write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
||||
with open(TEMP_SOURCE_AUDIO, "rb") as f:
|
||||
source_file = f.read()
|
||||
with open(TEMP_TARGET_AUDIO, "rb") as f:
|
||||
target_file = f.read()
|
||||
with open(TEMP_RESULT_AUDIO, "rb") as f:
|
||||
result_file = f.read()
|
||||
|
||||
|
||||
return Output(__root__=(AudioEntity(content=source_file, mel=Synthesizer.make_spectrogram(src_wav)), AudioEntity(content=target_file, mel=Synthesizer.make_spectrogram(ref_wav)), AudioEntity(content=result_file, mel=Synthesizer.make_spectrogram(wav))))
|
||||
2
control/mkgui/base/__init__.py
Normal file
2
control/mkgui/base/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
|
||||
from .core import Opyrator
|
||||
1
control/mkgui/base/api/__init__.py
Normal file
1
control/mkgui/base/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .fastapi_app import create_api
|
||||
102
control/mkgui/base/api/fastapi_utils.py
Normal file
102
control/mkgui/base/api/fastapi_utils.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Collection of utilities for FastAPI apps."""
|
||||
|
||||
import inspect
|
||||
from typing import Any, Type
|
||||
|
||||
from fastapi import FastAPI, Form
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def as_form(cls: Type[BaseModel]) -> Any:
|
||||
"""Adds an as_form class method to decorated models.
|
||||
|
||||
The as_form class method can be used with FastAPI endpoints
|
||||
"""
|
||||
new_params = [
|
||||
inspect.Parameter(
|
||||
field.alias,
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
default=(Form(field.default) if not field.required else Form(...)),
|
||||
)
|
||||
for field in cls.__fields__.values()
|
||||
]
|
||||
|
||||
async def _as_form(**data): # type: ignore
|
||||
return cls(**data)
|
||||
|
||||
sig = inspect.signature(_as_form)
|
||||
sig = sig.replace(parameters=new_params)
|
||||
_as_form.__signature__ = sig # type: ignore
|
||||
setattr(cls, "as_form", _as_form)
|
||||
return cls
|
||||
|
||||
|
||||
def patch_fastapi(app: FastAPI) -> None:
|
||||
"""Patch function to allow relative url resolution.
|
||||
|
||||
This patch is required to make fastapi fully functional with a relative url path.
|
||||
This code snippet can be copy-pasted to any Fastapi application.
|
||||
"""
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse
|
||||
|
||||
async def redoc_ui_html(req: Request) -> HTMLResponse:
|
||||
assert app.openapi_url is not None
|
||||
redoc_ui = get_redoc_html(
|
||||
openapi_url="./" + app.openapi_url.lstrip("/"),
|
||||
title=app.title + " - Redoc UI",
|
||||
)
|
||||
|
||||
return HTMLResponse(redoc_ui.body.decode("utf-8"))
|
||||
|
||||
async def swagger_ui_html(req: Request) -> HTMLResponse:
|
||||
assert app.openapi_url is not None
|
||||
swagger_ui = get_swagger_ui_html(
|
||||
openapi_url="./" + app.openapi_url.lstrip("/"),
|
||||
title=app.title + " - Swagger UI",
|
||||
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
|
||||
)
|
||||
|
||||
# insert request interceptor to have all request run on relativ path
|
||||
request_interceptor = (
|
||||
"requestInterceptor: (e) => {"
|
||||
"\n\t\t\tvar url = window.location.origin + window.location.pathname"
|
||||
'\n\t\t\turl = url.substring( 0, url.lastIndexOf( "/" ) + 1);'
|
||||
"\n\t\t\turl = e.url.replace(/http(s)?:\/\/[^/]*\//i, url);" # noqa: W605
|
||||
"\n\t\t\te.contextUrl = url"
|
||||
"\n\t\t\te.url = url"
|
||||
"\n\t\t\treturn e;}"
|
||||
)
|
||||
|
||||
return HTMLResponse(
|
||||
swagger_ui.body.decode("utf-8").replace(
|
||||
"dom_id: '#swagger-ui',",
|
||||
"dom_id: '#swagger-ui',\n\t\t" + request_interceptor + ",",
|
||||
)
|
||||
)
|
||||
|
||||
# remove old docs route and add our patched route
|
||||
routes_new = []
|
||||
for app_route in app.routes:
|
||||
if app_route.path == "/docs": # type: ignore
|
||||
continue
|
||||
|
||||
if app_route.path == "/redoc": # type: ignore
|
||||
continue
|
||||
|
||||
routes_new.append(app_route)
|
||||
|
||||
app.router.routes = routes_new
|
||||
|
||||
assert app.docs_url is not None
|
||||
app.add_route(app.docs_url, swagger_ui_html, include_in_schema=False)
|
||||
assert app.redoc_url is not None
|
||||
app.add_route(app.redoc_url, redoc_ui_html, include_in_schema=False)
|
||||
|
||||
# Make graphql realtive
|
||||
from starlette import graphql
|
||||
|
||||
graphql.GRAPHIQL = graphql.GRAPHIQL.replace(
|
||||
"({{REQUEST_PATH}}", '("." + {{REQUEST_PATH}}'
|
||||
)
|
||||
0
control/mkgui/base/components/__init__.py
Normal file
0
control/mkgui/base/components/__init__.py
Normal file
43
control/mkgui/base/components/outputs.py
Normal file
43
control/mkgui/base/components/outputs.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ScoredLabel(BaseModel):
|
||||
label: str
|
||||
score: float
|
||||
|
||||
|
||||
class ClassificationOutput(BaseModel):
|
||||
__root__: List[ScoredLabel]
|
||||
|
||||
def __iter__(self): # type: ignore
|
||||
return iter(self.__root__)
|
||||
|
||||
def __getitem__(self, item): # type: ignore
|
||||
return self.__root__[item]
|
||||
|
||||
def render_output_ui(self, streamlit) -> None: # type: ignore
|
||||
import plotly.express as px
|
||||
|
||||
sorted_predictions = sorted(
|
||||
[prediction.dict() for prediction in self.__root__],
|
||||
key=lambda k: k["score"],
|
||||
)
|
||||
|
||||
num_labels = len(sorted_predictions)
|
||||
if len(sorted_predictions) > 10:
|
||||
num_labels = streamlit.slider(
|
||||
"Maximum labels to show: ",
|
||||
min_value=1,
|
||||
max_value=len(sorted_predictions),
|
||||
value=len(sorted_predictions),
|
||||
)
|
||||
fig = px.bar(
|
||||
sorted_predictions[len(sorted_predictions) - num_labels :],
|
||||
x="score",
|
||||
y="label",
|
||||
orientation="h",
|
||||
)
|
||||
streamlit.plotly_chart(fig, use_container_width=True)
|
||||
# fig.show()
|
||||
46
control/mkgui/base/components/types.py
Normal file
46
control/mkgui/base/components/types.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import base64
|
||||
from typing import Any, Dict, overload
|
||||
|
||||
|
||||
class FileContent(str):
|
||||
def as_bytes(self) -> bytes:
|
||||
return base64.b64decode(self, validate=True)
|
||||
|
||||
def as_str(self) -> str:
|
||||
return self.as_bytes().decode()
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(format="byte")
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Any: # type: ignore
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: Any) -> "FileContent":
|
||||
if isinstance(value, FileContent):
|
||||
return value
|
||||
elif isinstance(value, str):
|
||||
return FileContent(value)
|
||||
elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||
return FileContent(base64.b64encode(value).decode())
|
||||
else:
|
||||
raise Exception("Wrong type")
|
||||
|
||||
# # 暂时无法使用,因为浏览器中没有考虑选择文件夹
|
||||
# class DirectoryContent(FileContent):
|
||||
# @classmethod
|
||||
# def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
# field_schema.update(format="path")
|
||||
|
||||
# @classmethod
|
||||
# def validate(cls, value: Any) -> "DirectoryContent":
|
||||
# if isinstance(value, DirectoryContent):
|
||||
# return value
|
||||
# elif isinstance(value, str):
|
||||
# return DirectoryContent(value)
|
||||
# elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||
# return DirectoryContent(base64.b64encode(value).decode())
|
||||
# else:
|
||||
# raise Exception("Wrong type")
|
||||
203
control/mkgui/base/core.py
Normal file
203
control/mkgui/base/core.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import re
|
||||
from typing import Any, Callable, Type, Union, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
from pydantic.tools import parse_obj_as
|
||||
|
||||
|
||||
def name_to_title(name: str) -> str:
|
||||
"""Converts a camelCase or snake_case name to title case."""
|
||||
# If camelCase -> convert to snake case
|
||||
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
||||
name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
|
||||
# Convert to title case
|
||||
return name.replace("_", " ").strip().title()
|
||||
|
||||
|
||||
def is_compatible_type(type: Type) -> bool:
|
||||
"""Returns `True` if the type is opyrator-compatible."""
|
||||
try:
|
||||
if issubclass(type, BaseModel):
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# valid list type
|
||||
if type.__origin__ is list and issubclass(type.__args__[0], BaseModel):
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_input_type(func: Callable) -> Type:
|
||||
"""Returns the input type of a given function (callable).
|
||||
|
||||
Args:
|
||||
func: The function for which to get the input type.
|
||||
|
||||
Raises:
|
||||
ValueError: If the function does not have a valid input type annotation.
|
||||
"""
|
||||
type_hints = get_type_hints(func)
|
||||
|
||||
if "input" not in type_hints:
|
||||
raise ValueError(
|
||||
"The callable MUST have a parameter with the name `input` with typing annotation. "
|
||||
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
|
||||
)
|
||||
|
||||
input_type = type_hints["input"]
|
||||
|
||||
if not is_compatible_type(input_type):
|
||||
raise ValueError(
|
||||
"The `input` parameter MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
|
||||
)
|
||||
|
||||
# TODO: return warning if more than one input parameters
|
||||
|
||||
return input_type
|
||||
|
||||
|
||||
def get_output_type(func: Callable) -> Type:
|
||||
"""Returns the output type of a given function (callable).
|
||||
|
||||
Args:
|
||||
func: The function for which to get the output type.
|
||||
|
||||
Raises:
|
||||
ValueError: If the function does not have a valid output type annotation.
|
||||
"""
|
||||
type_hints = get_type_hints(func)
|
||||
if "return" not in type_hints:
|
||||
raise ValueError(
|
||||
"The return type of the callable MUST be annotated with type hints."
|
||||
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
|
||||
)
|
||||
|
||||
output_type = type_hints["return"]
|
||||
|
||||
if not is_compatible_type(output_type):
|
||||
raise ValueError(
|
||||
"The return value MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
|
||||
)
|
||||
|
||||
return output_type
|
||||
|
||||
|
||||
def get_callable(import_string: str) -> Callable:
|
||||
"""Import a callable from an string."""
|
||||
callable_seperator = ":"
|
||||
if callable_seperator not in import_string:
|
||||
# Use dot as seperator
|
||||
callable_seperator = "."
|
||||
|
||||
if callable_seperator not in import_string:
|
||||
raise ValueError("The callable path MUST specify the function. ")
|
||||
|
||||
mod_name, callable_name = import_string.rsplit(callable_seperator, 1)
|
||||
mod = importlib.import_module(mod_name)
|
||||
return getattr(mod, callable_name)
|
||||
|
||||
|
||||
class Opyrator:
|
||||
def __init__(self, func: Union[Callable, str]) -> None:
|
||||
if isinstance(func, str):
|
||||
# Try to load the function from a string notion
|
||||
self.function = get_callable(func)
|
||||
else:
|
||||
self.function = func
|
||||
|
||||
self._action = "Execute"
|
||||
self._input_type = None
|
||||
self._output_type = None
|
||||
|
||||
if not callable(self.function):
|
||||
raise ValueError("The provided function parameters is not a callable.")
|
||||
|
||||
if inspect.isclass(self.function):
|
||||
raise ValueError(
|
||||
"The provided callable is an uninitialized Class. This is not allowed."
|
||||
)
|
||||
|
||||
if inspect.isfunction(self.function):
|
||||
# The provided callable is a function
|
||||
self._input_type = get_input_type(self.function)
|
||||
self._output_type = get_output_type(self.function)
|
||||
|
||||
try:
|
||||
# Get name
|
||||
self._name = name_to_title(self.function.__name__)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Get description from function
|
||||
doc_string = inspect.getdoc(self.function)
|
||||
if doc_string:
|
||||
self._action = doc_string
|
||||
except Exception:
|
||||
pass
|
||||
elif hasattr(self.function, "__call__"):
|
||||
# The provided callable is a function
|
||||
self._input_type = get_input_type(self.function.__call__) # type: ignore
|
||||
self._output_type = get_output_type(self.function.__call__) # type: ignore
|
||||
|
||||
try:
|
||||
# Get name
|
||||
self._name = name_to_title(type(self.function).__name__)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Get action from
|
||||
doc_string = inspect.getdoc(self.function.__call__) # type: ignore
|
||||
if doc_string:
|
||||
self._action = doc_string
|
||||
|
||||
if (
|
||||
not self._action
|
||||
or self._action == "Call"
|
||||
):
|
||||
# Get docstring from class instead of __call__ function
|
||||
doc_string = inspect.getdoc(self.function)
|
||||
if doc_string:
|
||||
self._action = doc_string
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Unknown callable type.")
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def action(self) -> str:
|
||||
return self._action
|
||||
|
||||
@property
|
||||
def input_type(self) -> Any:
|
||||
return self._input_type
|
||||
|
||||
@property
|
||||
def output_type(self) -> Any:
|
||||
return self._output_type
|
||||
|
||||
def __call__(self, input: Any, **kwargs: Any) -> Any:
|
||||
|
||||
input_obj = input
|
||||
|
||||
if isinstance(input, str):
|
||||
# Allow json input
|
||||
input_obj = parse_raw_as(self.input_type, input)
|
||||
|
||||
if isinstance(input, dict):
|
||||
# Allow dict input
|
||||
input_obj = parse_obj_as(self.input_type, input)
|
||||
|
||||
return self.function(input_obj, **kwargs)
|
||||
1
control/mkgui/base/ui/__init__.py
Normal file
1
control/mkgui/base/ui/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .streamlit_ui import render_streamlit_ui
|
||||
129
control/mkgui/base/ui/schema_utils.py
Normal file
129
control/mkgui/base/ui/schema_utils.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def resolve_reference(reference: str, references: Dict) -> Dict:
|
||||
return references[reference.split("/")[-1]]
|
||||
|
||||
|
||||
def get_single_reference_item(property: Dict, references: Dict) -> Dict:
|
||||
# Ref can either be directly in the properties or the first element of allOf
|
||||
reference = property.get("$ref")
|
||||
if reference is None:
|
||||
reference = property["allOf"][0]["$ref"]
|
||||
return resolve_reference(reference, references)
|
||||
|
||||
|
||||
def is_single_string_property(property: Dict) -> bool:
|
||||
return property.get("type") == "string"
|
||||
|
||||
|
||||
def is_single_datetime_property(property: Dict) -> bool:
|
||||
if property.get("type") != "string":
|
||||
return False
|
||||
return property.get("format") in ["date-time", "time", "date"]
|
||||
|
||||
|
||||
def is_single_boolean_property(property: Dict) -> bool:
|
||||
return property.get("type") == "boolean"
|
||||
|
||||
|
||||
def is_single_number_property(property: Dict) -> bool:
|
||||
return property.get("type") in ["integer", "number"]
|
||||
|
||||
|
||||
def is_single_file_property(property: Dict) -> bool:
|
||||
if property.get("type") != "string":
|
||||
return False
|
||||
# TODO: binary?
|
||||
return property.get("format") == "byte"
|
||||
|
||||
|
||||
def is_single_directory_property(property: Dict) -> bool:
|
||||
if property.get("type") != "string":
|
||||
return False
|
||||
return property.get("format") == "path"
|
||||
|
||||
def is_multi_enum_property(property: Dict, references: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
if property.get("uniqueItems") is not True:
|
||||
# Only relevant if it is a set or other datastructures with unique items
|
||||
return False
|
||||
|
||||
try:
|
||||
_ = resolve_reference(property["items"]["$ref"], references)["enum"]
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_single_enum_property(property: Dict, references: Dict) -> bool:
|
||||
try:
|
||||
_ = get_single_reference_item(property, references)["enum"]
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_single_dict_property(property: Dict) -> bool:
|
||||
if property.get("type") != "object":
|
||||
return False
|
||||
return "additionalProperties" in property
|
||||
|
||||
|
||||
def is_single_reference(property: Dict) -> bool:
|
||||
if property.get("type") is not None:
|
||||
return False
|
||||
|
||||
return bool(property.get("$ref"))
|
||||
|
||||
|
||||
def is_multi_file_property(property: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
if property.get("items") is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
# TODO: binary
|
||||
return property["items"]["format"] == "byte"
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_single_object(property: Dict, references: Dict) -> bool:
|
||||
try:
|
||||
object_reference = get_single_reference_item(property, references)
|
||||
if object_reference["type"] != "object":
|
||||
return False
|
||||
return "properties" in object_reference
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_property_list(property: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
if property.get("items") is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
return property["items"]["type"] in ["string", "number", "integer"]
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_object_list_property(property: Dict, references: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
try:
|
||||
object_reference = resolve_reference(property["items"]["$ref"], references)
|
||||
if object_reference["type"] != "object":
|
||||
return False
|
||||
return "properties" in object_reference
|
||||
except Exception:
|
||||
return False
|
||||
887
control/mkgui/base/ui/streamlit_ui.py
Normal file
887
control/mkgui/base/ui/streamlit_ui.py
Normal file
@@ -0,0 +1,887 @@
|
||||
import datetime
|
||||
import inspect
|
||||
import mimetypes
|
||||
import sys
|
||||
from os import getcwd, unlink, path
|
||||
from platform import system
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any, Callable, Dict, List, Type
|
||||
from PIL import Image
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, ValidationError, parse_obj_as
|
||||
|
||||
from control.mkgui.base import Opyrator
|
||||
from control.mkgui.base.core import name_to_title
|
||||
from . import schema_utils
|
||||
from .streamlit_utils import CUSTOM_STREAMLIT_CSS
|
||||
|
||||
STREAMLIT_RUNNER_SNIPPET = """
|
||||
from control.mkgui.base.ui import render_streamlit_ui
|
||||
|
||||
import streamlit as st
|
||||
|
||||
# TODO: Make it configurable
|
||||
# Page config can only be setup once
|
||||
st.set_page_config(
|
||||
page_title="MockingBird",
|
||||
page_icon="🧊",
|
||||
layout="wide")
|
||||
|
||||
render_streamlit_ui()
|
||||
"""
|
||||
|
||||
# with st.spinner("Loading MockingBird GUI. Please wait..."):
|
||||
# opyrator = Opyrator("{opyrator_path}")
|
||||
|
||||
|
||||
def launch_ui(port: int = 8501) -> None:
|
||||
with NamedTemporaryFile(
|
||||
suffix=".py", mode="w", encoding="utf-8", delete=False
|
||||
) as f:
|
||||
f.write(STREAMLIT_RUNNER_SNIPPET)
|
||||
f.seek(0)
|
||||
|
||||
import subprocess
|
||||
|
||||
python_path = f'PYTHONPATH="$PYTHONPATH:{getcwd()}"'
|
||||
if system() == "Windows":
|
||||
python_path = f"set PYTHONPATH=%PYTHONPATH%;{getcwd()} &&"
|
||||
subprocess.run(
|
||||
f"""set STREAMLIT_GLOBAL_SHOW_WARNING_ON_DIRECT_EXECUTION=false""",
|
||||
shell=True,
|
||||
)
|
||||
|
||||
subprocess.run(
|
||||
f"""{python_path} "{sys.executable}" -m streamlit run --server.port={port} --server.headless=True --runner.magicEnabled=False --server.maxUploadSize=50 --browser.gatherUsageStats=False {f.name}""",
|
||||
shell=True,
|
||||
)
|
||||
|
||||
f.close()
|
||||
unlink(f.name)
|
||||
|
||||
|
||||
def function_has_named_arg(func: Callable, parameter: str) -> bool:
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
for param in sig.parameters.values():
|
||||
if param.name == "input":
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def has_output_ui_renderer(data_item: BaseModel) -> bool:
|
||||
return hasattr(data_item, "render_output_ui")
|
||||
|
||||
|
||||
def has_input_ui_renderer(input_class: Type[BaseModel]) -> bool:
|
||||
return hasattr(input_class, "render_input_ui")
|
||||
|
||||
|
||||
def is_compatible_audio(mime_type: str) -> bool:
|
||||
return mime_type in ["audio/mpeg", "audio/ogg", "audio/wav"]
|
||||
|
||||
|
||||
def is_compatible_image(mime_type: str) -> bool:
|
||||
return mime_type in ["image/png", "image/jpeg"]
|
||||
|
||||
|
||||
def is_compatible_video(mime_type: str) -> bool:
|
||||
return mime_type in ["video/mp4"]
|
||||
|
||||
|
||||
class InputUI:
|
||||
def __init__(self, session_state, input_class: Type[BaseModel]):
|
||||
self._session_state = session_state
|
||||
self._input_class = input_class
|
||||
|
||||
self._schema_properties = input_class.schema(by_alias=True).get(
|
||||
"properties", {}
|
||||
)
|
||||
self._schema_references = input_class.schema(by_alias=True).get(
|
||||
"definitions", {}
|
||||
)
|
||||
|
||||
def render_ui(self, streamlit_app_root) -> None:
|
||||
if has_input_ui_renderer(self._input_class):
|
||||
# The input model has a rendering function
|
||||
# The rendering also returns the current state of input data
|
||||
self._session_state.input_data = self._input_class.render_input_ui( # type: ignore
|
||||
st, self._session_state.input_data
|
||||
)
|
||||
return
|
||||
|
||||
# print(self._schema_properties)
|
||||
for property_key in self._schema_properties.keys():
|
||||
property = self._schema_properties[property_key]
|
||||
|
||||
if not property.get("title"):
|
||||
# Set property key as fallback title
|
||||
property["title"] = name_to_title(property_key)
|
||||
|
||||
try:
|
||||
if "input_data" in self._session_state:
|
||||
self._store_value(
|
||||
property_key,
|
||||
self._render_property(streamlit_app_root, property_key, property),
|
||||
)
|
||||
except Exception as e:
|
||||
print("Exception!", e)
|
||||
pass
|
||||
|
||||
def _get_default_streamlit_input_kwargs(self, key: str, property: Dict) -> Dict:
|
||||
streamlit_kwargs = {
|
||||
"label": property.get("title"),
|
||||
"key": key,
|
||||
}
|
||||
|
||||
if property.get("description"):
|
||||
streamlit_kwargs["help"] = property.get("description")
|
||||
return streamlit_kwargs
|
||||
|
||||
def _store_value(self, key: str, value: Any) -> None:
|
||||
data_element = self._session_state.input_data
|
||||
key_elements = key.split(".")
|
||||
for i, key_element in enumerate(key_elements):
|
||||
if i == len(key_elements) - 1:
|
||||
# add value to this element
|
||||
data_element[key_element] = value
|
||||
return
|
||||
if key_element not in data_element:
|
||||
data_element[key_element] = {}
|
||||
data_element = data_element[key_element]
|
||||
|
||||
def _get_value(self, key: str) -> Any:
|
||||
data_element = self._session_state.input_data
|
||||
key_elements = key.split(".")
|
||||
for i, key_element in enumerate(key_elements):
|
||||
if i == len(key_elements) - 1:
|
||||
# add value to this element
|
||||
if key_element not in data_element:
|
||||
return None
|
||||
return data_element[key_element]
|
||||
if key_element not in data_element:
|
||||
data_element[key_element] = {}
|
||||
data_element = data_element[key_element]
|
||||
return None
|
||||
|
||||
def _render_single_datetime_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
if property.get("format") == "time":
|
||||
if property.get("default"):
|
||||
try:
|
||||
streamlit_kwargs["value"] = datetime.time.fromisoformat( # type: ignore
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return streamlit_app.time_input(**streamlit_kwargs)
|
||||
elif property.get("format") == "date":
|
||||
if property.get("default"):
|
||||
try:
|
||||
streamlit_kwargs["value"] = datetime.date.fromisoformat( # type: ignore
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return streamlit_app.date_input(**streamlit_kwargs)
|
||||
elif property.get("format") == "date-time":
|
||||
if property.get("default"):
|
||||
try:
|
||||
streamlit_kwargs["value"] = datetime.datetime.fromisoformat( # type: ignore
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
with streamlit_app.container():
|
||||
streamlit_app.subheader(streamlit_kwargs.get("label"))
|
||||
if streamlit_kwargs.get("description"):
|
||||
streamlit_app.text(streamlit_kwargs.get("description"))
|
||||
selected_date = None
|
||||
selected_time = None
|
||||
date_col, time_col = streamlit_app.columns(2)
|
||||
with date_col:
|
||||
date_kwargs = {"label": "Date", "key": key + "-date-input"}
|
||||
if streamlit_kwargs.get("value"):
|
||||
try:
|
||||
date_kwargs["value"] = streamlit_kwargs.get( # type: ignore
|
||||
"value"
|
||||
).date()
|
||||
except Exception:
|
||||
pass
|
||||
selected_date = streamlit_app.date_input(**date_kwargs)
|
||||
|
||||
with time_col:
|
||||
time_kwargs = {"label": "Time", "key": key + "-time-input"}
|
||||
if streamlit_kwargs.get("value"):
|
||||
try:
|
||||
time_kwargs["value"] = streamlit_kwargs.get( # type: ignore
|
||||
"value"
|
||||
).time()
|
||||
except Exception:
|
||||
pass
|
||||
selected_time = streamlit_app.time_input(**time_kwargs)
|
||||
return datetime.datetime.combine(selected_date, selected_time)
|
||||
else:
|
||||
streamlit_app.warning(
|
||||
"Date format is not supported: " + str(property.get("format"))
|
||||
)
|
||||
|
||||
def _render_single_file_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
file_extension = None
|
||||
if "mime_type" in property:
|
||||
file_extension = mimetypes.guess_extension(property["mime_type"])
|
||||
|
||||
uploaded_file = streamlit_app.file_uploader(
|
||||
**streamlit_kwargs, accept_multiple_files=False, type=file_extension
|
||||
)
|
||||
if uploaded_file is None:
|
||||
return None
|
||||
|
||||
bytes = uploaded_file.getvalue()
|
||||
if property.get("mime_type"):
|
||||
if is_compatible_audio(property["mime_type"]):
|
||||
# Show audio
|
||||
streamlit_app.audio(bytes, format=property.get("mime_type"))
|
||||
if is_compatible_image(property["mime_type"]):
|
||||
# Show image
|
||||
streamlit_app.image(bytes)
|
||||
if is_compatible_video(property["mime_type"]):
|
||||
# Show video
|
||||
streamlit_app.video(bytes, format=property.get("mime_type"))
|
||||
return bytes
|
||||
|
||||
def _render_single_string_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
if property.get("default"):
|
||||
streamlit_kwargs["value"] = property.get("default")
|
||||
elif property.get("example"):
|
||||
# TODO: also use example for other property types
|
||||
# Use example as value if it is provided
|
||||
streamlit_kwargs["value"] = property.get("example")
|
||||
|
||||
if property.get("maxLength") is not None:
|
||||
streamlit_kwargs["max_chars"] = property.get("maxLength")
|
||||
|
||||
if (
|
||||
property.get("format")
|
||||
or (
|
||||
property.get("maxLength") is not None
|
||||
and int(property.get("maxLength")) < 140 # type: ignore
|
||||
)
|
||||
or property.get("writeOnly")
|
||||
):
|
||||
# If any format is set, use single text input
|
||||
# If max chars is set to less than 140, use single text input
|
||||
# If write only -> password field
|
||||
if property.get("writeOnly"):
|
||||
streamlit_kwargs["type"] = "password"
|
||||
return streamlit_app.text_input(**streamlit_kwargs)
|
||||
else:
|
||||
# Otherwise use multiline text area
|
||||
return streamlit_app.text_area(**streamlit_kwargs)
|
||||
|
||||
def _render_multi_enum_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
reference_item = schema_utils.resolve_reference(
|
||||
property["items"]["$ref"], self._schema_references
|
||||
)
|
||||
# TODO: how to select defaults
|
||||
return streamlit_app.multiselect(
|
||||
**streamlit_kwargs, options=reference_item["enum"]
|
||||
)
|
||||
|
||||
def _render_single_enum_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
reference_item = schema_utils.get_single_reference_item(
|
||||
property, self._schema_references
|
||||
)
|
||||
|
||||
if property.get("default") is not None:
|
||||
try:
|
||||
streamlit_kwargs["index"] = reference_item["enum"].index(
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
# Use default selection
|
||||
pass
|
||||
|
||||
return streamlit_app.selectbox(
|
||||
**streamlit_kwargs, options=reference_item["enum"]
|
||||
)
|
||||
|
||||
def _render_single_dict_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
# Add title and subheader
|
||||
streamlit_app.subheader(property.get("title"))
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
current_dict = self._get_value(key)
|
||||
if not current_dict:
|
||||
current_dict = {}
|
||||
|
||||
key_col, value_col = streamlit_app.columns(2)
|
||||
|
||||
with key_col:
|
||||
updated_key = streamlit_app.text_input(
|
||||
"Key", value="", key=key + "-new-key"
|
||||
)
|
||||
|
||||
with value_col:
|
||||
# TODO: also add boolean?
|
||||
value_kwargs = {"label": "Value", "key": key + "-new-value"}
|
||||
if property["additionalProperties"].get("type") == "integer":
|
||||
value_kwargs["value"] = 0 # type: ignore
|
||||
updated_value = streamlit_app.number_input(**value_kwargs)
|
||||
elif property["additionalProperties"].get("type") == "number":
|
||||
value_kwargs["value"] = 0.0 # type: ignore
|
||||
value_kwargs["format"] = "%f"
|
||||
updated_value = streamlit_app.number_input(**value_kwargs)
|
||||
else:
|
||||
value_kwargs["value"] = ""
|
||||
updated_value = streamlit_app.text_input(**value_kwargs)
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
with streamlit_app.container():
|
||||
clear_col, add_col = streamlit_app.columns([1, 2])
|
||||
|
||||
with clear_col:
|
||||
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
||||
current_dict = {}
|
||||
|
||||
with add_col:
|
||||
if (
|
||||
streamlit_app.button("Add Item", key=key + "-add-item")
|
||||
and updated_key
|
||||
):
|
||||
current_dict[updated_key] = updated_value
|
||||
|
||||
streamlit_app.write(current_dict)
|
||||
|
||||
return current_dict
|
||||
|
||||
def _render_single_reference(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
reference_item = schema_utils.get_single_reference_item(
|
||||
property, self._schema_references
|
||||
)
|
||||
return self._render_property(streamlit_app, key, reference_item)
|
||||
|
||||
def _render_multi_file_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
file_extension = None
|
||||
if "mime_type" in property:
|
||||
file_extension = mimetypes.guess_extension(property["mime_type"])
|
||||
|
||||
uploaded_files = streamlit_app.file_uploader(
|
||||
**streamlit_kwargs, accept_multiple_files=True, type=file_extension
|
||||
)
|
||||
uploaded_files_bytes = []
|
||||
if uploaded_files:
|
||||
for uploaded_file in uploaded_files:
|
||||
uploaded_files_bytes.append(uploaded_file.read())
|
||||
return uploaded_files_bytes
|
||||
|
||||
def _render_single_boolean_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
if property.get("default"):
|
||||
streamlit_kwargs["value"] = property.get("default")
|
||||
return streamlit_app.checkbox(**streamlit_kwargs)
|
||||
|
||||
def _render_single_number_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
number_transform = int
|
||||
if property.get("type") == "number":
|
||||
number_transform = float # type: ignore
|
||||
streamlit_kwargs["format"] = "%f"
|
||||
|
||||
if "multipleOf" in property:
|
||||
# Set stepcount based on multiple of parameter
|
||||
streamlit_kwargs["step"] = number_transform(property["multipleOf"])
|
||||
elif number_transform == int:
|
||||
# Set step size to 1 as default
|
||||
streamlit_kwargs["step"] = 1
|
||||
elif number_transform == float:
|
||||
# Set step size to 0.01 as default
|
||||
# TODO: adapt to default value
|
||||
streamlit_kwargs["step"] = 0.01
|
||||
|
||||
if "minimum" in property:
|
||||
streamlit_kwargs["min_value"] = number_transform(property["minimum"])
|
||||
if "exclusiveMinimum" in property:
|
||||
streamlit_kwargs["min_value"] = number_transform(
|
||||
property["exclusiveMinimum"] + streamlit_kwargs["step"]
|
||||
)
|
||||
if "maximum" in property:
|
||||
streamlit_kwargs["max_value"] = number_transform(property["maximum"])
|
||||
|
||||
if "exclusiveMaximum" in property:
|
||||
streamlit_kwargs["max_value"] = number_transform(
|
||||
property["exclusiveMaximum"] - streamlit_kwargs["step"]
|
||||
)
|
||||
|
||||
if property.get("default") is not None:
|
||||
streamlit_kwargs["value"] = number_transform(property.get("default")) # type: ignore
|
||||
else:
|
||||
if "min_value" in streamlit_kwargs:
|
||||
streamlit_kwargs["value"] = streamlit_kwargs["min_value"]
|
||||
elif number_transform == int:
|
||||
streamlit_kwargs["value"] = 0
|
||||
else:
|
||||
# Set default value to step
|
||||
streamlit_kwargs["value"] = number_transform(streamlit_kwargs["step"])
|
||||
|
||||
if "min_value" in streamlit_kwargs and "max_value" in streamlit_kwargs:
|
||||
# TODO: Only if less than X steps
|
||||
return streamlit_app.slider(**streamlit_kwargs)
|
||||
else:
|
||||
return streamlit_app.number_input(**streamlit_kwargs)
|
||||
|
||||
def _render_object_input(self, streamlit_app: st, key: str, property: Dict) -> Any:
|
||||
properties = property["properties"]
|
||||
object_inputs = {}
|
||||
for property_key in properties:
|
||||
property = properties[property_key]
|
||||
if not property.get("title"):
|
||||
# Set property key as fallback title
|
||||
property["title"] = name_to_title(property_key)
|
||||
# construct full key based on key parts -> required later to get the value
|
||||
full_key = key + "." + property_key
|
||||
object_inputs[property_key] = self._render_property(
|
||||
streamlit_app, full_key, property
|
||||
)
|
||||
return object_inputs
|
||||
|
||||
def _render_single_object_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
# Add title and subheader
|
||||
title = property.get("title")
|
||||
streamlit_app.subheader(title)
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
object_reference = schema_utils.get_single_reference_item(
|
||||
property, self._schema_references
|
||||
)
|
||||
return self._render_object_input(streamlit_app, key, object_reference)
|
||||
|
||||
def _render_property_list_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
# Add title and subheader
|
||||
streamlit_app.subheader(property.get("title"))
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
current_list = self._get_value(key)
|
||||
if not current_list:
|
||||
current_list = []
|
||||
|
||||
value_kwargs = {"label": "Value", "key": key + "-new-value"}
|
||||
if property["items"]["type"] == "integer":
|
||||
value_kwargs["value"] = 0 # type: ignore
|
||||
new_value = streamlit_app.number_input(**value_kwargs)
|
||||
elif property["items"]["type"] == "number":
|
||||
value_kwargs["value"] = 0.0 # type: ignore
|
||||
value_kwargs["format"] = "%f"
|
||||
new_value = streamlit_app.number_input(**value_kwargs)
|
||||
else:
|
||||
value_kwargs["value"] = ""
|
||||
new_value = streamlit_app.text_input(**value_kwargs)
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
with streamlit_app.container():
|
||||
clear_col, add_col = streamlit_app.columns([1, 2])
|
||||
|
||||
with clear_col:
|
||||
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
||||
current_list = []
|
||||
|
||||
with add_col:
|
||||
if (
|
||||
streamlit_app.button("Add Item", key=key + "-add-item")
|
||||
and new_value is not None
|
||||
):
|
||||
current_list.append(new_value)
|
||||
|
||||
streamlit_app.write(current_list)
|
||||
|
||||
return current_list
|
||||
|
||||
def _render_object_list_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
# TODO: support max_items, and min_items properties
|
||||
|
||||
# Add title and subheader
|
||||
streamlit_app.subheader(property.get("title"))
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
current_list = self._get_value(key)
|
||||
if not current_list:
|
||||
current_list = []
|
||||
|
||||
object_reference = schema_utils.resolve_reference(
|
||||
property["items"]["$ref"], self._schema_references
|
||||
)
|
||||
input_data = self._render_object_input(streamlit_app, key, object_reference)
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
with streamlit_app.container():
|
||||
clear_col, add_col = streamlit_app.columns([1, 2])
|
||||
|
||||
with clear_col:
|
||||
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
||||
current_list = []
|
||||
|
||||
with add_col:
|
||||
if (
|
||||
streamlit_app.button("Add Item", key=key + "-add-item")
|
||||
and input_data
|
||||
):
|
||||
current_list.append(input_data)
|
||||
|
||||
streamlit_app.write(current_list)
|
||||
return current_list
|
||||
|
||||
def _render_property(self, streamlit_app: st, key: str, property: Dict) -> Any:
|
||||
if schema_utils.is_single_enum_property(property, self._schema_references):
|
||||
return self._render_single_enum_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_multi_enum_property(property, self._schema_references):
|
||||
return self._render_multi_enum_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_file_property(property):
|
||||
return self._render_single_file_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_multi_file_property(property):
|
||||
return self._render_multi_file_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_datetime_property(property):
|
||||
return self._render_single_datetime_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_boolean_property(property):
|
||||
return self._render_single_boolean_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_dict_property(property):
|
||||
return self._render_single_dict_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_number_property(property):
|
||||
return self._render_single_number_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_string_property(property):
|
||||
return self._render_single_string_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_object(property, self._schema_references):
|
||||
return self._render_single_object_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_object_list_property(property, self._schema_references):
|
||||
return self._render_object_list_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_property_list(property):
|
||||
return self._render_property_list_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_reference(property):
|
||||
return self._render_single_reference(streamlit_app, key, property)
|
||||
|
||||
streamlit_app.warning(
|
||||
"The type of the following property is currently not supported: "
|
||||
+ str(property.get("title"))
|
||||
)
|
||||
raise Exception("Unsupported property")
|
||||
|
||||
|
||||
class OutputUI:
|
||||
def __init__(self, output_data: Any, input_data: Any):
|
||||
self._output_data = output_data
|
||||
self._input_data = input_data
|
||||
|
||||
def render_ui(self, streamlit_app) -> None:
|
||||
try:
|
||||
if isinstance(self._output_data, BaseModel):
|
||||
self._render_single_output(streamlit_app, self._output_data)
|
||||
return
|
||||
if type(self._output_data) == list:
|
||||
self._render_list_output(streamlit_app, self._output_data)
|
||||
return
|
||||
except Exception as ex:
|
||||
streamlit_app.exception(ex)
|
||||
# Fallback to
|
||||
streamlit_app.json(jsonable_encoder(self._output_data))
|
||||
|
||||
def _render_single_text_property(
|
||||
self, streamlit: st, property_schema: Dict, value: Any
|
||||
) -> None:
|
||||
# Add title and subheader
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
if value is None or value == "":
|
||||
streamlit.info("No value returned!")
|
||||
else:
|
||||
streamlit.code(str(value), language="plain")
|
||||
|
||||
def _render_single_file_property(
|
||||
self, streamlit: st, property_schema: Dict, value: Any
|
||||
) -> None:
|
||||
# Add title and subheader
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
if value is None or value == "":
|
||||
streamlit.info("No value returned!")
|
||||
else:
|
||||
# TODO: Detect if it is a FileContent instance
|
||||
# TODO: detect if it is base64
|
||||
file_extension = ""
|
||||
if "mime_type" in property_schema:
|
||||
mime_type = property_schema["mime_type"]
|
||||
file_extension = mimetypes.guess_extension(mime_type) or ""
|
||||
|
||||
if is_compatible_audio(mime_type):
|
||||
streamlit.audio(value.as_bytes(), format=mime_type)
|
||||
return
|
||||
|
||||
if is_compatible_image(mime_type):
|
||||
streamlit.image(value.as_bytes())
|
||||
return
|
||||
|
||||
if is_compatible_video(mime_type):
|
||||
streamlit.video(value.as_bytes(), format=mime_type)
|
||||
return
|
||||
|
||||
filename = (
|
||||
(property_schema["title"] + file_extension)
|
||||
.lower()
|
||||
.strip()
|
||||
.replace(" ", "-")
|
||||
)
|
||||
streamlit.markdown(
|
||||
f'<a href="data:application/octet-stream;base64,{value}" download="{filename}"><input type="button" value="Download File"></a>',
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
def _render_single_complex_property(
|
||||
self, streamlit: st, property_schema: Dict, value: Any
|
||||
) -> None:
|
||||
# Add title and subheader
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
|
||||
streamlit.json(jsonable_encoder(value))
|
||||
|
||||
def _render_single_output(self, streamlit: st, output_data: BaseModel) -> None:
|
||||
try:
|
||||
if has_output_ui_renderer(output_data):
|
||||
if function_has_named_arg(output_data.render_output_ui, "input"): # type: ignore
|
||||
# render method also requests the input data
|
||||
output_data.render_output_ui(streamlit, input=self._input_data) # type: ignore
|
||||
else:
|
||||
output_data.render_output_ui(streamlit) # type: ignore
|
||||
return
|
||||
except Exception:
|
||||
# Use default auto-generation methods if the custom rendering throws an exception
|
||||
logger.exception(
|
||||
"Failed to execute custom render_output_ui function. Using auto-generation instead"
|
||||
)
|
||||
|
||||
model_schema = output_data.schema(by_alias=False)
|
||||
model_properties = model_schema.get("properties")
|
||||
definitions = model_schema.get("definitions")
|
||||
|
||||
if model_properties:
|
||||
for property_key in output_data.__dict__:
|
||||
property_schema = model_properties.get(property_key)
|
||||
if not property_schema.get("title"):
|
||||
# Set property key as fallback title
|
||||
property_schema["title"] = property_key
|
||||
|
||||
output_property_value = output_data.__dict__[property_key]
|
||||
|
||||
if has_output_ui_renderer(output_property_value):
|
||||
output_property_value.render_output_ui(streamlit) # type: ignore
|
||||
continue
|
||||
|
||||
if isinstance(output_property_value, BaseModel):
|
||||
# Render output recursivly
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
self._render_single_output(streamlit, output_property_value)
|
||||
continue
|
||||
|
||||
if property_schema:
|
||||
if schema_utils.is_single_file_property(property_schema):
|
||||
self._render_single_file_property(
|
||||
streamlit, property_schema, output_property_value
|
||||
)
|
||||
continue
|
||||
|
||||
if (
|
||||
schema_utils.is_single_string_property(property_schema)
|
||||
or schema_utils.is_single_number_property(property_schema)
|
||||
or schema_utils.is_single_datetime_property(property_schema)
|
||||
or schema_utils.is_single_boolean_property(property_schema)
|
||||
):
|
||||
self._render_single_text_property(
|
||||
streamlit, property_schema, output_property_value
|
||||
)
|
||||
continue
|
||||
if definitions and schema_utils.is_single_enum_property(
|
||||
property_schema, definitions
|
||||
):
|
||||
self._render_single_text_property(
|
||||
streamlit, property_schema, output_property_value.value
|
||||
)
|
||||
continue
|
||||
|
||||
# TODO: render dict as table
|
||||
|
||||
self._render_single_complex_property(
|
||||
streamlit, property_schema, output_property_value
|
||||
)
|
||||
return
|
||||
|
||||
def _render_list_output(self, streamlit: st, output_data: List) -> None:
|
||||
try:
|
||||
data_items: List = []
|
||||
for data_item in output_data:
|
||||
if has_output_ui_renderer(data_item):
|
||||
# Render using the render function
|
||||
data_item.render_output_ui(streamlit) # type: ignore
|
||||
continue
|
||||
data_items.append(data_item.dict())
|
||||
# Try to show as dataframe
|
||||
streamlit.table(pd.DataFrame(data_items))
|
||||
except Exception:
|
||||
# Fallback to
|
||||
streamlit.json(jsonable_encoder(output_data))
|
||||
|
||||
|
||||
def getOpyrator(mode: str) -> Opyrator:
|
||||
if mode == None or mode.startswith('VC'):
|
||||
from control.mkgui.app_vc import convert
|
||||
return Opyrator(convert)
|
||||
if mode == None or mode.startswith('预处理'):
|
||||
from control.mkgui.preprocess import preprocess
|
||||
return Opyrator(preprocess)
|
||||
if mode == None or mode.startswith('模型训练'):
|
||||
from control.mkgui.train import train
|
||||
return Opyrator(train)
|
||||
if mode == None or mode.startswith('模型训练(VC)'):
|
||||
from control.mkgui.train_vc import train_vc
|
||||
return Opyrator(train_vc)
|
||||
from control.mkgui.app import synthesize
|
||||
return Opyrator(synthesize)
|
||||
|
||||
|
||||
def render_streamlit_ui() -> None:
|
||||
# init
|
||||
session_state = st.session_state
|
||||
session_state.input_data = {}
|
||||
# Add custom css settings
|
||||
st.markdown(f"<style>{CUSTOM_STREAMLIT_CSS}</style>", unsafe_allow_html=True)
|
||||
|
||||
with st.spinner("Loading MockingBird GUI. Please wait..."):
|
||||
session_state.mode = st.sidebar.selectbox(
|
||||
'模式选择',
|
||||
( "AI拟音", "VC拟音", "预处理", "模型训练", "模型训练(VC)")
|
||||
)
|
||||
if "mode" in session_state:
|
||||
mode = session_state.mode
|
||||
else:
|
||||
mode = ""
|
||||
opyrator = getOpyrator(mode)
|
||||
title = opyrator.name + mode
|
||||
|
||||
col1, col2, _ = st.columns(3)
|
||||
col2.title(title)
|
||||
col2.markdown("欢迎使用MockingBird Web 2")
|
||||
|
||||
image = Image.open(path.join('control','mkgui', 'static', 'mb.png'))
|
||||
col1.image(image)
|
||||
|
||||
st.markdown("---")
|
||||
left, right = st.columns([0.4, 0.6])
|
||||
|
||||
with left:
|
||||
st.header("Control 控制")
|
||||
InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(st)
|
||||
execute_selected = st.button(opyrator.action)
|
||||
if execute_selected:
|
||||
with st.spinner("Executing operation. Please wait..."):
|
||||
try:
|
||||
input_data_obj = parse_obj_as(
|
||||
opyrator.input_type, session_state.input_data
|
||||
)
|
||||
session_state.output_data = opyrator(input=input_data_obj)
|
||||
session_state.latest_operation_input = input_data_obj # should this really be saved as additional session object?
|
||||
except ValidationError as ex:
|
||||
st.error(ex)
|
||||
else:
|
||||
# st.success("Operation executed successfully.")
|
||||
pass
|
||||
|
||||
with right:
|
||||
st.header("Result 结果")
|
||||
if 'output_data' in session_state:
|
||||
OutputUI(
|
||||
session_state.output_data, session_state.latest_operation_input
|
||||
).render_ui(st)
|
||||
if st.button("Clear"):
|
||||
# Clear all state
|
||||
for key in st.session_state.keys():
|
||||
del st.session_state[key]
|
||||
session_state.input_data = {}
|
||||
st.experimental_rerun()
|
||||
else:
|
||||
# placeholder
|
||||
st.caption("请使用左侧控制板进行输入并运行获得结果")
|
||||
|
||||
|
||||
13
control/mkgui/base/ui/streamlit_utils.py
Normal file
13
control/mkgui/base/ui/streamlit_utils.py
Normal file
@@ -0,0 +1,13 @@
|
||||
CUSTOM_STREAMLIT_CSS = """
|
||||
div[data-testid="stBlock"] button {
|
||||
width: 100% !important;
|
||||
margin-bottom: 20px !important;
|
||||
border-color: #bfbfbf !important;
|
||||
}
|
||||
section[data-testid="stSidebar"] div {
|
||||
max-width: 10rem;
|
||||
}
|
||||
pre code {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
"""
|
||||
96
control/mkgui/preprocess.py
Normal file
96
control/mkgui/preprocess.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from typing import Any, Tuple
|
||||
|
||||
|
||||
# Constants
|
||||
EXT_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg_extractor"
|
||||
ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
|
||||
|
||||
|
||||
if os.path.isdir(EXT_MODELS_DIRT):
|
||||
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded extractor models: " + str(len(extractors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(ENC_MODELS_DIRT):
|
||||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded encoders models: " + str(len(encoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
class Model(str, Enum):
|
||||
VC_PPG2MEL = "ppg2mel"
|
||||
|
||||
class Dataset(str, Enum):
|
||||
AIDATATANG_200ZH = "aidatatang_200zh"
|
||||
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
|
||||
|
||||
class Input(BaseModel):
|
||||
# def render_input_ui(st, input) -> Dict:
|
||||
# input["selected_dataset"] = st.selectbox(
|
||||
# '选择数据集',
|
||||
# ("aidatatang_200zh", "aidatatang_200zh_s")
|
||||
# )
|
||||
# return input
|
||||
model: Model = Field(
|
||||
Model.VC_PPG2MEL, title="目标模型",
|
||||
)
|
||||
dataset: Dataset = Field(
|
||||
Dataset.AIDATATANG_200ZH, title="数据集选择",
|
||||
)
|
||||
datasets_root: str = Field(
|
||||
..., alias="数据集根目录", description="输入数据集根目录(相对/绝对)",
|
||||
format=True,
|
||||
example="..\\trainning_data\\"
|
||||
)
|
||||
output_root: str = Field(
|
||||
..., alias="输出根目录", description="输出结果根目录(相对/绝对)",
|
||||
format=True,
|
||||
example="..\\trainning_data\\"
|
||||
)
|
||||
n_processes: int = Field(
|
||||
2, alias="处理线程数", description="根据CPU线程数来设置",
|
||||
le=32, ge=1
|
||||
)
|
||||
extractor: extractors = Field(
|
||||
..., alias="特征提取模型",
|
||||
description="选择PPG特征提取模型文件."
|
||||
)
|
||||
encoder: encoders = Field(
|
||||
..., alias="语音编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: Tuple[str, int]
|
||||
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
sr, count = self.__root__
|
||||
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
|
||||
|
||||
def preprocess(input: Input) -> Output:
|
||||
"""Preprocess(预处理)"""
|
||||
finished = 0
|
||||
if input.model == Model.VC_PPG2MEL:
|
||||
from models.ppg2mel.preprocess import preprocess_dataset
|
||||
finished = preprocess_dataset(
|
||||
datasets_root=Path(input.datasets_root),
|
||||
dataset=input.dataset,
|
||||
out_dir=Path(input.output_root),
|
||||
n_processes=input.n_processes,
|
||||
ppg_encoder_model_fpath=Path(input.extractor.value),
|
||||
speaker_encoder_model=Path(input.encoder.value)
|
||||
)
|
||||
# TODO: pass useful return code
|
||||
return Output(__root__=(input.dataset, finished))
|
||||
BIN
control/mkgui/static/mb.png
Normal file
BIN
control/mkgui/static/mb.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.6 KiB |
106
control/mkgui/train.py
Normal file
106
control/mkgui/train.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from models.synthesizer.hparams import hparams
|
||||
from models.synthesizer.train import train as synt_train
|
||||
|
||||
# Constants
|
||||
SYN_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}synthesizer"
|
||||
ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
|
||||
|
||||
|
||||
# EXT_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg_extractor"
|
||||
# CONV_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg2mel"
|
||||
# ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
|
||||
|
||||
# Pre-Load models
|
||||
if os.path.isdir(SYN_MODELS_DIRT):
|
||||
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
||||
else:
|
||||
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(ENC_MODELS_DIRT):
|
||||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded encoders models: " + str(len(encoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
class Model(str, Enum):
|
||||
DEFAULT = "default"
|
||||
|
||||
class Input(BaseModel):
|
||||
model: Model = Field(
|
||||
Model.DEFAULT, title="模型类型",
|
||||
)
|
||||
# datasets_root: str = Field(
|
||||
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
|
||||
# format=True,
|
||||
# example="..\\trainning_data\\"
|
||||
# )
|
||||
input_root: str = Field(
|
||||
..., alias="输入目录", description="预处理数据根目录",
|
||||
format=True,
|
||||
example=f"..{os.sep}audiodata{os.sep}SV2TTS{os.sep}synthesizer"
|
||||
)
|
||||
run_id: str = Field(
|
||||
"", alias="新模型名/运行ID", description="使用新ID进行重新训练,否则选择下面的模型进行继续训练",
|
||||
)
|
||||
synthesizer: synthesizers = Field(
|
||||
..., alias="已有合成模型",
|
||||
description="选择语音合成模型文件."
|
||||
)
|
||||
gpu: bool = Field(
|
||||
True, alias="GPU训练", description="选择“是”,则使用GPU训练",
|
||||
)
|
||||
verbose: bool = Field(
|
||||
True, alias="打印详情", description="选择“是”,输出更多详情",
|
||||
)
|
||||
encoder: encoders = Field(
|
||||
..., alias="语音编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
save_every: int = Field(
|
||||
1000, alias="更新间隔", description="每隔n步则更新一次模型",
|
||||
)
|
||||
backup_every: int = Field(
|
||||
10000, alias="保存间隔", description="每隔n步则保存一次模型",
|
||||
)
|
||||
log_every: int = Field(
|
||||
500, alias="打印间隔", description="每隔n步则打印一次训练统计",
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: int
|
||||
|
||||
def render_output_ui(self, streamlit_app) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
streamlit_app.subheader(f"Training started with code: {self.__root__}")
|
||||
|
||||
def train(input: Input) -> Output:
|
||||
"""Train(训练)"""
|
||||
|
||||
print(">>> Start training ...")
|
||||
force_restart = len(input.run_id) > 0
|
||||
if not force_restart:
|
||||
input.run_id = Path(input.synthesizer.value).name.split('.')[0]
|
||||
|
||||
synt_train(
|
||||
input.run_id,
|
||||
input.input_root,
|
||||
f"data{os.sep}ckpt{os.sep}synthesizer",
|
||||
input.save_every,
|
||||
input.backup_every,
|
||||
input.log_every,
|
||||
force_restart,
|
||||
hparams
|
||||
)
|
||||
return Output(__root__=0)
|
||||
155
control/mkgui/train_vc.py
Normal file
155
control/mkgui/train_vc.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from typing import Any, Tuple
|
||||
import numpy as np
|
||||
from utils.load_yaml import HpsYaml
|
||||
from utils.util import AttrDict
|
||||
import torch
|
||||
|
||||
# Constants
|
||||
EXT_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg_extractor"
|
||||
CONV_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg2mel"
|
||||
ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
|
||||
|
||||
|
||||
if os.path.isdir(EXT_MODELS_DIRT):
|
||||
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded extractor models: " + str(len(extractors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(CONV_MODELS_DIRT):
|
||||
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
|
||||
print("Loaded convertor models: " + str(len(convertors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(ENC_MODELS_DIRT):
|
||||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded encoders models: " + str(len(encoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
class Model(str, Enum):
|
||||
VC_PPG2MEL = "ppg2mel"
|
||||
|
||||
class Dataset(str, Enum):
|
||||
AIDATATANG_200ZH = "aidatatang_200zh"
|
||||
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
|
||||
|
||||
class Input(BaseModel):
|
||||
# def render_input_ui(st, input) -> Dict:
|
||||
# input["selected_dataset"] = st.selectbox(
|
||||
# '选择数据集',
|
||||
# ("aidatatang_200zh", "aidatatang_200zh_s")
|
||||
# )
|
||||
# return input
|
||||
model: Model = Field(
|
||||
Model.VC_PPG2MEL, title="模型类型",
|
||||
)
|
||||
# datasets_root: str = Field(
|
||||
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
|
||||
# format=True,
|
||||
# example="..\\trainning_data\\"
|
||||
# )
|
||||
output_root: str = Field(
|
||||
..., alias="输出目录(可选)", description="建议不填,保持默认",
|
||||
format=True,
|
||||
example=""
|
||||
)
|
||||
continue_mode: bool = Field(
|
||||
True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练",
|
||||
)
|
||||
gpu: bool = Field(
|
||||
True, alias="GPU训练", description="选择“是”,则使用GPU训练",
|
||||
)
|
||||
verbose: bool = Field(
|
||||
True, alias="打印详情", description="选择“是”,输出更多详情",
|
||||
)
|
||||
# TODO: Move to hiden fields by default
|
||||
convertor: convertors = Field(
|
||||
..., alias="转换模型",
|
||||
description="选择语音转换模型文件."
|
||||
)
|
||||
extractor: extractors = Field(
|
||||
..., alias="特征提取模型",
|
||||
description="选择PPG特征提取模型文件."
|
||||
)
|
||||
encoder: encoders = Field(
|
||||
..., alias="语音编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
njobs: int = Field(
|
||||
8, alias="进程数", description="适用于ppg2mel",
|
||||
)
|
||||
seed: int = Field(
|
||||
default=0, alias="初始随机数", description="适用于ppg2mel",
|
||||
)
|
||||
model_name: str = Field(
|
||||
..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效",
|
||||
example="test"
|
||||
)
|
||||
model_config: str = Field(
|
||||
..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效",
|
||||
example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2"
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: Tuple[str, int]
|
||||
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
sr, count = self.__root__
|
||||
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
|
||||
|
||||
def train_vc(input: Input) -> Output:
|
||||
"""Train VC(训练 VC)"""
|
||||
|
||||
print(">>> OneShot VC training ...")
|
||||
params = AttrDict()
|
||||
params.update({
|
||||
"gpu": input.gpu,
|
||||
"cpu": not input.gpu,
|
||||
"njobs": input.njobs,
|
||||
"seed": input.seed,
|
||||
"verbose": input.verbose,
|
||||
"load": input.convertor.value,
|
||||
"warm_start": False,
|
||||
})
|
||||
if input.continue_mode:
|
||||
# trace old model and config
|
||||
p = Path(input.convertor.value)
|
||||
params.name = p.parent.name
|
||||
# search a config file
|
||||
model_config_fpaths = list(p.parent.rglob("*.yaml"))
|
||||
if len(model_config_fpaths) == 0:
|
||||
raise "No model yaml config found for convertor"
|
||||
config = HpsYaml(model_config_fpaths[0])
|
||||
params.ckpdir = p.parent.parent
|
||||
params.config = model_config_fpaths[0]
|
||||
params.logdir = os.path.join(p.parent, "log")
|
||||
else:
|
||||
# Make the config dict dot visitable
|
||||
config = HpsYaml(input.config)
|
||||
np.random.seed(input.seed)
|
||||
torch.manual_seed(input.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(input.seed)
|
||||
mode = "train"
|
||||
from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||
solver = Solver(config, params, mode)
|
||||
solver.load_data()
|
||||
solver.set_model()
|
||||
solver.exec()
|
||||
print(">>> Oneshot VC train finished!")
|
||||
|
||||
# TODO: pass useful return code
|
||||
return Output(__root__=(input.dataset, 0))
|
||||
475
control/toolbox/__init__.py
Normal file
475
control/toolbox/__init__.py
Normal file
@@ -0,0 +1,475 @@
|
||||
from control.toolbox.ui import UI
|
||||
from models.encoder import inference as encoder
|
||||
from models.synthesizer.inference import Synthesizer
|
||||
from models.vocoder.wavernn import inference as rnn_vocoder
|
||||
from models.vocoder.hifigan import inference as gan_vocoder
|
||||
from models.vocoder.fregan import inference as fgan_vocoder
|
||||
from pathlib import Path
|
||||
from time import perf_counter as timer
|
||||
from control.toolbox.utterance import Utterance
|
||||
import numpy as np
|
||||
import traceback
|
||||
import sys
|
||||
import torch
|
||||
import re
|
||||
|
||||
# 默认使用wavernn
|
||||
vocoder = rnn_vocoder
|
||||
|
||||
# Use this directory structure for your datasets, or modify it to fit your needs
|
||||
recognized_datasets = [
|
||||
"LibriSpeech/dev-clean",
|
||||
"LibriSpeech/dev-other",
|
||||
"LibriSpeech/test-clean",
|
||||
"LibriSpeech/test-other",
|
||||
"LibriSpeech/train-clean-100",
|
||||
"LibriSpeech/train-clean-360",
|
||||
"LibriSpeech/train-other-500",
|
||||
"LibriTTS/dev-clean",
|
||||
"LibriTTS/dev-other",
|
||||
"LibriTTS/test-clean",
|
||||
"LibriTTS/test-other",
|
||||
"LibriTTS/train-clean-100",
|
||||
"LibriTTS/train-clean-360",
|
||||
"LibriTTS/train-other-500",
|
||||
"LJSpeech-1.1",
|
||||
"VoxCeleb1/wav",
|
||||
"VoxCeleb1/test_wav",
|
||||
"VoxCeleb2/dev/aac",
|
||||
"VoxCeleb2/test/aac",
|
||||
"VCTK-Corpus/wav48",
|
||||
"aidatatang_200zh/corpus",
|
||||
"aishell3/test/wav",
|
||||
"magicdata/train",
|
||||
]
|
||||
|
||||
#Maximum of generated wavs to keep on memory
|
||||
MAX_WAVES = 15
|
||||
|
||||
class Toolbox:
|
||||
def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, extractor_models_dir, convertor_models_dir, seed, no_mp3_support, vc_mode):
|
||||
self.no_mp3_support = no_mp3_support
|
||||
self.vc_mode = vc_mode
|
||||
sys.excepthook = self.excepthook
|
||||
self.datasets_root = datasets_root
|
||||
self.utterances = set()
|
||||
self.current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav
|
||||
|
||||
self.synthesizer = None # type: Synthesizer
|
||||
|
||||
# for ppg-based voice conversion
|
||||
self.extractor = None
|
||||
self.convertor = None # ppg2mel
|
||||
|
||||
self.current_wav = None
|
||||
self.waves_list = []
|
||||
self.waves_count = 0
|
||||
self.waves_namelist = []
|
||||
|
||||
# Check for webrtcvad (enables removal of silences in vocoder output)
|
||||
try:
|
||||
import webrtcvad
|
||||
self.trim_silences = True
|
||||
except:
|
||||
self.trim_silences = False
|
||||
|
||||
# Initialize the events and the interface
|
||||
self.ui = UI(vc_mode)
|
||||
self.style_idx = 0
|
||||
self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, extractor_models_dir, convertor_models_dir, seed)
|
||||
self.setup_events()
|
||||
self.ui.start()
|
||||
|
||||
def excepthook(self, exc_type, exc_value, exc_tb):
|
||||
traceback.print_exception(exc_type, exc_value, exc_tb)
|
||||
self.ui.log("Exception: %s" % exc_value)
|
||||
|
||||
def setup_events(self):
|
||||
# Dataset, speaker and utterance selection
|
||||
self.ui.browser_load_button.clicked.connect(lambda: self.load_from_browser())
|
||||
random_func = lambda level: lambda: self.ui.populate_browser(self.datasets_root,
|
||||
recognized_datasets,
|
||||
level)
|
||||
self.ui.random_dataset_button.clicked.connect(random_func(0))
|
||||
self.ui.random_speaker_button.clicked.connect(random_func(1))
|
||||
self.ui.random_utterance_button.clicked.connect(random_func(2))
|
||||
self.ui.dataset_box.currentIndexChanged.connect(random_func(1))
|
||||
self.ui.speaker_box.currentIndexChanged.connect(random_func(2))
|
||||
|
||||
# Model selection
|
||||
self.ui.encoder_box.currentIndexChanged.connect(self.init_encoder)
|
||||
def func():
|
||||
self.synthesizer = None
|
||||
if self.vc_mode:
|
||||
self.ui.extractor_box.currentIndexChanged.connect(self.init_extractor)
|
||||
else:
|
||||
self.ui.synthesizer_box.currentIndexChanged.connect(func)
|
||||
|
||||
self.ui.vocoder_box.currentIndexChanged.connect(self.init_vocoder)
|
||||
|
||||
# Utterance selection
|
||||
func = lambda: self.load_from_browser(self.ui.browse_file())
|
||||
self.ui.browser_browse_button.clicked.connect(func)
|
||||
func = lambda: self.ui.draw_utterance(self.ui.selected_utterance, "current")
|
||||
self.ui.utterance_history.currentIndexChanged.connect(func)
|
||||
func = lambda: self.ui.play(self.ui.selected_utterance.wav, Synthesizer.sample_rate)
|
||||
self.ui.play_button.clicked.connect(func)
|
||||
self.ui.stop_button.clicked.connect(self.ui.stop)
|
||||
self.ui.record_button.clicked.connect(self.record)
|
||||
|
||||
# Source Utterance selection
|
||||
if self.vc_mode:
|
||||
func = lambda: self.load_soruce_button(self.ui.selected_utterance)
|
||||
self.ui.load_soruce_button.clicked.connect(func)
|
||||
|
||||
#Audio
|
||||
self.ui.setup_audio_devices(Synthesizer.sample_rate)
|
||||
|
||||
#Wav playback & save
|
||||
func = lambda: self.replay_last_wav()
|
||||
self.ui.replay_wav_button.clicked.connect(func)
|
||||
func = lambda: self.export_current_wave()
|
||||
self.ui.export_wav_button.clicked.connect(func)
|
||||
self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
|
||||
|
||||
# Generation
|
||||
self.ui.vocode_button.clicked.connect(self.vocode)
|
||||
self.ui.random_seed_checkbox.clicked.connect(self.update_seed_textbox)
|
||||
|
||||
if self.vc_mode:
|
||||
func = lambda: self.convert() or self.vocode()
|
||||
self.ui.convert_button.clicked.connect(func)
|
||||
else:
|
||||
func = lambda: self.synthesize() or self.vocode()
|
||||
self.ui.generate_button.clicked.connect(func)
|
||||
self.ui.synthesize_button.clicked.connect(self.synthesize)
|
||||
|
||||
# UMAP legend
|
||||
self.ui.clear_button.clicked.connect(self.clear_utterances)
|
||||
|
||||
def set_current_wav(self, index):
|
||||
self.current_wav = self.waves_list[index]
|
||||
|
||||
def export_current_wave(self):
|
||||
self.ui.save_audio_file(self.current_wav, Synthesizer.sample_rate)
|
||||
|
||||
def replay_last_wav(self):
|
||||
self.ui.play(self.current_wav, Synthesizer.sample_rate)
|
||||
|
||||
def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, extractor_models_dir, convertor_models_dir, seed):
|
||||
self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True)
|
||||
self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, extractor_models_dir, convertor_models_dir, self.vc_mode)
|
||||
self.ui.populate_gen_options(seed, self.trim_silences)
|
||||
|
||||
def load_from_browser(self, fpath=None):
|
||||
if fpath is None:
|
||||
fpath = Path(self.datasets_root,
|
||||
self.ui.current_dataset_name,
|
||||
self.ui.current_speaker_name,
|
||||
self.ui.current_utterance_name)
|
||||
name = str(fpath.relative_to(self.datasets_root))
|
||||
speaker_name = self.ui.current_dataset_name + '_' + self.ui.current_speaker_name
|
||||
|
||||
# Select the next utterance
|
||||
if self.ui.auto_next_checkbox.isChecked():
|
||||
self.ui.browser_select_next()
|
||||
elif fpath == "":
|
||||
return
|
||||
else:
|
||||
name = fpath.name
|
||||
speaker_name = fpath.parent.name
|
||||
|
||||
if fpath.suffix.lower() == ".mp3" and self.no_mp3_support:
|
||||
self.ui.log("Error: No mp3 file argument was passed but an mp3 file was used")
|
||||
return
|
||||
|
||||
# Get the wav from the disk. We take the wav with the vocoder/synthesizer format for
|
||||
# playback, so as to have a fair comparison with the generated audio
|
||||
wav = Synthesizer.load_preprocess_wav(fpath)
|
||||
self.ui.log("Loaded %s" % name)
|
||||
|
||||
self.add_real_utterance(wav, name, speaker_name)
|
||||
|
||||
def load_soruce_button(self, utterance: Utterance):
|
||||
self.selected_source_utterance = utterance
|
||||
|
||||
def record(self):
|
||||
wav = self.ui.record_one(encoder.sampling_rate, 5)
|
||||
if wav is None:
|
||||
return
|
||||
self.ui.play(wav, encoder.sampling_rate)
|
||||
|
||||
speaker_name = "user01"
|
||||
name = speaker_name + "_rec_%05d" % np.random.randint(100000)
|
||||
self.add_real_utterance(wav, name, speaker_name)
|
||||
|
||||
def add_real_utterance(self, wav, name, speaker_name):
|
||||
# Compute the mel spectrogram
|
||||
spec = Synthesizer.make_spectrogram(wav)
|
||||
self.ui.draw_spec(spec, "current")
|
||||
|
||||
# Compute the embedding
|
||||
if not encoder.is_loaded():
|
||||
self.init_encoder()
|
||||
encoder_wav = encoder.preprocess_wav(wav)
|
||||
embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
||||
|
||||
# Add the utterance
|
||||
utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, False)
|
||||
self.utterances.add(utterance)
|
||||
self.ui.register_utterance(utterance, self.vc_mode)
|
||||
|
||||
# Plot it
|
||||
self.ui.draw_embed(embed, name, "current")
|
||||
self.ui.draw_umap_projections(self.utterances)
|
||||
|
||||
def clear_utterances(self):
|
||||
self.utterances.clear()
|
||||
self.ui.draw_umap_projections(self.utterances)
|
||||
|
||||
def synthesize(self):
|
||||
self.ui.log("Generating the mel spectrogram...")
|
||||
self.ui.set_loading(1)
|
||||
|
||||
# Update the synthesizer random seed
|
||||
if self.ui.random_seed_checkbox.isChecked():
|
||||
seed = int(self.ui.seed_textbox.text())
|
||||
self.ui.populate_gen_options(seed, self.trim_silences)
|
||||
else:
|
||||
seed = None
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# Synthesize the spectrogram
|
||||
if self.synthesizer is None or seed is not None:
|
||||
self.init_synthesizer()
|
||||
|
||||
texts = self.ui.text_prompt.toPlainText().split("\n")
|
||||
punctuation = '!,。、,' # punctuate and split/clean text
|
||||
processed_texts = []
|
||||
for text in texts:
|
||||
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
|
||||
if processed_text:
|
||||
processed_texts.append(processed_text.strip())
|
||||
texts = processed_texts
|
||||
embed = self.ui.selected_utterance.embed
|
||||
embeds = [embed] * len(texts)
|
||||
min_token = int(self.ui.token_slider.value())
|
||||
specs = self.synthesizer.synthesize_spectrograms(texts, embeds, style_idx=int(self.ui.style_slider.value()), min_stop_token=min_token, steps=int(self.ui.length_slider.value())*200)
|
||||
breaks = [spec.shape[1] for spec in specs]
|
||||
spec = np.concatenate(specs, axis=1)
|
||||
|
||||
self.ui.draw_spec(spec, "generated")
|
||||
self.current_generated = (self.ui.selected_utterance.speaker_name, spec, breaks, None)
|
||||
self.ui.set_loading(0)
|
||||
|
||||
def vocode(self):
|
||||
speaker_name, spec, breaks, _ = self.current_generated
|
||||
assert spec is not None
|
||||
|
||||
# Initialize the vocoder model and make it determinstic, if user provides a seed
|
||||
if self.ui.random_seed_checkbox.isChecked():
|
||||
seed = int(self.ui.seed_textbox.text())
|
||||
self.ui.populate_gen_options(seed, self.trim_silences)
|
||||
else:
|
||||
seed = None
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# Synthesize the waveform
|
||||
if not vocoder.is_loaded() or seed is not None:
|
||||
self.init_vocoder()
|
||||
|
||||
def vocoder_progress(i, seq_len, b_size, gen_rate):
|
||||
real_time_factor = (gen_rate / Synthesizer.sample_rate) * 1000
|
||||
line = "Waveform generation: %d/%d (batch size: %d, rate: %.1fkHz - %.2fx real time)" \
|
||||
% (i * b_size, seq_len * b_size, b_size, gen_rate, real_time_factor)
|
||||
self.ui.log(line, "overwrite")
|
||||
self.ui.set_loading(i, seq_len)
|
||||
if self.ui.current_vocoder_fpath is not None:
|
||||
self.ui.log("")
|
||||
wav, sample_rate = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
|
||||
else:
|
||||
self.ui.log("Waveform generation with Griffin-Lim... ")
|
||||
wav = Synthesizer.griffin_lim(spec)
|
||||
self.ui.set_loading(0)
|
||||
self.ui.log(" Done!", "append")
|
||||
|
||||
# Add breaks
|
||||
b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
|
||||
b_starts = np.concatenate(([0], b_ends[:-1]))
|
||||
wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)]
|
||||
breaks = [np.zeros(int(0.15 * sample_rate))] * len(breaks)
|
||||
wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
|
||||
|
||||
# Trim excessive silences
|
||||
if self.ui.trim_silences_checkbox.isChecked():
|
||||
wav = encoder.preprocess_wav(wav)
|
||||
|
||||
# Play it
|
||||
wav = wav / np.abs(wav).max() * 0.97
|
||||
self.ui.play(wav, sample_rate)
|
||||
|
||||
# Name it (history displayed in combobox)
|
||||
# TODO better naming for the combobox items?
|
||||
wav_name = str(self.waves_count + 1)
|
||||
|
||||
#Update waves combobox
|
||||
self.waves_count += 1
|
||||
if self.waves_count > MAX_WAVES:
|
||||
self.waves_list.pop()
|
||||
self.waves_namelist.pop()
|
||||
self.waves_list.insert(0, wav)
|
||||
self.waves_namelist.insert(0, wav_name)
|
||||
|
||||
self.ui.waves_cb.disconnect()
|
||||
self.ui.waves_cb_model.setStringList(self.waves_namelist)
|
||||
self.ui.waves_cb.setCurrentIndex(0)
|
||||
self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
|
||||
|
||||
# Update current wav
|
||||
self.set_current_wav(0)
|
||||
|
||||
#Enable replay and save buttons:
|
||||
self.ui.replay_wav_button.setDisabled(False)
|
||||
self.ui.export_wav_button.setDisabled(False)
|
||||
|
||||
# Compute the embedding
|
||||
# TODO: this is problematic with different sampling rates, gotta fix it
|
||||
if not encoder.is_loaded():
|
||||
self.init_encoder()
|
||||
encoder_wav = encoder.preprocess_wav(wav)
|
||||
embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
||||
|
||||
# Add the utterance
|
||||
name = speaker_name + "_gen_%05d" % np.random.randint(100000)
|
||||
utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, True)
|
||||
self.utterances.add(utterance)
|
||||
|
||||
# Plot it
|
||||
self.ui.draw_embed(embed, name, "generated")
|
||||
self.ui.draw_umap_projections(self.utterances)
|
||||
|
||||
def convert(self):
|
||||
self.ui.log("Extract PPG and Converting...")
|
||||
self.ui.set_loading(1)
|
||||
|
||||
# Init
|
||||
if self.convertor is None:
|
||||
self.init_convertor()
|
||||
if self.extractor is None:
|
||||
self.init_extractor()
|
||||
|
||||
src_wav = self.selected_source_utterance.wav
|
||||
|
||||
# Compute the ppg
|
||||
if not self.extractor is None:
|
||||
ppg = self.extractor.extract_from_wav(src_wav)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
ref_wav = self.ui.selected_utterance.wav
|
||||
# Import necessary dependency of Voice Conversion
|
||||
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
|
||||
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
|
||||
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
|
||||
min_len = min(ppg.shape[1], len(lf0_uv))
|
||||
ppg = ppg[:, :min_len]
|
||||
lf0_uv = lf0_uv[:min_len]
|
||||
_, mel_pred, att_ws = self.convertor.inference(
|
||||
ppg,
|
||||
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
|
||||
spembs=torch.from_numpy(self.ui.selected_utterance.embed).unsqueeze(0).to(device),
|
||||
)
|
||||
mel_pred= mel_pred.transpose(0, 1)
|
||||
breaks = [mel_pred.shape[1]]
|
||||
mel_pred= mel_pred.detach().cpu().numpy()
|
||||
self.ui.draw_spec(mel_pred, "generated")
|
||||
self.current_generated = (self.ui.selected_utterance.speaker_name, mel_pred, breaks, None)
|
||||
self.ui.set_loading(0)
|
||||
|
||||
def init_extractor(self):
|
||||
if self.ui.current_extractor_fpath is None:
|
||||
return
|
||||
model_fpath = self.ui.current_extractor_fpath
|
||||
self.ui.log("Loading the extractor %s... " % model_fpath)
|
||||
self.ui.set_loading(1)
|
||||
start = timer()
|
||||
import models.ppg_extractor as extractor
|
||||
self.extractor = extractor.load_model(model_fpath)
|
||||
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
||||
self.ui.set_loading(0)
|
||||
|
||||
def init_convertor(self):
|
||||
if self.ui.current_convertor_fpath is None:
|
||||
return
|
||||
model_fpath = self.ui.current_convertor_fpath
|
||||
self.ui.log("Loading the convertor %s... " % model_fpath)
|
||||
self.ui.set_loading(1)
|
||||
start = timer()
|
||||
import models.ppg2mel as convertor
|
||||
self.convertor = convertor.load_model( model_fpath)
|
||||
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
||||
self.ui.set_loading(0)
|
||||
|
||||
def init_encoder(self):
|
||||
model_fpath = self.ui.current_encoder_fpath
|
||||
|
||||
self.ui.log("Loading the encoder %s... " % model_fpath)
|
||||
self.ui.set_loading(1)
|
||||
start = timer()
|
||||
encoder.load_model(model_fpath)
|
||||
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
||||
self.ui.set_loading(0)
|
||||
|
||||
def init_synthesizer(self):
|
||||
model_fpath = self.ui.current_synthesizer_fpath
|
||||
|
||||
self.ui.log("Loading the synthesizer %s... " % model_fpath)
|
||||
self.ui.set_loading(1)
|
||||
start = timer()
|
||||
self.synthesizer = Synthesizer(model_fpath)
|
||||
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
||||
self.ui.set_loading(0)
|
||||
|
||||
def init_vocoder(self):
|
||||
|
||||
global vocoder
|
||||
model_fpath = self.ui.current_vocoder_fpath
|
||||
# Case of Griffin-lim
|
||||
if model_fpath is None:
|
||||
return
|
||||
# Sekect vocoder based on model name
|
||||
model_config_fpath = None
|
||||
if model_fpath.name is not None and model_fpath.name.find("hifigan") > -1:
|
||||
vocoder = gan_vocoder
|
||||
self.ui.log("set hifigan as vocoder")
|
||||
# search a config file
|
||||
model_config_fpaths = list(model_fpath.parent.rglob("*.json"))
|
||||
if self.vc_mode and self.ui.current_extractor_fpath is None:
|
||||
return
|
||||
if len(model_config_fpaths) > 0:
|
||||
model_config_fpath = model_config_fpaths[0]
|
||||
elif model_fpath.name is not None and model_fpath.name.find("fregan") > -1:
|
||||
vocoder = fgan_vocoder
|
||||
self.ui.log("set fregan as vocoder")
|
||||
# search a config file
|
||||
model_config_fpaths = list(model_fpath.parent.rglob("*.json"))
|
||||
if self.vc_mode and self.ui.current_extractor_fpath is None:
|
||||
return
|
||||
if len(model_config_fpaths) > 0:
|
||||
model_config_fpath = model_config_fpaths[0]
|
||||
else:
|
||||
vocoder = rnn_vocoder
|
||||
self.ui.log("set wavernn as vocoder")
|
||||
|
||||
self.ui.log("Loading the vocoder %s... " % model_fpath)
|
||||
self.ui.set_loading(1)
|
||||
start = timer()
|
||||
vocoder.load_model(model_fpath, model_config_fpath)
|
||||
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
||||
self.ui.set_loading(0)
|
||||
|
||||
def update_seed_textbox(self):
|
||||
self.ui.update_seed_textbox()
|
||||
BIN
control/toolbox/assets/mb.png
Normal file
BIN
control/toolbox/assets/mb.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.6 KiB |
701
control/toolbox/ui.py
Normal file
701
control/toolbox/ui.py
Normal file
@@ -0,0 +1,701 @@
|
||||
from PyQt5.QtCore import Qt, QStringListModel
|
||||
from PyQt5 import QtGui
|
||||
from PyQt5.QtWidgets import *
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
||||
from matplotlib.figure import Figure
|
||||
from models.encoder.inference import plot_embedding_as_heatmap
|
||||
from control.toolbox.utterance import Utterance
|
||||
from pathlib import Path
|
||||
from typing import List, Set
|
||||
import sounddevice as sd
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
# from sklearn.manifold import TSNE # You can try with TSNE if you like, I prefer UMAP
|
||||
from time import sleep
|
||||
import umap
|
||||
import sys
|
||||
from warnings import filterwarnings, warn
|
||||
filterwarnings("ignore")
|
||||
|
||||
|
||||
colormap = np.array([
|
||||
[0, 127, 70],
|
||||
[255, 0, 0],
|
||||
[255, 217, 38],
|
||||
[0, 135, 255],
|
||||
[165, 0, 165],
|
||||
[255, 167, 255],
|
||||
[97, 142, 151],
|
||||
[0, 255, 255],
|
||||
[255, 96, 38],
|
||||
[142, 76, 0],
|
||||
[33, 0, 127],
|
||||
[0, 0, 0],
|
||||
[183, 183, 183],
|
||||
[76, 255, 0],
|
||||
], dtype=np.float) / 255
|
||||
|
||||
default_text = \
|
||||
"欢迎使用工具箱, 现已支持中文输入!"
|
||||
|
||||
|
||||
|
||||
class UI(QDialog):
|
||||
min_umap_points = 4
|
||||
max_log_lines = 5
|
||||
max_saved_utterances = 20
|
||||
|
||||
def draw_utterance(self, utterance: Utterance, which):
|
||||
self.draw_spec(utterance.spec, which)
|
||||
self.draw_embed(utterance.embed, utterance.name, which)
|
||||
|
||||
def draw_embed(self, embed, name, which):
|
||||
embed_ax, _ = self.current_ax if which == "current" else self.gen_ax
|
||||
embed_ax.figure.suptitle("" if embed is None else name)
|
||||
|
||||
## Embedding
|
||||
# Clear the plot
|
||||
if len(embed_ax.images) > 0:
|
||||
embed_ax.images[0].colorbar.remove()
|
||||
embed_ax.clear()
|
||||
|
||||
# Draw the embed
|
||||
if embed is not None:
|
||||
plot_embedding_as_heatmap(embed, embed_ax)
|
||||
embed_ax.set_title("embedding")
|
||||
embed_ax.set_aspect("equal", "datalim")
|
||||
embed_ax.set_xticks([])
|
||||
embed_ax.set_yticks([])
|
||||
embed_ax.figure.canvas.draw()
|
||||
|
||||
def draw_spec(self, spec, which):
|
||||
_, spec_ax = self.current_ax if which == "current" else self.gen_ax
|
||||
|
||||
## Spectrogram
|
||||
# Draw the spectrogram
|
||||
spec_ax.clear()
|
||||
if spec is not None:
|
||||
im = spec_ax.imshow(spec, aspect="auto", interpolation="none")
|
||||
# spec_ax.figure.colorbar(mappable=im, shrink=0.65, orientation="horizontal",
|
||||
# spec_ax=spec_ax)
|
||||
spec_ax.set_title("mel spectrogram")
|
||||
|
||||
spec_ax.set_xticks([])
|
||||
spec_ax.set_yticks([])
|
||||
spec_ax.figure.canvas.draw()
|
||||
if which != "current":
|
||||
self.vocode_button.setDisabled(spec is None)
|
||||
|
||||
def draw_umap_projections(self, utterances: Set[Utterance]):
|
||||
self.umap_ax.clear()
|
||||
|
||||
speakers = np.unique([u.speaker_name for u in utterances])
|
||||
colors = {speaker_name: colormap[i] for i, speaker_name in enumerate(speakers)}
|
||||
embeds = [u.embed for u in utterances]
|
||||
|
||||
# Display a message if there aren't enough points
|
||||
if len(utterances) < self.min_umap_points:
|
||||
self.umap_ax.text(.5, .5, "Add %d more points to\ngenerate the projections" %
|
||||
(self.min_umap_points - len(utterances)),
|
||||
horizontalalignment='center', fontsize=15)
|
||||
self.umap_ax.set_title("")
|
||||
|
||||
# Compute the projections
|
||||
else:
|
||||
if not self.umap_hot:
|
||||
self.log(
|
||||
"Drawing UMAP projections for the first time, this will take a few seconds.")
|
||||
self.umap_hot = True
|
||||
|
||||
reducer = umap.UMAP(int(np.ceil(np.sqrt(len(embeds)))), metric="cosine")
|
||||
# reducer = TSNE()
|
||||
projections = reducer.fit_transform(embeds)
|
||||
|
||||
speakers_done = set()
|
||||
for projection, utterance in zip(projections, utterances):
|
||||
color = colors[utterance.speaker_name]
|
||||
mark = "x" if "_gen_" in utterance.name else "o"
|
||||
label = None if utterance.speaker_name in speakers_done else utterance.speaker_name
|
||||
speakers_done.add(utterance.speaker_name)
|
||||
self.umap_ax.scatter(projection[0], projection[1], c=[color], marker=mark,
|
||||
label=label)
|
||||
# self.umap_ax.set_title("UMAP projections")
|
||||
self.umap_ax.legend(prop={'size': 10})
|
||||
|
||||
# Draw the plot
|
||||
self.umap_ax.set_aspect("equal", "datalim")
|
||||
self.umap_ax.set_xticks([])
|
||||
self.umap_ax.set_yticks([])
|
||||
self.umap_ax.figure.canvas.draw()
|
||||
|
||||
def save_audio_file(self, wav, sample_rate):
|
||||
dialog = QFileDialog()
|
||||
dialog.setDefaultSuffix(".wav")
|
||||
fpath, _ = dialog.getSaveFileName(
|
||||
parent=self,
|
||||
caption="Select a path to save the audio file",
|
||||
filter="Audio Files (*.flac *.wav)"
|
||||
)
|
||||
if fpath:
|
||||
#Default format is wav
|
||||
if Path(fpath).suffix == "":
|
||||
fpath += ".wav"
|
||||
sf.write(fpath, wav, sample_rate)
|
||||
|
||||
def setup_audio_devices(self, sample_rate):
|
||||
input_devices = []
|
||||
output_devices = []
|
||||
for device in sd.query_devices():
|
||||
# Check if valid input
|
||||
try:
|
||||
sd.check_input_settings(device=device["name"], samplerate=sample_rate)
|
||||
input_devices.append(device["name"])
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check if valid output
|
||||
try:
|
||||
sd.check_output_settings(device=device["name"], samplerate=sample_rate)
|
||||
output_devices.append(device["name"])
|
||||
except Exception as e:
|
||||
# Log a warning only if the device is not an input
|
||||
if not device["name"] in input_devices:
|
||||
warn("Unsupported output device %s for the sample rate: %d \nError: %s" % (device["name"], sample_rate, str(e)))
|
||||
|
||||
if len(input_devices) == 0:
|
||||
self.log("No audio input device detected. Recording may not work.")
|
||||
self.audio_in_device = None
|
||||
else:
|
||||
self.audio_in_device = input_devices[0]
|
||||
|
||||
if len(output_devices) == 0:
|
||||
self.log("No supported output audio devices were found! Audio output may not work.")
|
||||
self.audio_out_devices_cb.addItems(["None"])
|
||||
self.audio_out_devices_cb.setDisabled(True)
|
||||
else:
|
||||
self.audio_out_devices_cb.clear()
|
||||
self.audio_out_devices_cb.addItems(output_devices)
|
||||
self.audio_out_devices_cb.currentTextChanged.connect(self.set_audio_device)
|
||||
|
||||
self.set_audio_device()
|
||||
|
||||
def set_audio_device(self):
|
||||
|
||||
output_device = self.audio_out_devices_cb.currentText()
|
||||
if output_device == "None":
|
||||
output_device = None
|
||||
|
||||
# If None, sounddevice queries portaudio
|
||||
sd.default.device = (self.audio_in_device, output_device)
|
||||
|
||||
def play(self, wav, sample_rate):
|
||||
try:
|
||||
sd.stop()
|
||||
sd.play(wav, sample_rate)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
self.log("Error in audio playback. Try selecting a different audio output device.")
|
||||
self.log("Your device must be connected before you start the toolbox.")
|
||||
|
||||
def stop(self):
|
||||
sd.stop()
|
||||
|
||||
def record_one(self, sample_rate, duration):
|
||||
self.record_button.setText("Recording...")
|
||||
self.record_button.setDisabled(True)
|
||||
|
||||
self.log("Recording %d seconds of audio" % duration)
|
||||
sd.stop()
|
||||
try:
|
||||
wav = sd.rec(duration * sample_rate, sample_rate, 1)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
self.log("Could not record anything. Is your recording device enabled?")
|
||||
self.log("Your device must be connected before you start the toolbox.")
|
||||
return None
|
||||
|
||||
for i in np.arange(0, duration, 0.1):
|
||||
self.set_loading(i, duration)
|
||||
sleep(0.1)
|
||||
self.set_loading(duration, duration)
|
||||
sd.wait()
|
||||
|
||||
self.log("Done recording.")
|
||||
self.record_button.setText("Record")
|
||||
self.record_button.setDisabled(False)
|
||||
|
||||
return wav.squeeze()
|
||||
|
||||
@property
|
||||
def current_dataset_name(self):
|
||||
return self.dataset_box.currentText()
|
||||
|
||||
@property
|
||||
def current_speaker_name(self):
|
||||
return self.speaker_box.currentText()
|
||||
|
||||
@property
|
||||
def current_utterance_name(self):
|
||||
return self.utterance_box.currentText()
|
||||
|
||||
def browse_file(self):
|
||||
fpath = QFileDialog().getOpenFileName(
|
||||
parent=self,
|
||||
caption="Select an audio file",
|
||||
filter="Audio Files (*.mp3 *.flac *.wav *.m4a)"
|
||||
)
|
||||
return Path(fpath[0]) if fpath[0] != "" else ""
|
||||
|
||||
@staticmethod
|
||||
def repopulate_box(box, items, random=False):
|
||||
"""
|
||||
Resets a box and adds a list of items. Pass a list of (item, data) pairs instead to join
|
||||
data to the items
|
||||
"""
|
||||
box.blockSignals(True)
|
||||
box.clear()
|
||||
for item in items:
|
||||
item = list(item) if isinstance(item, tuple) else [item]
|
||||
box.addItem(str(item[0]), *item[1:])
|
||||
if len(items) > 0:
|
||||
box.setCurrentIndex(np.random.randint(len(items)) if random else 0)
|
||||
box.setDisabled(len(items) == 0)
|
||||
box.blockSignals(False)
|
||||
|
||||
def populate_browser(self, datasets_root: Path, recognized_datasets: List, level: int,
|
||||
random=True):
|
||||
# Select a random dataset
|
||||
if level <= 0:
|
||||
if datasets_root is not None:
|
||||
datasets = [datasets_root.joinpath(d) for d in recognized_datasets]
|
||||
datasets = [d.relative_to(datasets_root) for d in datasets if d.exists()]
|
||||
self.browser_load_button.setDisabled(len(datasets) == 0)
|
||||
if datasets_root is None or len(datasets) == 0:
|
||||
msg = "Warning: you d" + ("id not pass a root directory for datasets as argument" \
|
||||
if datasets_root is None else "o not have any of the recognized datasets" \
|
||||
" in %s \n" \
|
||||
"Please note use 'E:\datasets' as root path " \
|
||||
"instead of 'E:\datasets\aidatatang_200zh\corpus\test' as an example " % datasets_root)
|
||||
self.log(msg)
|
||||
msg += ".\nThe recognized datasets are:\n\t%s\nFeel free to add your own. You " \
|
||||
"can still use the toolbox by recording samples yourself." % \
|
||||
("\n\t".join(recognized_datasets))
|
||||
print(msg, file=sys.stderr)
|
||||
|
||||
self.random_utterance_button.setDisabled(True)
|
||||
self.random_speaker_button.setDisabled(True)
|
||||
self.random_dataset_button.setDisabled(True)
|
||||
self.utterance_box.setDisabled(True)
|
||||
self.speaker_box.setDisabled(True)
|
||||
self.dataset_box.setDisabled(True)
|
||||
self.browser_load_button.setDisabled(True)
|
||||
self.auto_next_checkbox.setDisabled(True)
|
||||
return
|
||||
self.repopulate_box(self.dataset_box, datasets, random)
|
||||
|
||||
# Select a random speaker
|
||||
if level <= 1:
|
||||
speakers_root = datasets_root.joinpath(self.current_dataset_name)
|
||||
speaker_names = [d.stem for d in speakers_root.glob("*") if d.is_dir()]
|
||||
self.repopulate_box(self.speaker_box, speaker_names, random)
|
||||
|
||||
# Select a random utterance
|
||||
if level <= 2:
|
||||
utterances_root = datasets_root.joinpath(
|
||||
self.current_dataset_name,
|
||||
self.current_speaker_name
|
||||
)
|
||||
utterances = []
|
||||
for extension in ['mp3', 'flac', 'wav', 'm4a']:
|
||||
utterances.extend(Path(utterances_root).glob("**/*.%s" % extension))
|
||||
utterances = [fpath.relative_to(utterances_root) for fpath in utterances]
|
||||
self.repopulate_box(self.utterance_box, utterances, random)
|
||||
|
||||
def browser_select_next(self):
|
||||
index = (self.utterance_box.currentIndex() + 1) % len(self.utterance_box)
|
||||
self.utterance_box.setCurrentIndex(index)
|
||||
|
||||
@property
|
||||
def current_encoder_fpath(self):
|
||||
return self.encoder_box.itemData(self.encoder_box.currentIndex())
|
||||
|
||||
@property
|
||||
def current_synthesizer_fpath(self):
|
||||
return self.synthesizer_box.itemData(self.synthesizer_box.currentIndex())
|
||||
|
||||
@property
|
||||
def current_vocoder_fpath(self):
|
||||
return self.vocoder_box.itemData(self.vocoder_box.currentIndex())
|
||||
|
||||
@property
|
||||
def current_extractor_fpath(self):
|
||||
return self.extractor_box.itemData(self.extractor_box.currentIndex())
|
||||
|
||||
@property
|
||||
def current_convertor_fpath(self):
|
||||
return self.convertor_box.itemData(self.convertor_box.currentIndex())
|
||||
|
||||
def populate_models(self, encoder_models_dir: Path, synthesizer_models_dir: Path,
|
||||
vocoder_models_dir: Path, extractor_models_dir: Path, convertor_models_dir: Path, vc_mode: bool):
|
||||
# Encoder
|
||||
encoder_fpaths = list(encoder_models_dir.glob("*.pt"))
|
||||
if len(encoder_fpaths) == 0:
|
||||
raise Exception("No encoder models found in %s" % encoder_models_dir)
|
||||
self.repopulate_box(self.encoder_box, [(f.stem, f) for f in encoder_fpaths])
|
||||
|
||||
if vc_mode:
|
||||
# Extractor
|
||||
extractor_fpaths = list(extractor_models_dir.glob("*.pt"))
|
||||
if len(extractor_fpaths) == 0:
|
||||
self.log("No extractor models found in %s" % extractor_fpaths)
|
||||
self.repopulate_box(self.extractor_box, [(f.stem, f) for f in extractor_fpaths])
|
||||
|
||||
# Convertor
|
||||
convertor_fpaths = list(convertor_models_dir.glob("*.pth"))
|
||||
if len(convertor_fpaths) == 0:
|
||||
self.log("No convertor models found in %s" % convertor_fpaths)
|
||||
self.repopulate_box(self.convertor_box, [(f.stem, f) for f in convertor_fpaths])
|
||||
else:
|
||||
# Synthesizer
|
||||
synthesizer_fpaths = list(synthesizer_models_dir.glob("**/*.pt"))
|
||||
if len(synthesizer_fpaths) == 0:
|
||||
raise Exception("No synthesizer models found in %s" % synthesizer_models_dir)
|
||||
self.repopulate_box(self.synthesizer_box, [(f.stem, f) for f in synthesizer_fpaths])
|
||||
|
||||
# Vocoder
|
||||
vocoder_fpaths = list(vocoder_models_dir.glob("**/*.pt"))
|
||||
vocoder_items = [(f.stem, f) for f in vocoder_fpaths] + [("Griffin-Lim", None)]
|
||||
self.repopulate_box(self.vocoder_box, vocoder_items)
|
||||
|
||||
@property
|
||||
def selected_utterance(self):
|
||||
return self.utterance_history.itemData(self.utterance_history.currentIndex())
|
||||
|
||||
def register_utterance(self, utterance: Utterance, vc_mode):
|
||||
self.utterance_history.blockSignals(True)
|
||||
self.utterance_history.insertItem(0, utterance.name, utterance)
|
||||
self.utterance_history.setCurrentIndex(0)
|
||||
self.utterance_history.blockSignals(False)
|
||||
|
||||
if len(self.utterance_history) > self.max_saved_utterances:
|
||||
self.utterance_history.removeItem(self.max_saved_utterances)
|
||||
|
||||
self.play_button.setDisabled(False)
|
||||
if vc_mode:
|
||||
self.convert_button.setDisabled(False)
|
||||
else:
|
||||
self.generate_button.setDisabled(False)
|
||||
self.synthesize_button.setDisabled(False)
|
||||
|
||||
def log(self, line, mode="newline"):
|
||||
if mode == "newline":
|
||||
self.logs.append(line)
|
||||
if len(self.logs) > self.max_log_lines:
|
||||
del self.logs[0]
|
||||
elif mode == "append":
|
||||
self.logs[-1] += line
|
||||
elif mode == "overwrite":
|
||||
self.logs[-1] = line
|
||||
log_text = '\n'.join(self.logs)
|
||||
|
||||
self.log_window.setText(log_text)
|
||||
self.app.processEvents()
|
||||
|
||||
def set_loading(self, value, maximum=1):
|
||||
self.loading_bar.setValue(value * 100)
|
||||
self.loading_bar.setMaximum(maximum * 100)
|
||||
self.loading_bar.setTextVisible(value != 0)
|
||||
self.app.processEvents()
|
||||
|
||||
def populate_gen_options(self, seed, trim_silences):
|
||||
if seed is not None:
|
||||
self.random_seed_checkbox.setChecked(True)
|
||||
self.seed_textbox.setText(str(seed))
|
||||
self.seed_textbox.setEnabled(True)
|
||||
else:
|
||||
self.random_seed_checkbox.setChecked(False)
|
||||
self.seed_textbox.setText(str(0))
|
||||
self.seed_textbox.setEnabled(False)
|
||||
|
||||
if not trim_silences:
|
||||
self.trim_silences_checkbox.setChecked(False)
|
||||
self.trim_silences_checkbox.setDisabled(True)
|
||||
|
||||
def update_seed_textbox(self):
|
||||
if self.random_seed_checkbox.isChecked():
|
||||
self.seed_textbox.setEnabled(True)
|
||||
else:
|
||||
self.seed_textbox.setEnabled(False)
|
||||
|
||||
def reset_interface(self, vc_mode):
|
||||
self.draw_embed(None, None, "current")
|
||||
self.draw_embed(None, None, "generated")
|
||||
self.draw_spec(None, "current")
|
||||
self.draw_spec(None, "generated")
|
||||
self.draw_umap_projections(set())
|
||||
self.set_loading(0)
|
||||
self.play_button.setDisabled(True)
|
||||
if vc_mode:
|
||||
self.convert_button.setDisabled(True)
|
||||
else:
|
||||
self.generate_button.setDisabled(True)
|
||||
self.synthesize_button.setDisabled(True)
|
||||
self.vocode_button.setDisabled(True)
|
||||
self.replay_wav_button.setDisabled(True)
|
||||
self.export_wav_button.setDisabled(True)
|
||||
[self.log("") for _ in range(self.max_log_lines)]
|
||||
|
||||
def __init__(self, vc_mode):
|
||||
## Initialize the application
|
||||
self.app = QApplication(sys.argv)
|
||||
super().__init__(None)
|
||||
self.setWindowTitle("MockingBird GUI")
|
||||
self.setWindowIcon(QtGui.QIcon('toolbox\\assets\\mb.png'))
|
||||
self.setWindowFlag(Qt.WindowMinimizeButtonHint, True)
|
||||
self.setWindowFlag(Qt.WindowMaximizeButtonHint, True)
|
||||
|
||||
|
||||
## Main layouts
|
||||
# Root
|
||||
root_layout = QGridLayout()
|
||||
self.setLayout(root_layout)
|
||||
|
||||
# Browser
|
||||
browser_layout = QGridLayout()
|
||||
root_layout.addLayout(browser_layout, 0, 0, 1, 8)
|
||||
|
||||
# Generation
|
||||
gen_layout = QVBoxLayout()
|
||||
root_layout.addLayout(gen_layout, 0, 8)
|
||||
|
||||
# Visualizations
|
||||
vis_layout = QVBoxLayout()
|
||||
root_layout.addLayout(vis_layout, 1, 0, 2, 8)
|
||||
|
||||
# Output
|
||||
output_layout = QGridLayout()
|
||||
vis_layout.addLayout(output_layout, 0)
|
||||
|
||||
# Projections
|
||||
self.projections_layout = QVBoxLayout()
|
||||
root_layout.addLayout(self.projections_layout, 1, 8, 2, 2)
|
||||
|
||||
## Projections
|
||||
# UMap
|
||||
fig, self.umap_ax = plt.subplots(figsize=(3, 3), facecolor="#F0F0F0")
|
||||
fig.subplots_adjust(left=0.02, bottom=0.02, right=0.98, top=0.98)
|
||||
self.projections_layout.addWidget(FigureCanvas(fig))
|
||||
self.umap_hot = False
|
||||
self.clear_button = QPushButton("Clear")
|
||||
self.projections_layout.addWidget(self.clear_button)
|
||||
|
||||
|
||||
## Browser
|
||||
# Dataset, speaker and utterance selection
|
||||
i = 0
|
||||
|
||||
source_groupbox = QGroupBox('Source(源音频)')
|
||||
source_layout = QGridLayout()
|
||||
source_groupbox.setLayout(source_layout)
|
||||
browser_layout.addWidget(source_groupbox, i, 0, 1, 5)
|
||||
|
||||
self.dataset_box = QComboBox()
|
||||
source_layout.addWidget(QLabel("Dataset(数据集):"), i, 0)
|
||||
source_layout.addWidget(self.dataset_box, i, 1)
|
||||
self.random_dataset_button = QPushButton("Random")
|
||||
source_layout.addWidget(self.random_dataset_button, i, 2)
|
||||
i += 1
|
||||
self.speaker_box = QComboBox()
|
||||
source_layout.addWidget(QLabel("Speaker(说话者)"), i, 0)
|
||||
source_layout.addWidget(self.speaker_box, i, 1)
|
||||
self.random_speaker_button = QPushButton("Random")
|
||||
source_layout.addWidget(self.random_speaker_button, i, 2)
|
||||
i += 1
|
||||
self.utterance_box = QComboBox()
|
||||
source_layout.addWidget(QLabel("Utterance(音频):"), i, 0)
|
||||
source_layout.addWidget(self.utterance_box, i, 1)
|
||||
self.random_utterance_button = QPushButton("Random")
|
||||
source_layout.addWidget(self.random_utterance_button, i, 2)
|
||||
|
||||
i += 1
|
||||
source_layout.addWidget(QLabel("<b>Use(使用):</b>"), i, 0)
|
||||
self.browser_load_button = QPushButton("Load Above(加载上面)")
|
||||
source_layout.addWidget(self.browser_load_button, i, 1, 1, 2)
|
||||
self.auto_next_checkbox = QCheckBox("Auto select next")
|
||||
self.auto_next_checkbox.setChecked(True)
|
||||
source_layout.addWidget(self.auto_next_checkbox, i+1, 1)
|
||||
self.browser_browse_button = QPushButton("Browse(打开本地)")
|
||||
source_layout.addWidget(self.browser_browse_button, i, 3)
|
||||
self.record_button = QPushButton("Record(录音)")
|
||||
source_layout.addWidget(self.record_button, i+1, 3)
|
||||
|
||||
i += 2
|
||||
# Utterance box
|
||||
browser_layout.addWidget(QLabel("<b>Current(当前):</b>"), i, 0)
|
||||
self.utterance_history = QComboBox()
|
||||
browser_layout.addWidget(self.utterance_history, i, 1)
|
||||
self.play_button = QPushButton("Play(播放)")
|
||||
browser_layout.addWidget(self.play_button, i, 2)
|
||||
self.stop_button = QPushButton("Stop(暂停)")
|
||||
browser_layout.addWidget(self.stop_button, i, 3)
|
||||
if vc_mode:
|
||||
self.load_soruce_button = QPushButton("Select(选择为被转换的语音输入)")
|
||||
browser_layout.addWidget(self.load_soruce_button, i, 4)
|
||||
|
||||
i += 1
|
||||
model_groupbox = QGroupBox('Models(模型选择)')
|
||||
model_layout = QHBoxLayout()
|
||||
model_groupbox.setLayout(model_layout)
|
||||
browser_layout.addWidget(model_groupbox, i, 0, 2, 5)
|
||||
|
||||
# Model and audio output selection
|
||||
self.encoder_box = QComboBox()
|
||||
model_layout.addWidget(QLabel("Encoder:"))
|
||||
model_layout.addWidget(self.encoder_box)
|
||||
self.synthesizer_box = QComboBox()
|
||||
if vc_mode:
|
||||
self.extractor_box = QComboBox()
|
||||
model_layout.addWidget(QLabel("Extractor:"))
|
||||
model_layout.addWidget(self.extractor_box)
|
||||
self.convertor_box = QComboBox()
|
||||
model_layout.addWidget(QLabel("Convertor:"))
|
||||
model_layout.addWidget(self.convertor_box)
|
||||
else:
|
||||
model_layout.addWidget(QLabel("Synthesizer:"))
|
||||
model_layout.addWidget(self.synthesizer_box)
|
||||
self.vocoder_box = QComboBox()
|
||||
model_layout.addWidget(QLabel("Vocoder:"))
|
||||
model_layout.addWidget(self.vocoder_box)
|
||||
|
||||
#Replay & Save Audio
|
||||
i = 0
|
||||
output_layout.addWidget(QLabel("<b>Toolbox Output:</b>"), i, 0)
|
||||
self.waves_cb = QComboBox()
|
||||
self.waves_cb_model = QStringListModel()
|
||||
self.waves_cb.setModel(self.waves_cb_model)
|
||||
self.waves_cb.setToolTip("Select one of the last generated waves in this section for replaying or exporting")
|
||||
output_layout.addWidget(self.waves_cb, i, 1)
|
||||
self.replay_wav_button = QPushButton("Replay")
|
||||
self.replay_wav_button.setToolTip("Replay last generated vocoder")
|
||||
output_layout.addWidget(self.replay_wav_button, i, 2)
|
||||
self.export_wav_button = QPushButton("Export")
|
||||
self.export_wav_button.setToolTip("Save last generated vocoder audio in filesystem as a wav file")
|
||||
output_layout.addWidget(self.export_wav_button, i, 3)
|
||||
self.audio_out_devices_cb=QComboBox()
|
||||
i += 1
|
||||
output_layout.addWidget(QLabel("<b>Audio Output</b>"), i, 0)
|
||||
output_layout.addWidget(self.audio_out_devices_cb, i, 1)
|
||||
|
||||
## Embed & spectrograms
|
||||
vis_layout.addStretch()
|
||||
# TODO: add spectrograms for source
|
||||
gridspec_kw = {"width_ratios": [1, 4]}
|
||||
fig, self.current_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0",
|
||||
gridspec_kw=gridspec_kw)
|
||||
fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8)
|
||||
vis_layout.addWidget(FigureCanvas(fig))
|
||||
|
||||
fig, self.gen_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0",
|
||||
gridspec_kw=gridspec_kw)
|
||||
fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8)
|
||||
vis_layout.addWidget(FigureCanvas(fig))
|
||||
|
||||
for ax in self.current_ax.tolist() + self.gen_ax.tolist():
|
||||
ax.set_facecolor("#F0F0F0")
|
||||
for side in ["top", "right", "bottom", "left"]:
|
||||
ax.spines[side].set_visible(False)
|
||||
|
||||
## Generation
|
||||
self.text_prompt = QPlainTextEdit(default_text)
|
||||
gen_layout.addWidget(self.text_prompt, stretch=1)
|
||||
|
||||
if vc_mode:
|
||||
layout = QHBoxLayout()
|
||||
self.convert_button = QPushButton("Extract and Convert")
|
||||
layout.addWidget(self.convert_button)
|
||||
gen_layout.addLayout(layout)
|
||||
else:
|
||||
self.generate_button = QPushButton("Synthesize and vocode")
|
||||
gen_layout.addWidget(self.generate_button)
|
||||
layout = QHBoxLayout()
|
||||
self.synthesize_button = QPushButton("Synthesize only")
|
||||
layout.addWidget(self.synthesize_button)
|
||||
|
||||
self.vocode_button = QPushButton("Vocode only")
|
||||
layout.addWidget(self.vocode_button)
|
||||
gen_layout.addLayout(layout)
|
||||
|
||||
|
||||
layout_seed = QGridLayout()
|
||||
self.random_seed_checkbox = QCheckBox("Random seed:")
|
||||
self.random_seed_checkbox.setToolTip("When checked, makes the synthesizer and vocoder deterministic.")
|
||||
layout_seed.addWidget(self.random_seed_checkbox, 0, 0)
|
||||
self.seed_textbox = QLineEdit()
|
||||
self.seed_textbox.setMaximumWidth(80)
|
||||
layout_seed.addWidget(self.seed_textbox, 0, 1)
|
||||
self.trim_silences_checkbox = QCheckBox("Enhance vocoder output")
|
||||
self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output."
|
||||
" This feature requires `webrtcvad` to be installed.")
|
||||
layout_seed.addWidget(self.trim_silences_checkbox, 0, 2, 1, 2)
|
||||
self.style_slider = QSlider(Qt.Horizontal)
|
||||
self.style_slider.setTickInterval(1)
|
||||
self.style_slider.setFocusPolicy(Qt.NoFocus)
|
||||
self.style_slider.setSingleStep(1)
|
||||
self.style_slider.setRange(-1, 9)
|
||||
self.style_value_label = QLabel("-1")
|
||||
self.style_slider.setValue(-1)
|
||||
layout_seed.addWidget(QLabel("Style:"), 1, 0)
|
||||
|
||||
self.style_slider.valueChanged.connect(lambda s: self.style_value_label.setNum(s))
|
||||
layout_seed.addWidget(self.style_value_label, 1, 1)
|
||||
layout_seed.addWidget(self.style_slider, 1, 3)
|
||||
|
||||
self.token_slider = QSlider(Qt.Horizontal)
|
||||
self.token_slider.setTickInterval(1)
|
||||
self.token_slider.setFocusPolicy(Qt.NoFocus)
|
||||
self.token_slider.setSingleStep(1)
|
||||
self.token_slider.setRange(3, 9)
|
||||
self.token_value_label = QLabel("5")
|
||||
self.token_slider.setValue(4)
|
||||
layout_seed.addWidget(QLabel("Accuracy(精度):"), 2, 0)
|
||||
|
||||
self.token_slider.valueChanged.connect(lambda s: self.token_value_label.setNum(s))
|
||||
layout_seed.addWidget(self.token_value_label, 2, 1)
|
||||
layout_seed.addWidget(self.token_slider, 2, 3)
|
||||
|
||||
self.length_slider = QSlider(Qt.Horizontal)
|
||||
self.length_slider.setTickInterval(1)
|
||||
self.length_slider.setFocusPolicy(Qt.NoFocus)
|
||||
self.length_slider.setSingleStep(1)
|
||||
self.length_slider.setRange(1, 10)
|
||||
self.length_value_label = QLabel("2")
|
||||
self.length_slider.setValue(2)
|
||||
layout_seed.addWidget(QLabel("MaxLength(最大句长):"), 3, 0)
|
||||
|
||||
self.length_slider.valueChanged.connect(lambda s: self.length_value_label.setNum(s))
|
||||
layout_seed.addWidget(self.length_value_label, 3, 1)
|
||||
layout_seed.addWidget(self.length_slider, 3, 3)
|
||||
|
||||
gen_layout.addLayout(layout_seed)
|
||||
|
||||
self.loading_bar = QProgressBar()
|
||||
gen_layout.addWidget(self.loading_bar)
|
||||
|
||||
self.log_window = QLabel()
|
||||
self.log_window.setAlignment(Qt.AlignBottom | Qt.AlignLeft)
|
||||
gen_layout.addWidget(self.log_window)
|
||||
self.logs = []
|
||||
gen_layout.addStretch()
|
||||
|
||||
|
||||
## Set the size of the window and of the elements
|
||||
max_size = QDesktopWidget().availableGeometry(self).size() * 0.5
|
||||
self.resize(max_size)
|
||||
|
||||
## Finalize the display
|
||||
self.reset_interface(vc_mode)
|
||||
self.show()
|
||||
|
||||
def start(self):
|
||||
self.app.exec_()
|
||||
5
control/toolbox/utterance.py
Normal file
5
control/toolbox/utterance.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from collections import namedtuple
|
||||
|
||||
Utterance = namedtuple("Utterance", "name speaker_name wav spec embed partial_embeds synth")
|
||||
Utterance.__eq__ = lambda x, y: x.name == y.name
|
||||
Utterance.__hash__ = lambda x: hash(x.name)
|
||||
Reference in New Issue
Block a user