diff --git a/mkgui/base/ui/streamlit_ui.py b/mkgui/base/ui/streamlit_ui.py
index 08232f7..2e5159d 100644
--- a/mkgui/base/ui/streamlit_ui.py
+++ b/mkgui/base/ui/streamlit_ui.py
@@ -51,9 +51,13 @@ def launch_ui(port: int = 8501) -> None:
python_path = f'PYTHONPATH="$PYTHONPATH:{getcwd()}"'
if system() == "Windows":
python_path = f"set PYTHONPATH=%PYTHONPATH%;{getcwd()} &&"
+ subprocess.run(
+ f"""set STREAMLIT_GLOBAL_SHOW_WARNING_ON_DIRECT_EXECUTION=false""",
+ shell=True,
+ )
subprocess.run(
- f"""{python_path} "{sys.executable}" -m streamlit run --server.port={port} --server.headless=True --runner.magicEnabled=False --server.maxUploadSize=50 --browser.gatherUsageStats=False {f.name}""",
+ f"""{python_path} "{sys.executable}" -m streamlit run --server.port={port} --server.headless=True --runner.magicEnabled=False --server.maxUploadSize=50 --browser.gatherUsageStats=False {f.name}""",
shell=True,
)
@@ -122,10 +126,11 @@ class InputUI:
property["title"] = name_to_title(property_key)
try:
- self._store_value(
- property_key,
- self._render_property(streamlit_app_root, property_key, property),
- )
+ if "input_data" in self._session_state:
+ self._store_value(
+ property_key,
+ self._render_property(streamlit_app_root, property_key, property),
+ )
except Exception as e:
print("Exception!", e)
pass
@@ -807,6 +812,9 @@ def getOpyrator(mode: str) -> Opyrator:
if mode == None or mode.startswith('预处理'):
from mkgui.preprocess import preprocess
return Opyrator(preprocess)
+ if mode == None or mode.startswith('模型训练'):
+ from mkgui.train import train
+ return Opyrator(train)
from mkgui.app import synthesize
return Opyrator(synthesize)
@@ -815,11 +823,13 @@ def render_streamlit_ui() -> None:
# init
session_state = st.session_state
session_state.input_data = {}
+ # Add custom css settings
+ st.markdown(f"", unsafe_allow_html=True)
with st.spinner("Loading MockingBird GUI. Please wait..."):
session_state.mode = st.sidebar.selectbox(
'模式选择',
- ( "AI拟音", "VC拟音", "预处理")
+ ( "AI拟音", "VC拟音", "预处理", "模型训练")
)
if "mode" in session_state:
mode = session_state.mode
@@ -872,6 +882,4 @@ def render_streamlit_ui() -> None:
# placeholder
st.caption("请使用左侧控制板进行输入并运行获得结果")
- # Add custom css settings
- st.markdown(f"", unsafe_allow_html=True)
-
+
diff --git a/mkgui/train.py b/mkgui/train.py
new file mode 100644
index 0000000..02287d7
--- /dev/null
+++ b/mkgui/train.py
@@ -0,0 +1,148 @@
+from pydantic import BaseModel, Field
+import os
+from pathlib import Path
+from enum import Enum
+from typing import Any
+import numpy as np
+from utils.load_yaml import HpsYaml
+from utils.util import AttrDict
+import torch
+
+# TODO: seperator for *unix systems
+# Constants
+EXT_MODELS_DIRT = "ppg_extractor\\saved_models"
+CONV_MODELS_DIRT = "ppg2mel\\saved_models"
+ENC_MODELS_DIRT = "encoder\\saved_models"
+
+
+if os.path.isdir(EXT_MODELS_DIRT):
+ extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
+ print("Loaded extractor models: " + str(len(extractors)))
+if os.path.isdir(CONV_MODELS_DIRT):
+ convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
+ print("Loaded convertor models: " + str(len(convertors)))
+if os.path.isdir(ENC_MODELS_DIRT):
+ encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
+ print("Loaded encoders models: " + str(len(encoders)))
+
+class Model(str, Enum):
+ VC_PPG2MEL = "ppg2mel"
+
+class Dataset(str, Enum):
+ AIDATATANG_200ZH = "aidatatang_200zh"
+ AIDATATANG_200ZH_S = "aidatatang_200zh_s"
+
+class Input(BaseModel):
+ # def render_input_ui(st, input) -> Dict:
+ # input["selected_dataset"] = st.selectbox(
+ # '选择数据集',
+ # ("aidatatang_200zh", "aidatatang_200zh_s")
+ # )
+ # return input
+ model: Model = Field(
+ Model.VC_PPG2MEL, title="模型类型",
+ )
+ # datasets_root: str = Field(
+ # ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
+ # format=True,
+ # example="..\\trainning_data\\"
+ # )
+ output_root: str = Field(
+ ..., alias="输出目录(可选)", description="建议不填,保持默认",
+ format=True,
+ example=""
+ )
+ continue_mode: bool = Field(
+ True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练",
+ )
+ gpu: bool = Field(
+ True, alias="GPU训练", description="选择“是”,则使用GPU训练",
+ )
+ verbose: bool = Field(
+ True, alias="打印详情", description="选择“是”,输出更多详情",
+ )
+ # TODO: Move to hiden fields by default
+ convertor: convertors = Field(
+ ..., alias="转换模型",
+ description="选择语音转换模型文件."
+ )
+ extractor: extractors = Field(
+ ..., alias="特征提取模型",
+ description="选择PPG特征提取模型文件."
+ )
+ encoder: encoders = Field(
+ ..., alias="语音编码模型",
+ description="选择语音编码模型文件."
+ )
+ njobs: int = Field(
+ 8, alias="进程数", description="适用于ppg2mel",
+ )
+ seed: int = Field(
+ default=0, alias="初始随机数", description="适用于ppg2mel",
+ )
+ model_name: str = Field(
+ ..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效",
+ example="test"
+ )
+ model_config: str = Field(
+ ..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效",
+ example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2"
+ )
+
+class AudioEntity(BaseModel):
+ content: bytes
+ mel: Any
+
+class Output(BaseModel):
+ __root__: tuple[str, int]
+
+ def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
+ """Custom output UI.
+ If this method is implmeneted, it will be used instead of the default Output UI renderer.
+ """
+ sr, count = self.__root__
+ streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
+
+def train(input: Input) -> Output:
+ """Train(训练)"""
+
+ print(">>> OneShot VC training ...")
+ params = AttrDict()
+ params.update({
+ "gpu": input.gpu,
+ "cpu": not input.gpu,
+ "njobs": input.njobs,
+ "seed": input.seed,
+ "verbose": input.verbose,
+ "load": input.convertor.value,
+ "warm_start": False,
+ })
+ if input.continue_mode:
+ # trace old model and config
+ p = Path(input.convertor.value)
+ params.name = p.parent.name
+ # search a config file
+ model_config_fpaths = list(p.parent.rglob("*.yaml"))
+ if len(model_config_fpaths) == 0:
+ raise "No model yaml config found for convertor"
+ config = HpsYaml(model_config_fpaths[0])
+ params.ckpdir = p.parent.parent
+ params.config = model_config_fpaths[0]
+ params.logdir = os.path.join(p.parent, "log")
+ else:
+ # Make the config dict dot visitable
+ config = HpsYaml(input.config)
+ np.random.seed(input.seed)
+ torch.manual_seed(input.seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(input.seed)
+ mode = "train"
+ from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
+ solver = Solver(config, params, mode)
+ solver.load_data()
+ solver.set_model()
+ solver.exec()
+ print(">>> Oneshot VC train finished!")
+
+ # TODO: pass useful return code
+ return Output(__root__=(input.dataset, 0))
\ No newline at end of file
diff --git a/ppg2mel/train.py b/ppg2mel/train.py
index fed7501..d3ef729 100644
--- a/ppg2mel/train.py
+++ b/ppg2mel/train.py
@@ -31,15 +31,10 @@ def main():
parser.add_argument('--njobs', default=8, type=int,
help='Number of threads for dataloader/decoding.', required=False)
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
- parser.add_argument('--no-pin', action='store_true',
- help='Disable pin-memory for dataloader')
- parser.add_argument('--test', action='store_true', help='Test the model.')
+ # parser.add_argument('--no-pin', action='store_true',
+ # help='Disable pin-memory for dataloader')
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
- parser.add_argument('--finetune', action='store_true', help='Finetune model')
- parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
- parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
- parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
-
+
###
paras = parser.parse_args()
diff --git a/ppg2mel/train/solver.py b/ppg2mel/train/solver.py
index 264a91c..9ca71cb 100644
--- a/ppg2mel/train/solver.py
+++ b/ppg2mel/train/solver.py
@@ -93,6 +93,7 @@ class BaseSolver():
def load_ckpt(self):
''' Load ckpt if --load option is specified '''
+ print(self.paras)
if self.paras.load is not None:
if self.paras.warm_start:
self.verbose(f"Warm starting model from checkpoint {self.paras.load}.")
@@ -100,7 +101,7 @@ class BaseSolver():
self.paras.load, map_location=self.device if self.mode == 'train'
else 'cpu')
model_dict = ckpt['model']
- if len(self.config.model.ignore_layers) > 0:
+ if "ignore_layers" in self.config.model and len(self.config.model.ignore_layers) > 0:
model_dict = {k:v for k, v in model_dict.items()
if k not in self.config.model.ignore_layers}
dummy_dict = self.model.state_dict()
diff --git a/utils/util.py b/utils/util.py
index 5227538..34bcffd 100644
--- a/utils/util.py
+++ b/utils/util.py
@@ -42,3 +42,9 @@ def human_format(num):
# add more suffixes if you need them
return '{:3.1f}{}'.format(num, [' ', 'K', 'M', 'G', 'T', 'P'][magnitude])
+
+# provide easy access of attribute from dict, such abc.key
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
diff --git a/vocoder/hifigan/env.py b/vocoder/hifigan/env.py
index 2bdbc95..8f0d306 100644
--- a/vocoder/hifigan/env.py
+++ b/vocoder/hifigan/env.py
@@ -1,13 +1,6 @@
import os
import shutil
-
-class AttrDict(dict):
- def __init__(self, *args, **kwargs):
- super(AttrDict, self).__init__(*args, **kwargs)
- self.__dict__ = self
-
-
def build_env(config, config_name, path):
t_path = os.path.join(path, config_name)
if config != t_path:
diff --git a/vocoder/hifigan/inference.py b/vocoder/hifigan/inference.py
index 423cbc6..8caf348 100644
--- a/vocoder/hifigan/inference.py
+++ b/vocoder/hifigan/inference.py
@@ -3,7 +3,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import json
import torch
-from vocoder.hifigan.env import AttrDict
+from utils.util import AttrDict
from vocoder.hifigan.models import Generator
generator = None # type: Generator
diff --git a/vocoder/hifigan/train.py b/vocoder/hifigan/train.py
index 987bcca..8760274 100644
--- a/vocoder/hifigan/train.py
+++ b/vocoder/hifigan/train.py
@@ -12,7 +12,6 @@ from torch.utils.data import DistributedSampler, DataLoader
import torch.multiprocessing as mp
from torch.distributed import init_process_group
from torch.nn.parallel import DistributedDataParallel
-from vocoder.hifigan.env import AttrDict, build_env
from vocoder.hifigan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
from vocoder.hifigan.models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\
discriminator_loss
diff --git a/vocoder_train.py b/vocoder_train.py
index d3ad0f5..1ef0e30 100644
--- a/vocoder_train.py
+++ b/vocoder_train.py
@@ -1,7 +1,7 @@
from utils.argutils import print_args
from vocoder.wavernn.train import train
from vocoder.hifigan.train import train as train_hifigan
-from vocoder.hifigan.env import AttrDict
+from utils.util import AttrDict
from pathlib import Path
import argparse
import json