mirror of
https://github.com/babysor/Realtime-Voice-Clone-Chinese.git
synced 2026-04-23 18:31:14 +08:00
Refactor Project to 3 parts: Models, Control, Data
Need readme
This commit is contained in:
64
control/cli/encoder_preprocess.py
Normal file
64
control/cli/encoder_preprocess.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from models.encoder.preprocess import (preprocess_aidatatang_200zh,
|
||||
preprocess_librispeech, preprocess_voxceleb1,
|
||||
preprocess_voxceleb2)
|
||||
from utils.argutils import print_args
|
||||
|
||||
if __name__ == "__main__":
|
||||
class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
|
||||
pass
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocesses audio files from datasets, encodes them as mel spectrograms and "
|
||||
"writes them to the disk. This will allow you to train the encoder. The "
|
||||
"datasets required are at least one of LibriSpeech, VoxCeleb1, VoxCeleb2, aidatatang_200zh. ",
|
||||
formatter_class=MyFormatter
|
||||
)
|
||||
parser.add_argument("datasets_root", type=Path, help=\
|
||||
"Path to the directory containing your LibriSpeech/TTS and VoxCeleb datasets.")
|
||||
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
|
||||
"Path to the output directory that will contain the mel spectrograms. If left out, "
|
||||
"defaults to <datasets_root>/SV2TTS/encoder/")
|
||||
parser.add_argument("-d", "--datasets", type=str,
|
||||
default="librispeech_other,voxceleb1,aidatatang_200zh", help=\
|
||||
"Comma-separated list of the name of the datasets you want to preprocess. Only the train "
|
||||
"set of these datasets will be used. Possible names: librispeech_other, voxceleb1, "
|
||||
"voxceleb2.")
|
||||
parser.add_argument("-s", "--skip_existing", action="store_true", help=\
|
||||
"Whether to skip existing output files with the same name. Useful if this script was "
|
||||
"interrupted.")
|
||||
parser.add_argument("--no_trim", action="store_true", help=\
|
||||
"Preprocess audio without trimming silences (not recommended).")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Verify webrtcvad is available
|
||||
if not args.no_trim:
|
||||
try:
|
||||
import webrtcvad
|
||||
except:
|
||||
raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
|
||||
"noise removal and is recommended. Please install and try again. If installation fails, "
|
||||
"use --no_trim to disable this error message.")
|
||||
del args.no_trim
|
||||
|
||||
# Process the arguments
|
||||
args.datasets = args.datasets.split(",")
|
||||
if not hasattr(args, "out_dir"):
|
||||
args.out_dir = args.datasets_root.joinpath("SV2TTS", "encoder")
|
||||
assert args.datasets_root.exists()
|
||||
args.out_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Preprocess the datasets
|
||||
print_args(args, parser)
|
||||
preprocess_func = {
|
||||
"librispeech_other": preprocess_librispeech,
|
||||
"voxceleb1": preprocess_voxceleb1,
|
||||
"voxceleb2": preprocess_voxceleb2,
|
||||
"aidatatang_200zh": preprocess_aidatatang_200zh,
|
||||
}
|
||||
args = vars(args)
|
||||
for dataset in args.pop("datasets"):
|
||||
print("Preprocessing %s" % dataset)
|
||||
preprocess_func[dataset](**args)
|
||||
47
control/cli/encoder_train.py
Normal file
47
control/cli/encoder_train.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from utils.argutils import print_args
|
||||
from models.encoder.train import train
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Trains the speaker encoder. You must have run encoder_preprocess.py first.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument("run_id", type=str, help= \
|
||||
"Name for this model instance. If a model state from the same run ID was previously "
|
||||
"saved, the training will restart from there. Pass -f to overwrite saved states and "
|
||||
"restart from scratch.")
|
||||
parser.add_argument("clean_data_root", type=Path, help= \
|
||||
"Path to the output directory of encoder_preprocess.py. If you left the default "
|
||||
"output directory when preprocessing, it should be <datasets_root>/SV2TTS/encoder/.")
|
||||
parser.add_argument("-m", "--models_dir", type=Path, default="encoder/saved_models/", help=\
|
||||
"Path to the output directory that will contain the saved model weights, as well as "
|
||||
"backups of those weights and plots generated during training.")
|
||||
parser.add_argument("-v", "--vis_every", type=int, default=10, help= \
|
||||
"Number of steps between updates of the loss and the plots.")
|
||||
parser.add_argument("-u", "--umap_every", type=int, default=100, help= \
|
||||
"Number of steps between updates of the umap projection. Set to 0 to never update the "
|
||||
"projections.")
|
||||
parser.add_argument("-s", "--save_every", type=int, default=500, help= \
|
||||
"Number of steps between updates of the model on the disk. Set to 0 to never save the "
|
||||
"model.")
|
||||
parser.add_argument("-b", "--backup_every", type=int, default=7500, help= \
|
||||
"Number of steps between backups of the model. Set to 0 to never make backups of the "
|
||||
"model.")
|
||||
parser.add_argument("-f", "--force_restart", action="store_true", help= \
|
||||
"Do not load any saved model.")
|
||||
parser.add_argument("--visdom_server", type=str, default="http://localhost")
|
||||
parser.add_argument("--no_visdom", action="store_true", help= \
|
||||
"Disable visdom.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Process the arguments
|
||||
args.models_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Run the training
|
||||
print_args(args, parser)
|
||||
train(**vars(args))
|
||||
|
||||
67
control/cli/ppg2mel_train.py
Normal file
67
control/cli/ppg2mel_train.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import sys
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from utils.load_yaml import HpsYaml
|
||||
from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||
|
||||
# For reproducibility, comment these may speed up training
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def main():
|
||||
# Arguments
|
||||
parser = argparse.ArgumentParser(description=
|
||||
'Training PPG2Mel VC model.')
|
||||
parser.add_argument('--config', type=str,
|
||||
help='Path to experiment config, e.g., config/vc.yaml')
|
||||
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
||||
parser.add_argument('--logdir', default='log/', type=str,
|
||||
help='Logging path.', required=False)
|
||||
parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str,
|
||||
help='Checkpoint path.', required=False)
|
||||
parser.add_argument('--outdir', default='result/', type=str,
|
||||
help='Decode output path.', required=False)
|
||||
parser.add_argument('--load', default=None, type=str,
|
||||
help='Load pre-trained model (for training only)', required=False)
|
||||
parser.add_argument('--warm_start', action='store_true',
|
||||
help='Load model weights only, ignore specified layers.')
|
||||
parser.add_argument('--seed', default=0, type=int,
|
||||
help='Random seed for reproducable results.', required=False)
|
||||
parser.add_argument('--njobs', default=8, type=int,
|
||||
help='Number of threads for dataloader/decoding.', required=False)
|
||||
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
||||
parser.add_argument('--no-pin', action='store_true',
|
||||
help='Disable pin-memory for dataloader')
|
||||
parser.add_argument('--test', action='store_true', help='Test the model.')
|
||||
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
||||
parser.add_argument('--finetune', action='store_true', help='Finetune model')
|
||||
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
|
||||
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
|
||||
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
|
||||
|
||||
###
|
||||
|
||||
paras = parser.parse_args()
|
||||
setattr(paras, 'gpu', not paras.cpu)
|
||||
setattr(paras, 'pin_memory', not paras.no_pin)
|
||||
setattr(paras, 'verbose', not paras.no_msg)
|
||||
# Make the config dict dot visitable
|
||||
config = HpsYaml(paras.config)
|
||||
|
||||
np.random.seed(paras.seed)
|
||||
torch.manual_seed(paras.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(paras.seed)
|
||||
|
||||
print(">>> OneShot VC training ...")
|
||||
mode = "train"
|
||||
solver = Solver(config, paras, mode)
|
||||
solver.load_data()
|
||||
solver.set_model()
|
||||
solver.exec()
|
||||
print(">>> Oneshot VC train finished!")
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
49
control/cli/pre4ppg.py
Normal file
49
control/cli/pre4ppg.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
from models.ppg2mel.preprocess import preprocess_dataset
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
recognized_datasets = [
|
||||
"aidatatang_200zh",
|
||||
"aidatatang_200zh_s", # sample
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocesses audio files from datasets, to be used by the "
|
||||
"ppg2mel model for training.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument("datasets_root", type=Path, help=\
|
||||
"Path to the directory containing your datasets.")
|
||||
parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\
|
||||
"Name of the dataset to process, allowing values: aidatatang_200zh.")
|
||||
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
|
||||
"Path to the output directory that will contain the mel spectrograms, the audios and the "
|
||||
"embeds. Defaults to <datasets_root>/PPGVC/ppg2mel/")
|
||||
parser.add_argument("-n", "--n_processes", type=int, default=8, help=\
|
||||
"Number of processes in parallel.")
|
||||
# parser.add_argument("-s", "--skip_existing", action="store_true", help=\
|
||||
# "Whether to overwrite existing files with the same name. Useful if the preprocessing was "
|
||||
# "interrupted. ")
|
||||
# parser.add_argument("--hparams", type=str, default="", help=\
|
||||
# "Hyperparameter overrides as a comma-separated list of name-value pairs")
|
||||
# parser.add_argument("--no_trim", action="store_true", help=\
|
||||
# "Preprocess audio without trimming silences (not recommended).")
|
||||
parser.add_argument("-pf", "--ppg_encoder_model_fpath", type=Path, default="ppg_extractor/saved_models/24epoch.pt", help=\
|
||||
"Path your trained ppg encoder model.")
|
||||
parser.add_argument("-sf", "--speaker_encoder_model", type=Path, default="encoder/saved_models/pretrained_bak_5805000.pt", help=\
|
||||
"Path your trained speaker encoder model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.dataset in recognized_datasets, 'is not supported, file a issue to propose a new one'
|
||||
|
||||
# Create directories
|
||||
assert args.datasets_root.exists()
|
||||
if not hasattr(args, "out_dir"):
|
||||
args.out_dir = args.datasets_root.joinpath("PPGVC", "ppg2mel")
|
||||
args.out_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
preprocess_dataset(**vars(args))
|
||||
65
control/cli/synthesizer_preprocess_audio.py
Normal file
65
control/cli/synthesizer_preprocess_audio.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from models.synthesizer.preprocess import preprocess_dataset
|
||||
from models.synthesizer.hparams import hparams
|
||||
from utils.argutils import print_args
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
|
||||
recognized_datasets = [
|
||||
"aidatatang_200zh",
|
||||
"magicdata",
|
||||
"aishell3"
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("This method is deprecaded and will not be longer supported, please use 'pre.py'")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocesses audio files from datasets, encodes them as mel spectrograms "
|
||||
"and writes them to the disk. Audio files are also saved, to be used by the "
|
||||
"vocoder for training.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument("datasets_root", type=Path, help=\
|
||||
"Path to the directory containing your LibriSpeech/TTS datasets.")
|
||||
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
|
||||
"Path to the output directory that will contain the mel spectrograms, the audios and the "
|
||||
"embeds. Defaults to <datasets_root>/SV2TTS/synthesizer/")
|
||||
parser.add_argument("-n", "--n_processes", type=int, default=None, help=\
|
||||
"Number of processes in parallel.")
|
||||
parser.add_argument("-s", "--skip_existing", action="store_true", help=\
|
||||
"Whether to overwrite existing files with the same name. Useful if the preprocessing was "
|
||||
"interrupted.")
|
||||
parser.add_argument("--hparams", type=str, default="", help=\
|
||||
"Hyperparameter overrides as a comma-separated list of name-value pairs")
|
||||
parser.add_argument("--no_trim", action="store_true", help=\
|
||||
"Preprocess audio without trimming silences (not recommended).")
|
||||
parser.add_argument("--no_alignments", action="store_true", help=\
|
||||
"Use this option when dataset does not include alignments\
|
||||
(these are used to split long audio files into sub-utterances.)")
|
||||
parser.add_argument("--dataset", type=str, default="aidatatang_200zh", help=\
|
||||
"Name of the dataset to process, allowing values: magicdata, aidatatang_200zh.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Process the arguments
|
||||
if not hasattr(args, "out_dir"):
|
||||
args.out_dir = args.datasets_root.joinpath("SV2TTS", "synthesizer")
|
||||
assert args.dataset in recognized_datasets, 'is not supported, please vote for it in https://github.com/babysor/MockingBird/issues/10'
|
||||
# Create directories
|
||||
assert args.datasets_root.exists()
|
||||
args.out_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Verify webrtcvad is available
|
||||
if not args.no_trim:
|
||||
try:
|
||||
import webrtcvad
|
||||
except:
|
||||
raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
|
||||
"noise removal and is recommended. Please install and try again. If installation fails, "
|
||||
"use --no_trim to disable this error message.")
|
||||
del args.no_trim
|
||||
|
||||
# Preprocess the dataset
|
||||
print_args(args, parser)
|
||||
args.hparams = hparams.parse(args.hparams)
|
||||
|
||||
preprocess_dataset(**vars(args))
|
||||
26
control/cli/synthesizer_preprocess_embeds.py
Normal file
26
control/cli/synthesizer_preprocess_embeds.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from models.synthesizer.preprocess import create_embeddings
|
||||
from utils.argutils import print_args
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("This method is deprecaded and will not be longer supported, please use 'pre.py'")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Creates embeddings for the synthesizer from the LibriSpeech utterances.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument("synthesizer_root", type=Path, help=\
|
||||
"Path to the synthesizer training data that contains the audios and the train.txt file. "
|
||||
"If you let everything as default, it should be <datasets_root>/SV2TTS/synthesizer/.")
|
||||
parser.add_argument("-e", "--encoder_model_fpath", type=Path,
|
||||
default="encoder/saved_models/pretrained.pt", help=\
|
||||
"Path your trained encoder model.")
|
||||
parser.add_argument("-n", "--n_processes", type=int, default=4, help= \
|
||||
"Number of parallel processes. An encoder is created for each, so you may need to lower "
|
||||
"this value on GPUs with low memory. Set it to 1 if CUDA is unhappy.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Preprocess the dataset
|
||||
print_args(args, parser)
|
||||
create_embeddings(**vars(args))
|
||||
37
control/cli/synthesizer_train.py
Normal file
37
control/cli/synthesizer_train.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from models.synthesizer.hparams import hparams
|
||||
from models.synthesizer.train import train
|
||||
from utils.argutils import print_args
|
||||
import argparse
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("run_id", type=str, help= \
|
||||
"Name for this model instance. If a model state from the same run ID was previously "
|
||||
"saved, the training will restart from there. Pass -f to overwrite saved states and "
|
||||
"restart from scratch.")
|
||||
parser.add_argument("syn_dir", type=str, default=argparse.SUPPRESS, help= \
|
||||
"Path to the synthesizer directory that contains the ground truth mel spectrograms, "
|
||||
"the wavs and the embeds.")
|
||||
parser.add_argument("-m", "--models_dir", type=str, default="synthesizer/saved_models/", help=\
|
||||
"Path to the output directory that will contain the saved model weights and the logs.")
|
||||
parser.add_argument("-s", "--save_every", type=int, default=1000, help= \
|
||||
"Number of steps between updates of the model on the disk. Set to 0 to never save the "
|
||||
"model.")
|
||||
parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \
|
||||
"Number of steps between backups of the model. Set to 0 to never make backups of the "
|
||||
"model.")
|
||||
parser.add_argument("-l", "--log_every", type=int, default=200, help= \
|
||||
"Number of steps between summary the training info in tensorboard")
|
||||
parser.add_argument("-f", "--force_restart", action="store_true", help= \
|
||||
"Do not load any saved model and restart from scratch.")
|
||||
parser.add_argument("--hparams", default="",
|
||||
help="Hyperparameter overrides as a comma-separated list of name=value "
|
||||
"pairs")
|
||||
args = parser.parse_args()
|
||||
print_args(args, parser)
|
||||
|
||||
args.hparams = hparams.parse(args.hparams)
|
||||
|
||||
# Run the training
|
||||
train(**vars(args))
|
||||
59
control/cli/vocoder_preprocess.py
Normal file
59
control/cli/vocoder_preprocess.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from models.synthesizer.synthesize import run_synthesis
|
||||
from models.synthesizer.hparams import hparams
|
||||
from utils.argutils import print_args
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
|
||||
pass
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Creates ground-truth aligned (GTA) spectrograms from the vocoder.",
|
||||
formatter_class=MyFormatter
|
||||
)
|
||||
parser.add_argument("datasets_root", type=str, help=\
|
||||
"Path to the directory containing your SV2TTS directory. If you specify both --in_dir and "
|
||||
"--out_dir, this argument won't be used.")
|
||||
parser.add_argument("-m", "--model_dir", type=str,
|
||||
default="synthesizer/saved_models/mandarin/", help=\
|
||||
"Path to the pretrained model directory.")
|
||||
parser.add_argument("-i", "--in_dir", type=str, default=argparse.SUPPRESS, help= \
|
||||
"Path to the synthesizer directory that contains the mel spectrograms, the wavs and the "
|
||||
"embeds. Defaults to <datasets_root>/SV2TTS/synthesizer/.")
|
||||
parser.add_argument("-o", "--out_dir", type=str, default=argparse.SUPPRESS, help= \
|
||||
"Path to the output vocoder directory that will contain the ground truth aligned mel "
|
||||
"spectrograms. Defaults to <datasets_root>/SV2TTS/vocoder/.")
|
||||
parser.add_argument("--hparams", default="",
|
||||
help="Hyperparameter overrides as a comma-separated list of name=value "
|
||||
"pairs")
|
||||
parser.add_argument("--no_trim", action="store_true", help=\
|
||||
"Preprocess audio without trimming silences (not recommended).")
|
||||
parser.add_argument("--cpu", action="store_true", help=\
|
||||
"If True, processing is done on CPU, even when a GPU is available.")
|
||||
args = parser.parse_args()
|
||||
print_args(args, parser)
|
||||
modified_hp = hparams.parse(args.hparams)
|
||||
|
||||
if not hasattr(args, "in_dir"):
|
||||
args.in_dir = os.path.join(args.datasets_root, "SV2TTS", "synthesizer")
|
||||
if not hasattr(args, "out_dir"):
|
||||
args.out_dir = os.path.join(args.datasets_root, "SV2TTS", "vocoder")
|
||||
|
||||
if args.cpu:
|
||||
# Hide GPUs from Pytorch to force CPU processing
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
|
||||
# Verify webrtcvad is available
|
||||
if not args.no_trim:
|
||||
try:
|
||||
import webrtcvad
|
||||
except:
|
||||
raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
|
||||
"noise removal and is recommended. Please install and try again. If installation fails, "
|
||||
"use --no_trim to disable this error message.")
|
||||
del args.no_trim
|
||||
|
||||
run_synthesis(args.in_dir, args.out_dir, args.model_dir, modified_hp)
|
||||
|
||||
92
control/cli/vocoder_train.py
Normal file
92
control/cli/vocoder_train.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from utils.argutils import print_args
|
||||
from models.vocoder.wavernn.train import train
|
||||
from models.vocoder.hifigan.train import train as train_hifigan
|
||||
from models.vocoder.fregan.train import train as train_fregan
|
||||
from utils.util import AttrDict
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import json
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Trains the vocoder from the synthesizer audios and the GTA synthesized mels, "
|
||||
"or ground truth mels.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument("run_id", type=str, help= \
|
||||
"Name for this model instance. If a model state from the same run ID was previously "
|
||||
"saved, the training will restart from there. Pass -f to overwrite saved states and "
|
||||
"restart from scratch.")
|
||||
parser.add_argument("datasets_root", type=str, help= \
|
||||
"Path to the directory containing your SV2TTS directory. Specifying --syn_dir or --voc_dir "
|
||||
"will take priority over this argument.")
|
||||
parser.add_argument("vocoder_type", type=str, default="wavernn", help= \
|
||||
"Choose the vocoder type for train. Defaults to wavernn"
|
||||
"Now, Support <hifigan> and <wavernn> for choose")
|
||||
parser.add_argument("--syn_dir", type=str, default=argparse.SUPPRESS, help= \
|
||||
"Path to the synthesizer directory that contains the ground truth mel spectrograms, "
|
||||
"the wavs and the embeds. Defaults to <datasets_root>/SV2TTS/synthesizer/.")
|
||||
parser.add_argument("--voc_dir", type=str, default=argparse.SUPPRESS, help= \
|
||||
"Path to the vocoder directory that contains the GTA synthesized mel spectrograms. "
|
||||
"Defaults to <datasets_root>/SV2TTS/vocoder/. Unused if --ground_truth is passed.")
|
||||
parser.add_argument("-m", "--models_dir", type=str, default="vocoder/saved_models/", help=\
|
||||
"Path to the directory that will contain the saved model weights, as well as backups "
|
||||
"of those weights and wavs generated during training.")
|
||||
parser.add_argument("-g", "--ground_truth", action="store_true", help= \
|
||||
"Train on ground truth spectrograms (<datasets_root>/SV2TTS/synthesizer/mels).")
|
||||
parser.add_argument("-s", "--save_every", type=int, default=1000, help= \
|
||||
"Number of steps between updates of the model on the disk. Set to 0 to never save the "
|
||||
"model.")
|
||||
parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \
|
||||
"Number of steps between backups of the model. Set to 0 to never make backups of the "
|
||||
"model.")
|
||||
parser.add_argument("-f", "--force_restart", action="store_true", help= \
|
||||
"Do not load any saved model and restart from scratch.")
|
||||
parser.add_argument("--config", type=str, default="vocoder/hifigan/config_16k_.json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not hasattr(args, "syn_dir"):
|
||||
args.syn_dir = Path(args.datasets_root, "SV2TTS", "synthesizer")
|
||||
args.syn_dir = Path(args.syn_dir)
|
||||
if not hasattr(args, "voc_dir"):
|
||||
args.voc_dir = Path(args.datasets_root, "SV2TTS", "vocoder")
|
||||
args.voc_dir = Path(args.voc_dir)
|
||||
del args.datasets_root
|
||||
args.models_dir = Path(args.models_dir)
|
||||
args.models_dir.mkdir(exist_ok=True)
|
||||
|
||||
print_args(args, parser)
|
||||
|
||||
# Process the arguments
|
||||
if args.vocoder_type == "wavernn":
|
||||
# Run the training wavernn
|
||||
delattr(args, 'vocoder_type')
|
||||
delattr(args, 'config')
|
||||
train(**vars(args))
|
||||
elif args.vocoder_type == "hifigan":
|
||||
with open(args.config) as f:
|
||||
json_config = json.load(f)
|
||||
h = AttrDict(json_config)
|
||||
if h.num_gpus > 1:
|
||||
h.num_gpus = torch.cuda.device_count()
|
||||
h.batch_size = int(h.batch_size / h.num_gpus)
|
||||
print('Batch size per GPU :', h.batch_size)
|
||||
mp.spawn(train_hifigan, nprocs=h.num_gpus, args=(args, h,))
|
||||
else:
|
||||
train_hifigan(0, args, h)
|
||||
elif args.vocoder_type == "fregan":
|
||||
with open('vocoder/fregan/config.json') as f:
|
||||
json_config = json.load(f)
|
||||
h = AttrDict(json_config)
|
||||
if h.num_gpus > 1:
|
||||
h.num_gpus = torch.cuda.device_count()
|
||||
h.batch_size = int(h.batch_size / h.num_gpus)
|
||||
print('Batch size per GPU :', h.batch_size)
|
||||
mp.spawn(train_fregan, nprocs=h.num_gpus, args=(args, h,))
|
||||
else:
|
||||
train_fregan(0, args, h)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user