mirror of
https://github.com/babysor/Realtime-Voice-Clone-Chinese.git
synced 2026-02-04 02:54:07 +08:00
Compare commits
10 Commits
refactor
...
ppg-vc-ini
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3a2d50c862 | ||
|
|
d786e78121 | ||
|
|
6befb700e9 | ||
|
|
dd3abebc4d | ||
|
|
eeee32f3e3 | ||
|
|
8ef5e1411d | ||
|
|
20bea3546b | ||
|
|
fad5023fca | ||
|
|
19eaa68202 | ||
|
|
379fd2b9fd |
1
.github/FUNDING.yml
vendored
1
.github/FUNDING.yml
vendored
@@ -1 +0,0 @@
|
||||
github: babysor
|
||||
17
.github/ISSUE_TEMPLATE/issue.md
vendored
17
.github/ISSUE_TEMPLATE/issue.md
vendored
@@ -1,17 +0,0 @@
|
||||
---
|
||||
name: Issue
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Summary[问题简述(一句话)]**
|
||||
A clear and concise description of what the issue is.
|
||||
|
||||
**Env & To Reproduce[复现与环境]**
|
||||
描述你用的环境、代码版本、模型
|
||||
|
||||
**Screenshots[截图(如有)]**
|
||||
If applicable, add screenshots to help
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -13,6 +13,7 @@
|
||||
*.bbl
|
||||
*.bcf
|
||||
*.toc
|
||||
*.wav
|
||||
*.sh
|
||||
*/saved_models
|
||||
!vocoder/saved_models/pretrained/**
|
||||
|
||||
2
.vscode/launch.json
vendored
2
.vscode/launch.json
vendored
@@ -60,6 +60,6 @@
|
||||
"args": ["-c", ".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml",
|
||||
"-m", ".\\ppg2mel\\saved_models\\best_loss_step_304000.pth", "--wav_dir", ".\\wavs\\input", "--ref_wav_path", ".\\wavs\\pkq.mp3", "-o", ".\\wavs\\output\\"
|
||||
]
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
40
README-CN.md
40
README-CN.md
@@ -18,19 +18,10 @@
|
||||
|
||||
🌍 **Webserver Ready** 可伺服你的训练结果,供远程调用
|
||||
|
||||
### 进行中的工作
|
||||
* GUI/客户端大升级与合并
|
||||
[X] 初始化框架 `./mkgui` (基于streamlit + fastapi)和 [技术设计](https://vaj2fgg8yn.feishu.cn/docs/doccnvotLWylBub8VJIjKzoEaee)
|
||||
[X] 增加 Voice Cloning and Conversion的演示页面
|
||||
[X] 增加Voice Conversion的预处理preprocessing 和训练 training 页面
|
||||
[ ] 增加其他的的预处理preprocessing 和训练 training 页面
|
||||
* 模型后端基于ESPnet2升级
|
||||
|
||||
|
||||
## 开始
|
||||
### 1. 安装要求
|
||||
> 按照原始存储库测试您是否已准备好所有环境。
|
||||
运行工具箱(demo_toolbox.py)需要 **Python 3.7 或更高版本** 。
|
||||
**Python 3.7 或更高版本** 需要运行工具箱。
|
||||
|
||||
* 安装 [PyTorch](https://pytorch.org/get-started/locally/)。
|
||||
> 如果在用 pip 方式安装的时候出现 `ERROR: Could not find a version that satisfies the requirement torch==1.9.0+cu102 (from versions: 0.1.2, 0.1.2.post1, 0.1.2.post2)` 这个错误可能是 python 版本过低,3.9 可以安装成功
|
||||
@@ -77,7 +68,7 @@
|
||||
对效果影响不大,已经预置3款,如果希望自己训练可以参考以下命令。
|
||||
* 预处理数据:
|
||||
`python vocoder_preprocess.py <datasets_root> -m <synthesizer_model_path>`
|
||||
> `<datasets_root>`替换为你的数据集目录,`<synthesizer_model_path>`替换为一个你最好的synthesizer模型目录,例如 *sythensizer\saved_models\xxx*
|
||||
> `<datasets_root>`替换为你的数据集目录,`<synthesizer_model_path>`替换为一个你最好的synthesizer模型目录,例如 *sythensizer\saved_mode\xxx*
|
||||
|
||||
|
||||
* 训练wavernn声码器:
|
||||
@@ -87,17 +78,19 @@
|
||||
* 训练hifigan声码器:
|
||||
`python vocoder_train.py <trainid> <datasets_root> hifigan`
|
||||
> `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
|
||||
* 训练fregan声码器:
|
||||
`python vocoder_train.py <trainid> <datasets_root> --config config.json fregan`
|
||||
> `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
|
||||
* 将GAN声码器的训练切换为多GPU模式:修改GAN文件夹下.json文件中的"num_gpus"参数
|
||||
|
||||
### 3. 启动程序或工具箱
|
||||
您可以尝试使用以下命令:
|
||||
|
||||
### 3.1 启动Web程序(v2):
|
||||
### 3.1 启动Web程序:
|
||||
`python web.py`
|
||||
运行成功后在浏览器打开地址, 默认为 `http://localhost:8080`
|
||||

|
||||
> 注:目前界面比较buggy,
|
||||
> * 第一次点击`录制`要等待几秒浏览器正常启动录音,否则会有重音
|
||||
> * 录制结束不要再点`录制`而是`停止`
|
||||
> * 仅支持手动新录音(16khz), 不支持超过4MB的录音,最佳长度在5~15秒
|
||||
> * 默认使用第一个找到的模型,有动手能力的可以看代码修改 `web\__init__.py`。
|
||||
|
||||
### 3.2 启动工具箱:
|
||||
`python demo_toolbox.py -d <datasets_root>`
|
||||
@@ -108,12 +101,11 @@
|
||||
### 4. 番外:语音转换Voice Conversion(PPG based)
|
||||
想像柯南拿着变声器然后发出毛利小五郎的声音吗?本项目现基于PPG-VC,引入额外两个模块(PPG extractor + PPG2Mel), 可以实现变声功能。(文档不全,尤其是训练部分,正在努力补充中)
|
||||
#### 4.0 准备环境
|
||||
* 确保项目以上环境已经安装ok,运行`pip install espnet` 来安装剩余的必要包。
|
||||
* 下载以下模型 链接:https://pan.baidu.com/s/1bl_x_DHJSAUyN2fma-Q_Wg
|
||||
提取码:gh41
|
||||
* 24K采样率专用的vocoder(hifigan)到 *vocoder\saved_models\xxx*
|
||||
* 预训练的ppg特征encoder(ppg_extractor)到 *ppg_extractor\saved_models\xxx*
|
||||
* 预训练的PPG2Mel到 *ppg2mel\saved_models\xxx*
|
||||
* 确保项目以上环境已经安装ok,运行`pip install -r requirements.txt` 来安装剩余的必要包。
|
||||
* 下载以下模型
|
||||
* 24K采样率专用的vocoder(hifigan)到 *vocoder\saved_mode\xxx*
|
||||
* 预训练的ppg特征encoder(ppg_extractor)到 *ppg_extractor\saved_mode\xxx*
|
||||
* 预训练的PPG2Mel到 *ppg2mel\saved_mode\xxx*
|
||||
|
||||
#### 4.1 使用数据集自己训练PPG2Mel模型 (可选)
|
||||
|
||||
@@ -131,9 +123,8 @@
|
||||
|
||||
#### 4.2 启动工具箱VC模式
|
||||
您可以尝试使用以下命令:
|
||||
`python demo_toolbox.py -vc -d <datasets_root>`
|
||||
`python demo_toolbox.py vc -d <datasets_root>`
|
||||
> 请指定一个可用的数据集文件路径,如果有支持的数据集则会自动加载供调试,也同时会作为手动录制音频的存储目录。
|
||||
<img width="971" alt="微信图片_20220305005351" src="https://user-images.githubusercontent.com/7423248/156805733-2b093dbc-d989-4e68-8609-db11f365886a.png">
|
||||
|
||||
## 引用及论文
|
||||
> 该库一开始从仅支持英语的[Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) 分叉出来的,鸣谢作者。
|
||||
@@ -142,7 +133,6 @@
|
||||
| --- | ----------- | ----- | --------------------- |
|
||||
| [1803.09017](https://arxiv.org/abs/1803.09017) | GlobalStyleToken (synthesizer)| Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis | 本代码库 |
|
||||
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | 本代码库 |
|
||||
| [2106.02297](https://arxiv.org/abs/2106.02297) | Fre-GAN (vocoder)| Fre-GAN: Adversarial Frequency-consistent Audio Synthesis | 本代码库 |
|
||||
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | SV2TTS | Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis | 本代码库 |
|
||||
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|
||||
|[1703.10135](https://arxiv.org/pdf/1703.10135.pdf) | Tacotron (synthesizer) | Tacotron: Towards End-to-End Speech Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN)
|
||||
|
||||
18
README.md
18
README.md
@@ -18,14 +18,6 @@
|
||||
|
||||
### [DEMO VIDEO](https://www.bilibili.com/video/BV17Q4y1B7mY/)
|
||||
|
||||
### Ongoing Works(Helps Needed)
|
||||
* Major upgrade on GUI/Client and unifying web and toolbox
|
||||
[X] Init framework `./mkgui` and [tech design](https://vaj2fgg8yn.feishu.cn/docs/doccnvotLWylBub8VJIjKzoEaee)
|
||||
[X] Add demo part of Voice Cloning and Conversion
|
||||
[X] Add preprocessing and training for Voice Conversion
|
||||
[ ] Add preprocessing and training for Encoder/Synthesizer/Vocoder
|
||||
* Major upgrade on model backend based on ESPnet2(not yet started)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Install Requirements
|
||||
@@ -37,7 +29,7 @@
|
||||
* Install [ffmpeg](https://ffmpeg.org/download.html#get-packages).
|
||||
* Run `pip install -r requirements.txt` to install the remaining necessary packages.
|
||||
* Install webrtcvad `pip install webrtcvad-wheels`(If you need)
|
||||
> Note that we are using the pretrained encoder/vocoder but synthesizer since the original model is incompatible with the Chinese symbols. It means the demo_cli is not working at this moment.
|
||||
> Note that we are using the pretrained encoder/vocoder but synthesizer, since the original model is incompatible with the Chinese sympols. It means the demo_cli is not working at this moment.
|
||||
### 2. Prepare your models
|
||||
You can either train your models or use existing ones:
|
||||
|
||||
@@ -68,7 +60,7 @@ Allowing parameter `--dataset {dataset}` to support aidatatang_200zh, magicdata,
|
||||
| @author | https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g [Baidu](https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g) 4j5d | | 75k steps trained by multiple datasets
|
||||
| @author | https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw [Baidu](https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw) code:om7f | | 25k steps trained by multiple datasets, only works under version 0.0.1
|
||||
|@FawenYo | https://drive.google.com/file/d/1H-YGOUHpmqKxJ9FRc6vAjPuqQki24UbC/view?usp=sharing https://u.teknik.io/AYxWf.pt | [input](https://github.com/babysor/MockingBird/wiki/audio/self_test.mp3) [output](https://github.com/babysor/MockingBird/wiki/audio/export.wav) | 200k steps with local accent of Taiwan, only works under version 0.0.1
|
||||
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ code: 2021 https://www.aliyundrive.com/s/AwPsbo8mcSP code: z2m0 | https://www.bilibili.com/video/BV1uh411B7AD/ | only works under version 0.0.1
|
||||
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ code:2021 | https://www.bilibili.com/video/BV1uh411B7AD/ | only works under version 0.0.1
|
||||
|
||||
#### 2.4 Train vocoder (Optional)
|
||||
> note: vocoder has little difference in effect, so you may not need to train a new one.
|
||||
@@ -90,11 +82,6 @@ You can then try to run:`python web.py` and open it in browser, default as `http
|
||||
You can then try the toolbox:
|
||||
`python demo_toolbox.py -d <datasets_root>`
|
||||
|
||||
#### 3.3 Using the command line
|
||||
You can then try the command:
|
||||
`python gen_voice.py <text_file.txt> your_wav_file.wav`
|
||||
you may need to install cn2an by "pip install cn2an" for better digital number result.
|
||||
|
||||
## Reference
|
||||
> This repository is forked from [Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) which only support English.
|
||||
|
||||
@@ -102,7 +89,6 @@ you may need to install cn2an by "pip install cn2an" for better digital number r
|
||||
| --- | ----------- | ----- | --------------------- |
|
||||
| [1803.09017](https://arxiv.org/abs/1803.09017) | GlobalStyleToken (synthesizer)| Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis | This repo |
|
||||
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | This repo |
|
||||
| [2106.02297](https://arxiv.org/abs/2106.02297) | Fre-GAN (vocoder)| Fre-GAN: Adversarial Frequency-consistent Audio Synthesis | This repo |
|
||||
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | **SV2TTS** | **Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis** | This repo |
|
||||
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|
||||
|[1703.10135](https://arxiv.org/pdf/1703.10135.pdf) | Tacotron (synthesizer) | Tacotron: Towards End-to-End Speech Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN)
|
||||
|
||||
@@ -56,8 +56,8 @@ def wav_to_mel_spectrogram(wav):
|
||||
Note: this not a log-mel spectrogram.
|
||||
"""
|
||||
frames = librosa.feature.melspectrogram(
|
||||
y=wav,
|
||||
sr=sampling_rate,
|
||||
wav,
|
||||
sampling_rate,
|
||||
n_fft=int(sampling_rate * mel_window_length / 1000),
|
||||
hop_length=int(sampling_rate * mel_window_step / 1000),
|
||||
n_mels=mel_n_channels
|
||||
|
||||
128
gen_voice.py
128
gen_voice.py
@@ -1,128 +0,0 @@
|
||||
from encoder.params_model import model_embedding_size as speaker_embedding_size
|
||||
from utils.argutils import print_args
|
||||
from utils.modelutils import check_model_paths
|
||||
from synthesizer.inference import Synthesizer
|
||||
from encoder import inference as encoder
|
||||
from vocoder.wavernn import inference as rnn_vocoder
|
||||
from vocoder.hifigan import inference as gan_vocoder
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import librosa
|
||||
import argparse
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
import cn2an
|
||||
import glob
|
||||
|
||||
from audioread.exceptions import NoBackendError
|
||||
vocoder = gan_vocoder
|
||||
|
||||
def gen_one_wav(synthesizer, in_fpath, embed, texts, file_name, seq):
|
||||
embeds = [embed] * len(texts)
|
||||
# If you know what the attention layer alignments are, you can retrieve them here by
|
||||
# passing return_alignments=True
|
||||
specs = synthesizer.synthesize_spectrograms(texts, embeds, style_idx=-1, min_stop_token=4, steps=400)
|
||||
#spec = specs[0]
|
||||
breaks = [spec.shape[1] for spec in specs]
|
||||
spec = np.concatenate(specs, axis=1)
|
||||
|
||||
# If seed is specified, reset torch seed and reload vocoder
|
||||
# Synthesizing the waveform is fairly straightforward. Remember that the longer the
|
||||
# spectrogram, the more time-efficient the vocoder.
|
||||
generated_wav, output_sample_rate = vocoder.infer_waveform(spec)
|
||||
|
||||
# Add breaks
|
||||
b_ends = np.cumsum(np.array(breaks) * synthesizer.hparams.hop_size)
|
||||
b_starts = np.concatenate(([0], b_ends[:-1]))
|
||||
wavs = [generated_wav[start:end] for start, end, in zip(b_starts, b_ends)]
|
||||
breaks = [np.zeros(int(0.15 * synthesizer.sample_rate))] * len(breaks)
|
||||
generated_wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
|
||||
|
||||
## Post-generation
|
||||
# There's a bug with sounddevice that makes the audio cut one second earlier, so we
|
||||
# pad it.
|
||||
|
||||
# Trim excess silences to compensate for gaps in spectrograms (issue #53)
|
||||
generated_wav = encoder.preprocess_wav(generated_wav)
|
||||
generated_wav = generated_wav / np.abs(generated_wav).max() * 0.97
|
||||
|
||||
# Save it on the disk
|
||||
model=os.path.basename(in_fpath)
|
||||
filename = "%s_%d_%s.wav" %(file_name, seq, model)
|
||||
sf.write(filename, generated_wav, synthesizer.sample_rate)
|
||||
|
||||
print("\nSaved output as %s\n\n" % filename)
|
||||
|
||||
|
||||
def generate_wav(enc_model_fpath, syn_model_fpath, voc_model_fpath, in_fpath, input_txt, file_name):
|
||||
if torch.cuda.is_available():
|
||||
device_id = torch.cuda.current_device()
|
||||
gpu_properties = torch.cuda.get_device_properties(device_id)
|
||||
## Print some environment information (for debugging purposes)
|
||||
print("Found %d GPUs available. Using GPU %d (%s) of compute capability %d.%d with "
|
||||
"%.1fGb total memory.\n" %
|
||||
(torch.cuda.device_count(),
|
||||
device_id,
|
||||
gpu_properties.name,
|
||||
gpu_properties.major,
|
||||
gpu_properties.minor,
|
||||
gpu_properties.total_memory / 1e9))
|
||||
else:
|
||||
print("Using CPU for inference.\n")
|
||||
|
||||
print("Preparing the encoder, the synthesizer and the vocoder...")
|
||||
encoder.load_model(enc_model_fpath)
|
||||
synthesizer = Synthesizer(syn_model_fpath)
|
||||
vocoder.load_model(voc_model_fpath)
|
||||
|
||||
encoder_wav = synthesizer.load_preprocess_wav(in_fpath)
|
||||
embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
||||
|
||||
texts = input_txt.split("\n")
|
||||
seq=0
|
||||
each_num=1500
|
||||
|
||||
punctuation = '!,。、,' # punctuate and split/clean text
|
||||
processed_texts = []
|
||||
cur_num = 0
|
||||
for text in texts:
|
||||
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
|
||||
if processed_text:
|
||||
processed_texts.append(processed_text.strip())
|
||||
cur_num += len(processed_text.strip())
|
||||
if cur_num > each_num:
|
||||
seq = seq +1
|
||||
gen_one_wav(synthesizer, in_fpath, embed, processed_texts, file_name, seq)
|
||||
processed_texts = []
|
||||
cur_num = 0
|
||||
|
||||
if len(processed_texts)>0:
|
||||
seq = seq +1
|
||||
gen_one_wav(synthesizer, in_fpath, embed, processed_texts, file_name, seq)
|
||||
|
||||
if (len(sys.argv)>=3):
|
||||
my_txt = ""
|
||||
print("reading from :", sys.argv[1])
|
||||
with open(sys.argv[1], "r") as f:
|
||||
for line in f.readlines():
|
||||
#line = line.strip('\n')
|
||||
my_txt += line
|
||||
txt_file_name = sys.argv[1]
|
||||
wav_file_name = sys.argv[2]
|
||||
|
||||
output = cn2an.transform(my_txt, "an2cn")
|
||||
print(output)
|
||||
generate_wav(
|
||||
Path("encoder/saved_models/pretrained.pt"),
|
||||
Path("synthesizer/saved_models/mandarin.pt"),
|
||||
Path("vocoder/saved_models/pretrained/g_hifigan.pt"), wav_file_name, output, txt_file_name
|
||||
)
|
||||
|
||||
else:
|
||||
print("please input the file name")
|
||||
exit(1)
|
||||
|
||||
|
||||
145
mkgui/app.py
145
mkgui/app.py
@@ -1,145 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from encoder import inference as encoder
|
||||
import librosa
|
||||
from scipy.io.wavfile import write
|
||||
import re
|
||||
import numpy as np
|
||||
from mkgui.base.components.types import FileContent
|
||||
from vocoder.hifigan import inference as gan_vocoder
|
||||
from synthesizer.inference import Synthesizer
|
||||
from typing import Any, Tuple
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Constants
|
||||
AUDIO_SAMPLES_DIR = f"samples{os.sep}"
|
||||
SYN_MODELS_DIRT = f"synthesizer{os.sep}saved_models"
|
||||
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
|
||||
VOC_MODELS_DIRT = f"vocoder{os.sep}saved_models"
|
||||
TEMP_SOURCE_AUDIO = f"wavs{os.sep}temp_source.wav"
|
||||
TEMP_RESULT_AUDIO = f"wavs{os.sep}temp_result.wav"
|
||||
if not os.path.isdir("wavs"):
|
||||
os.makedirs("wavs")
|
||||
|
||||
# Load local sample audio as options TODO: load dataset
|
||||
if os.path.isdir(AUDIO_SAMPLES_DIR):
|
||||
audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
|
||||
# Pre-Load models
|
||||
if os.path.isdir(SYN_MODELS_DIRT):
|
||||
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
||||
else:
|
||||
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(ENC_MODELS_DIRT):
|
||||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded encoders models: " + str(len(encoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(VOC_MODELS_DIRT):
|
||||
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
|
||||
print("Loaded vocoders models: " + str(len(synthesizers)))
|
||||
else:
|
||||
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
|
||||
|
||||
class Input(BaseModel):
|
||||
message: str = Field(
|
||||
..., example="欢迎使用工具箱, 现已支持中文输入!", alias="文本内容"
|
||||
)
|
||||
local_audio_file: audio_input_selection = Field(
|
||||
..., alias="输入语音(本地wav)",
|
||||
description="选择本地语音文件."
|
||||
)
|
||||
upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
|
||||
description="拖拽或点击上传.", mime_type="audio/wav")
|
||||
encoder: encoders = Field(
|
||||
..., alias="编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
synthesizer: synthesizers = Field(
|
||||
..., alias="合成模型",
|
||||
description="选择语音合成模型文件."
|
||||
)
|
||||
vocoder: vocoders = Field(
|
||||
..., alias="语音解码模型",
|
||||
description="选择语音解码模型文件(目前只支持HifiGan类型)."
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: Tuple[AudioEntity, AudioEntity]
|
||||
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
src, result = self.__root__
|
||||
|
||||
streamlit_app.subheader("Synthesized Audio")
|
||||
streamlit_app.audio(result.content, format="audio/wav")
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(src.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Source Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(result.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Result Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
|
||||
|
||||
def synthesize(input: Input) -> Output:
|
||||
"""synthesize(合成)"""
|
||||
# load models
|
||||
encoder.load_model(Path(input.encoder.value))
|
||||
current_synt = Synthesizer(Path(input.synthesizer.value))
|
||||
gan_vocoder.load_model(Path(input.vocoder.value))
|
||||
|
||||
# load file
|
||||
if input.upload_audio_file != None:
|
||||
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
||||
f.write(input.upload_audio_file.as_bytes())
|
||||
f.seek(0)
|
||||
wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
|
||||
else:
|
||||
wav, sample_rate = librosa.load(input.local_audio_file.value)
|
||||
write(TEMP_SOURCE_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
||||
|
||||
source_spec = Synthesizer.make_spectrogram(wav)
|
||||
|
||||
# preprocess
|
||||
encoder_wav = encoder.preprocess_wav(wav, sample_rate)
|
||||
embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
||||
|
||||
# Load input text
|
||||
texts = filter(None, input.message.split("\n"))
|
||||
punctuation = '!,。、,' # punctuate and split/clean text
|
||||
processed_texts = []
|
||||
for text in texts:
|
||||
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
|
||||
if processed_text:
|
||||
processed_texts.append(processed_text.strip())
|
||||
texts = processed_texts
|
||||
|
||||
# synthesize and vocode
|
||||
embeds = [embed] * len(texts)
|
||||
specs = current_synt.synthesize_spectrograms(texts, embeds)
|
||||
spec = np.concatenate(specs, axis=1)
|
||||
sample_rate = Synthesizer.sample_rate
|
||||
wav, sample_rate = gan_vocoder.infer_waveform(spec)
|
||||
|
||||
# write and output
|
||||
write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
||||
with open(TEMP_SOURCE_AUDIO, "rb") as f:
|
||||
source_file = f.read()
|
||||
with open(TEMP_RESULT_AUDIO, "rb") as f:
|
||||
result_file = f.read()
|
||||
return Output(__root__=(AudioEntity(content=source_file, mel=source_spec), AudioEntity(content=result_file, mel=spec)))
|
||||
166
mkgui/app_vc.py
166
mkgui/app_vc.py
@@ -1,166 +0,0 @@
|
||||
from synthesizer.inference import Synthesizer
|
||||
from pydantic import BaseModel, Field
|
||||
from encoder import inference as speacker_encoder
|
||||
import torch
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
import ppg_extractor as Extractor
|
||||
import ppg2mel as Convertor
|
||||
import librosa
|
||||
from scipy.io.wavfile import write
|
||||
import re
|
||||
import numpy as np
|
||||
from mkgui.base.components.types import FileContent
|
||||
from vocoder.hifigan import inference as gan_vocoder
|
||||
from typing import Any, Tuple
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
# Constants
|
||||
AUDIO_SAMPLES_DIR = f'sample{os.sep}'
|
||||
EXT_MODELS_DIRT = f'ppg_extractor{os.sep}saved_models'
|
||||
CONV_MODELS_DIRT = f'ppg2mel{os.sep}saved_models'
|
||||
VOC_MODELS_DIRT = f'vocoder{os.sep}saved_models'
|
||||
TEMP_SOURCE_AUDIO = f'wavs{os.sep}temp_source.wav'
|
||||
TEMP_TARGET_AUDIO = f'wavs{os.sep}temp_target.wav'
|
||||
TEMP_RESULT_AUDIO = f'wavs{os.sep}temp_result.wav'
|
||||
|
||||
# Load local sample audio as options TODO: load dataset
|
||||
if os.path.isdir(AUDIO_SAMPLES_DIR):
|
||||
audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
|
||||
# Pre-Load models
|
||||
if os.path.isdir(EXT_MODELS_DIRT):
|
||||
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded extractor models: " + str(len(extractors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(CONV_MODELS_DIRT):
|
||||
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
|
||||
print("Loaded convertor models: " + str(len(convertors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(VOC_MODELS_DIRT):
|
||||
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
|
||||
print("Loaded vocoders models: " + str(len(vocoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
class Input(BaseModel):
|
||||
local_audio_file: audio_input_selection = Field(
|
||||
..., alias="输入语音(本地wav)",
|
||||
description="选择本地语音文件."
|
||||
)
|
||||
upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
|
||||
description="拖拽或点击上传.", mime_type="audio/wav")
|
||||
local_audio_file_target: audio_input_selection = Field(
|
||||
..., alias="目标语音(本地wav)",
|
||||
description="选择本地语音文件."
|
||||
)
|
||||
upload_audio_file_target: FileContent = Field(default=None, alias="或上传目标语音",
|
||||
description="拖拽或点击上传.", mime_type="audio/wav")
|
||||
extractor: extractors = Field(
|
||||
..., alias="编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
convertor: convertors = Field(
|
||||
..., alias="转换模型",
|
||||
description="选择语音转换模型文件."
|
||||
)
|
||||
vocoder: vocoders = Field(
|
||||
..., alias="语音解码模型",
|
||||
description="选择语音解码模型文件(目前只支持HifiGan类型)."
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: Tuple[AudioEntity, AudioEntity, AudioEntity]
|
||||
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
src, target, result = self.__root__
|
||||
|
||||
streamlit_app.subheader("Synthesized Audio")
|
||||
streamlit_app.audio(result.content, format="audio/wav")
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(src.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Source Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(target.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Target Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(result.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Result Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
|
||||
def convert(input: Input) -> Output:
|
||||
"""convert(转换)"""
|
||||
# load models
|
||||
extractor = Extractor.load_model(Path(input.extractor.value))
|
||||
convertor = Convertor.load_model(Path(input.convertor.value))
|
||||
# current_synt = Synthesizer(Path(input.synthesizer.value))
|
||||
gan_vocoder.load_model(Path(input.vocoder.value))
|
||||
|
||||
# load file
|
||||
if input.upload_audio_file != None:
|
||||
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
||||
f.write(input.upload_audio_file.as_bytes())
|
||||
f.seek(0)
|
||||
src_wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
|
||||
else:
|
||||
src_wav, sample_rate = librosa.load(input.local_audio_file.value)
|
||||
write(TEMP_SOURCE_AUDIO, sample_rate, src_wav) #Make sure we get the correct wav
|
||||
|
||||
if input.upload_audio_file_target != None:
|
||||
with open(TEMP_TARGET_AUDIO, "w+b") as f:
|
||||
f.write(input.upload_audio_file_target.as_bytes())
|
||||
f.seek(0)
|
||||
ref_wav, _ = librosa.load(TEMP_TARGET_AUDIO)
|
||||
else:
|
||||
ref_wav, _ = librosa.load(input.local_audio_file_target.value)
|
||||
write(TEMP_TARGET_AUDIO, sample_rate, ref_wav) #Make sure we get the correct wav
|
||||
|
||||
ppg = extractor.extract_from_wav(src_wav)
|
||||
# Import necessary dependency of Voice Conversion
|
||||
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
|
||||
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
|
||||
speacker_encoder.load_model(Path("encoder{os.sep}saved_models{os.sep}pretrained_bak_5805000.pt"))
|
||||
embed = speacker_encoder.embed_utterance(ref_wav)
|
||||
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
|
||||
min_len = min(ppg.shape[1], len(lf0_uv))
|
||||
ppg = ppg[:, :min_len]
|
||||
lf0_uv = lf0_uv[:min_len]
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
_, mel_pred, att_ws = convertor.inference(
|
||||
ppg,
|
||||
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
|
||||
spembs=torch.from_numpy(embed).unsqueeze(0).to(device),
|
||||
)
|
||||
mel_pred= mel_pred.transpose(0, 1)
|
||||
breaks = [mel_pred.shape[1]]
|
||||
mel_pred= mel_pred.detach().cpu().numpy()
|
||||
|
||||
# synthesize and vocode
|
||||
wav, sample_rate = gan_vocoder.infer_waveform(mel_pred)
|
||||
|
||||
# write and output
|
||||
write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
||||
with open(TEMP_SOURCE_AUDIO, "rb") as f:
|
||||
source_file = f.read()
|
||||
with open(TEMP_TARGET_AUDIO, "rb") as f:
|
||||
target_file = f.read()
|
||||
with open(TEMP_RESULT_AUDIO, "rb") as f:
|
||||
result_file = f.read()
|
||||
|
||||
|
||||
return Output(__root__=(AudioEntity(content=source_file, mel=Synthesizer.make_spectrogram(src_wav)), AudioEntity(content=target_file, mel=Synthesizer.make_spectrogram(ref_wav)), AudioEntity(content=result_file, mel=Synthesizer.make_spectrogram(wav))))
|
||||
@@ -1,2 +0,0 @@
|
||||
|
||||
from .core import Opyrator
|
||||
@@ -1 +0,0 @@
|
||||
from .fastapi_app import create_api
|
||||
@@ -1,102 +0,0 @@
|
||||
"""Collection of utilities for FastAPI apps."""
|
||||
|
||||
import inspect
|
||||
from typing import Any, Type
|
||||
|
||||
from fastapi import FastAPI, Form
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def as_form(cls: Type[BaseModel]) -> Any:
|
||||
"""Adds an as_form class method to decorated models.
|
||||
|
||||
The as_form class method can be used with FastAPI endpoints
|
||||
"""
|
||||
new_params = [
|
||||
inspect.Parameter(
|
||||
field.alias,
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
default=(Form(field.default) if not field.required else Form(...)),
|
||||
)
|
||||
for field in cls.__fields__.values()
|
||||
]
|
||||
|
||||
async def _as_form(**data): # type: ignore
|
||||
return cls(**data)
|
||||
|
||||
sig = inspect.signature(_as_form)
|
||||
sig = sig.replace(parameters=new_params)
|
||||
_as_form.__signature__ = sig # type: ignore
|
||||
setattr(cls, "as_form", _as_form)
|
||||
return cls
|
||||
|
||||
|
||||
def patch_fastapi(app: FastAPI) -> None:
|
||||
"""Patch function to allow relative url resolution.
|
||||
|
||||
This patch is required to make fastapi fully functional with a relative url path.
|
||||
This code snippet can be copy-pasted to any Fastapi application.
|
||||
"""
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse
|
||||
|
||||
async def redoc_ui_html(req: Request) -> HTMLResponse:
|
||||
assert app.openapi_url is not None
|
||||
redoc_ui = get_redoc_html(
|
||||
openapi_url="./" + app.openapi_url.lstrip("/"),
|
||||
title=app.title + " - Redoc UI",
|
||||
)
|
||||
|
||||
return HTMLResponse(redoc_ui.body.decode("utf-8"))
|
||||
|
||||
async def swagger_ui_html(req: Request) -> HTMLResponse:
|
||||
assert app.openapi_url is not None
|
||||
swagger_ui = get_swagger_ui_html(
|
||||
openapi_url="./" + app.openapi_url.lstrip("/"),
|
||||
title=app.title + " - Swagger UI",
|
||||
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
|
||||
)
|
||||
|
||||
# insert request interceptor to have all request run on relativ path
|
||||
request_interceptor = (
|
||||
"requestInterceptor: (e) => {"
|
||||
"\n\t\t\tvar url = window.location.origin + window.location.pathname"
|
||||
'\n\t\t\turl = url.substring( 0, url.lastIndexOf( "/" ) + 1);'
|
||||
"\n\t\t\turl = e.url.replace(/http(s)?:\/\/[^/]*\//i, url);" # noqa: W605
|
||||
"\n\t\t\te.contextUrl = url"
|
||||
"\n\t\t\te.url = url"
|
||||
"\n\t\t\treturn e;}"
|
||||
)
|
||||
|
||||
return HTMLResponse(
|
||||
swagger_ui.body.decode("utf-8").replace(
|
||||
"dom_id: '#swagger-ui',",
|
||||
"dom_id: '#swagger-ui',\n\t\t" + request_interceptor + ",",
|
||||
)
|
||||
)
|
||||
|
||||
# remove old docs route and add our patched route
|
||||
routes_new = []
|
||||
for app_route in app.routes:
|
||||
if app_route.path == "/docs": # type: ignore
|
||||
continue
|
||||
|
||||
if app_route.path == "/redoc": # type: ignore
|
||||
continue
|
||||
|
||||
routes_new.append(app_route)
|
||||
|
||||
app.router.routes = routes_new
|
||||
|
||||
assert app.docs_url is not None
|
||||
app.add_route(app.docs_url, swagger_ui_html, include_in_schema=False)
|
||||
assert app.redoc_url is not None
|
||||
app.add_route(app.redoc_url, redoc_ui_html, include_in_schema=False)
|
||||
|
||||
# Make graphql realtive
|
||||
from starlette import graphql
|
||||
|
||||
graphql.GRAPHIQL = graphql.GRAPHIQL.replace(
|
||||
"({{REQUEST_PATH}}", '("." + {{REQUEST_PATH}}'
|
||||
)
|
||||
@@ -1,43 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ScoredLabel(BaseModel):
|
||||
label: str
|
||||
score: float
|
||||
|
||||
|
||||
class ClassificationOutput(BaseModel):
|
||||
__root__: List[ScoredLabel]
|
||||
|
||||
def __iter__(self): # type: ignore
|
||||
return iter(self.__root__)
|
||||
|
||||
def __getitem__(self, item): # type: ignore
|
||||
return self.__root__[item]
|
||||
|
||||
def render_output_ui(self, streamlit) -> None: # type: ignore
|
||||
import plotly.express as px
|
||||
|
||||
sorted_predictions = sorted(
|
||||
[prediction.dict() for prediction in self.__root__],
|
||||
key=lambda k: k["score"],
|
||||
)
|
||||
|
||||
num_labels = len(sorted_predictions)
|
||||
if len(sorted_predictions) > 10:
|
||||
num_labels = streamlit.slider(
|
||||
"Maximum labels to show: ",
|
||||
min_value=1,
|
||||
max_value=len(sorted_predictions),
|
||||
value=len(sorted_predictions),
|
||||
)
|
||||
fig = px.bar(
|
||||
sorted_predictions[len(sorted_predictions) - num_labels :],
|
||||
x="score",
|
||||
y="label",
|
||||
orientation="h",
|
||||
)
|
||||
streamlit.plotly_chart(fig, use_container_width=True)
|
||||
# fig.show()
|
||||
@@ -1,46 +0,0 @@
|
||||
import base64
|
||||
from typing import Any, Dict, overload
|
||||
|
||||
|
||||
class FileContent(str):
|
||||
def as_bytes(self) -> bytes:
|
||||
return base64.b64decode(self, validate=True)
|
||||
|
||||
def as_str(self) -> str:
|
||||
return self.as_bytes().decode()
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(format="byte")
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Any: # type: ignore
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: Any) -> "FileContent":
|
||||
if isinstance(value, FileContent):
|
||||
return value
|
||||
elif isinstance(value, str):
|
||||
return FileContent(value)
|
||||
elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||
return FileContent(base64.b64encode(value).decode())
|
||||
else:
|
||||
raise Exception("Wrong type")
|
||||
|
||||
# # 暂时无法使用,因为浏览器中没有考虑选择文件夹
|
||||
# class DirectoryContent(FileContent):
|
||||
# @classmethod
|
||||
# def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
# field_schema.update(format="path")
|
||||
|
||||
# @classmethod
|
||||
# def validate(cls, value: Any) -> "DirectoryContent":
|
||||
# if isinstance(value, DirectoryContent):
|
||||
# return value
|
||||
# elif isinstance(value, str):
|
||||
# return DirectoryContent(value)
|
||||
# elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||
# return DirectoryContent(base64.b64encode(value).decode())
|
||||
# else:
|
||||
# raise Exception("Wrong type")
|
||||
@@ -1,203 +0,0 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import re
|
||||
from typing import Any, Callable, Type, Union, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
from pydantic.tools import parse_obj_as
|
||||
|
||||
|
||||
def name_to_title(name: str) -> str:
|
||||
"""Converts a camelCase or snake_case name to title case."""
|
||||
# If camelCase -> convert to snake case
|
||||
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
||||
name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
|
||||
# Convert to title case
|
||||
return name.replace("_", " ").strip().title()
|
||||
|
||||
|
||||
def is_compatible_type(type: Type) -> bool:
|
||||
"""Returns `True` if the type is opyrator-compatible."""
|
||||
try:
|
||||
if issubclass(type, BaseModel):
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# valid list type
|
||||
if type.__origin__ is list and issubclass(type.__args__[0], BaseModel):
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_input_type(func: Callable) -> Type:
|
||||
"""Returns the input type of a given function (callable).
|
||||
|
||||
Args:
|
||||
func: The function for which to get the input type.
|
||||
|
||||
Raises:
|
||||
ValueError: If the function does not have a valid input type annotation.
|
||||
"""
|
||||
type_hints = get_type_hints(func)
|
||||
|
||||
if "input" not in type_hints:
|
||||
raise ValueError(
|
||||
"The callable MUST have a parameter with the name `input` with typing annotation. "
|
||||
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
|
||||
)
|
||||
|
||||
input_type = type_hints["input"]
|
||||
|
||||
if not is_compatible_type(input_type):
|
||||
raise ValueError(
|
||||
"The `input` parameter MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
|
||||
)
|
||||
|
||||
# TODO: return warning if more than one input parameters
|
||||
|
||||
return input_type
|
||||
|
||||
|
||||
def get_output_type(func: Callable) -> Type:
|
||||
"""Returns the output type of a given function (callable).
|
||||
|
||||
Args:
|
||||
func: The function for which to get the output type.
|
||||
|
||||
Raises:
|
||||
ValueError: If the function does not have a valid output type annotation.
|
||||
"""
|
||||
type_hints = get_type_hints(func)
|
||||
if "return" not in type_hints:
|
||||
raise ValueError(
|
||||
"The return type of the callable MUST be annotated with type hints."
|
||||
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
|
||||
)
|
||||
|
||||
output_type = type_hints["return"]
|
||||
|
||||
if not is_compatible_type(output_type):
|
||||
raise ValueError(
|
||||
"The return value MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
|
||||
)
|
||||
|
||||
return output_type
|
||||
|
||||
|
||||
def get_callable(import_string: str) -> Callable:
|
||||
"""Import a callable from an string."""
|
||||
callable_seperator = ":"
|
||||
if callable_seperator not in import_string:
|
||||
# Use dot as seperator
|
||||
callable_seperator = "."
|
||||
|
||||
if callable_seperator not in import_string:
|
||||
raise ValueError("The callable path MUST specify the function. ")
|
||||
|
||||
mod_name, callable_name = import_string.rsplit(callable_seperator, 1)
|
||||
mod = importlib.import_module(mod_name)
|
||||
return getattr(mod, callable_name)
|
||||
|
||||
|
||||
class Opyrator:
|
||||
def __init__(self, func: Union[Callable, str]) -> None:
|
||||
if isinstance(func, str):
|
||||
# Try to load the function from a string notion
|
||||
self.function = get_callable(func)
|
||||
else:
|
||||
self.function = func
|
||||
|
||||
self._action = "Execute"
|
||||
self._input_type = None
|
||||
self._output_type = None
|
||||
|
||||
if not callable(self.function):
|
||||
raise ValueError("The provided function parameters is not a callable.")
|
||||
|
||||
if inspect.isclass(self.function):
|
||||
raise ValueError(
|
||||
"The provided callable is an uninitialized Class. This is not allowed."
|
||||
)
|
||||
|
||||
if inspect.isfunction(self.function):
|
||||
# The provided callable is a function
|
||||
self._input_type = get_input_type(self.function)
|
||||
self._output_type = get_output_type(self.function)
|
||||
|
||||
try:
|
||||
# Get name
|
||||
self._name = name_to_title(self.function.__name__)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Get description from function
|
||||
doc_string = inspect.getdoc(self.function)
|
||||
if doc_string:
|
||||
self._action = doc_string
|
||||
except Exception:
|
||||
pass
|
||||
elif hasattr(self.function, "__call__"):
|
||||
# The provided callable is a function
|
||||
self._input_type = get_input_type(self.function.__call__) # type: ignore
|
||||
self._output_type = get_output_type(self.function.__call__) # type: ignore
|
||||
|
||||
try:
|
||||
# Get name
|
||||
self._name = name_to_title(type(self.function).__name__)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Get action from
|
||||
doc_string = inspect.getdoc(self.function.__call__) # type: ignore
|
||||
if doc_string:
|
||||
self._action = doc_string
|
||||
|
||||
if (
|
||||
not self._action
|
||||
or self._action == "Call"
|
||||
):
|
||||
# Get docstring from class instead of __call__ function
|
||||
doc_string = inspect.getdoc(self.function)
|
||||
if doc_string:
|
||||
self._action = doc_string
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Unknown callable type.")
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def action(self) -> str:
|
||||
return self._action
|
||||
|
||||
@property
|
||||
def input_type(self) -> Any:
|
||||
return self._input_type
|
||||
|
||||
@property
|
||||
def output_type(self) -> Any:
|
||||
return self._output_type
|
||||
|
||||
def __call__(self, input: Any, **kwargs: Any) -> Any:
|
||||
|
||||
input_obj = input
|
||||
|
||||
if isinstance(input, str):
|
||||
# Allow json input
|
||||
input_obj = parse_raw_as(self.input_type, input)
|
||||
|
||||
if isinstance(input, dict):
|
||||
# Allow dict input
|
||||
input_obj = parse_obj_as(self.input_type, input)
|
||||
|
||||
return self.function(input_obj, **kwargs)
|
||||
@@ -1 +0,0 @@
|
||||
from .streamlit_ui import render_streamlit_ui
|
||||
@@ -1,129 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def resolve_reference(reference: str, references: Dict) -> Dict:
|
||||
return references[reference.split("/")[-1]]
|
||||
|
||||
|
||||
def get_single_reference_item(property: Dict, references: Dict) -> Dict:
|
||||
# Ref can either be directly in the properties or the first element of allOf
|
||||
reference = property.get("$ref")
|
||||
if reference is None:
|
||||
reference = property["allOf"][0]["$ref"]
|
||||
return resolve_reference(reference, references)
|
||||
|
||||
|
||||
def is_single_string_property(property: Dict) -> bool:
|
||||
return property.get("type") == "string"
|
||||
|
||||
|
||||
def is_single_datetime_property(property: Dict) -> bool:
|
||||
if property.get("type") != "string":
|
||||
return False
|
||||
return property.get("format") in ["date-time", "time", "date"]
|
||||
|
||||
|
||||
def is_single_boolean_property(property: Dict) -> bool:
|
||||
return property.get("type") == "boolean"
|
||||
|
||||
|
||||
def is_single_number_property(property: Dict) -> bool:
|
||||
return property.get("type") in ["integer", "number"]
|
||||
|
||||
|
||||
def is_single_file_property(property: Dict) -> bool:
|
||||
if property.get("type") != "string":
|
||||
return False
|
||||
# TODO: binary?
|
||||
return property.get("format") == "byte"
|
||||
|
||||
|
||||
def is_single_directory_property(property: Dict) -> bool:
|
||||
if property.get("type") != "string":
|
||||
return False
|
||||
return property.get("format") == "path"
|
||||
|
||||
def is_multi_enum_property(property: Dict, references: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
if property.get("uniqueItems") is not True:
|
||||
# Only relevant if it is a set or other datastructures with unique items
|
||||
return False
|
||||
|
||||
try:
|
||||
_ = resolve_reference(property["items"]["$ref"], references)["enum"]
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_single_enum_property(property: Dict, references: Dict) -> bool:
|
||||
try:
|
||||
_ = get_single_reference_item(property, references)["enum"]
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_single_dict_property(property: Dict) -> bool:
|
||||
if property.get("type") != "object":
|
||||
return False
|
||||
return "additionalProperties" in property
|
||||
|
||||
|
||||
def is_single_reference(property: Dict) -> bool:
|
||||
if property.get("type") is not None:
|
||||
return False
|
||||
|
||||
return bool(property.get("$ref"))
|
||||
|
||||
|
||||
def is_multi_file_property(property: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
if property.get("items") is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
# TODO: binary
|
||||
return property["items"]["format"] == "byte"
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_single_object(property: Dict, references: Dict) -> bool:
|
||||
try:
|
||||
object_reference = get_single_reference_item(property, references)
|
||||
if object_reference["type"] != "object":
|
||||
return False
|
||||
return "properties" in object_reference
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_property_list(property: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
if property.get("items") is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
return property["items"]["type"] in ["string", "number", "integer"]
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_object_list_property(property: Dict, references: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
try:
|
||||
object_reference = resolve_reference(property["items"]["$ref"], references)
|
||||
if object_reference["type"] != "object":
|
||||
return False
|
||||
return "properties" in object_reference
|
||||
except Exception:
|
||||
return False
|
||||
@@ -1,888 +0,0 @@
|
||||
import datetime
|
||||
import inspect
|
||||
import mimetypes
|
||||
import sys
|
||||
from os import getcwd, unlink
|
||||
from platform import system
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any, Callable, Dict, List, Type
|
||||
from PIL import Image
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, ValidationError, parse_obj_as
|
||||
|
||||
from mkgui.base import Opyrator
|
||||
from mkgui.base.core import name_to_title
|
||||
from mkgui.base.ui import schema_utils
|
||||
from mkgui.base.ui.streamlit_utils import CUSTOM_STREAMLIT_CSS
|
||||
|
||||
STREAMLIT_RUNNER_SNIPPET = """
|
||||
from mkgui.base.ui import render_streamlit_ui
|
||||
from mkgui.base import Opyrator
|
||||
|
||||
import streamlit as st
|
||||
|
||||
# TODO: Make it configurable
|
||||
# Page config can only be setup once
|
||||
st.set_page_config(
|
||||
page_title="MockingBird",
|
||||
page_icon="🧊",
|
||||
layout="wide")
|
||||
|
||||
render_streamlit_ui()
|
||||
"""
|
||||
|
||||
# with st.spinner("Loading MockingBird GUI. Please wait..."):
|
||||
# opyrator = Opyrator("{opyrator_path}")
|
||||
|
||||
|
||||
def launch_ui(port: int = 8501) -> None:
|
||||
with NamedTemporaryFile(
|
||||
suffix=".py", mode="w", encoding="utf-8", delete=False
|
||||
) as f:
|
||||
f.write(STREAMLIT_RUNNER_SNIPPET)
|
||||
f.seek(0)
|
||||
|
||||
import subprocess
|
||||
|
||||
python_path = f'PYTHONPATH="$PYTHONPATH:{getcwd()}"'
|
||||
if system() == "Windows":
|
||||
python_path = f"set PYTHONPATH=%PYTHONPATH%;{getcwd()} &&"
|
||||
subprocess.run(
|
||||
f"""set STREAMLIT_GLOBAL_SHOW_WARNING_ON_DIRECT_EXECUTION=false""",
|
||||
shell=True,
|
||||
)
|
||||
|
||||
subprocess.run(
|
||||
f"""{python_path} "{sys.executable}" -m streamlit run --server.port={port} --server.headless=True --runner.magicEnabled=False --server.maxUploadSize=50 --browser.gatherUsageStats=False {f.name}""",
|
||||
shell=True,
|
||||
)
|
||||
|
||||
f.close()
|
||||
unlink(f.name)
|
||||
|
||||
|
||||
def function_has_named_arg(func: Callable, parameter: str) -> bool:
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
for param in sig.parameters.values():
|
||||
if param.name == "input":
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def has_output_ui_renderer(data_item: BaseModel) -> bool:
|
||||
return hasattr(data_item, "render_output_ui")
|
||||
|
||||
|
||||
def has_input_ui_renderer(input_class: Type[BaseModel]) -> bool:
|
||||
return hasattr(input_class, "render_input_ui")
|
||||
|
||||
|
||||
def is_compatible_audio(mime_type: str) -> bool:
|
||||
return mime_type in ["audio/mpeg", "audio/ogg", "audio/wav"]
|
||||
|
||||
|
||||
def is_compatible_image(mime_type: str) -> bool:
|
||||
return mime_type in ["image/png", "image/jpeg"]
|
||||
|
||||
|
||||
def is_compatible_video(mime_type: str) -> bool:
|
||||
return mime_type in ["video/mp4"]
|
||||
|
||||
|
||||
class InputUI:
|
||||
def __init__(self, session_state, input_class: Type[BaseModel]):
|
||||
self._session_state = session_state
|
||||
self._input_class = input_class
|
||||
|
||||
self._schema_properties = input_class.schema(by_alias=True).get(
|
||||
"properties", {}
|
||||
)
|
||||
self._schema_references = input_class.schema(by_alias=True).get(
|
||||
"definitions", {}
|
||||
)
|
||||
|
||||
def render_ui(self, streamlit_app_root) -> None:
|
||||
if has_input_ui_renderer(self._input_class):
|
||||
# The input model has a rendering function
|
||||
# The rendering also returns the current state of input data
|
||||
self._session_state.input_data = self._input_class.render_input_ui( # type: ignore
|
||||
st, self._session_state.input_data
|
||||
)
|
||||
return
|
||||
|
||||
# print(self._schema_properties)
|
||||
for property_key in self._schema_properties.keys():
|
||||
property = self._schema_properties[property_key]
|
||||
|
||||
if not property.get("title"):
|
||||
# Set property key as fallback title
|
||||
property["title"] = name_to_title(property_key)
|
||||
|
||||
try:
|
||||
if "input_data" in self._session_state:
|
||||
self._store_value(
|
||||
property_key,
|
||||
self._render_property(streamlit_app_root, property_key, property),
|
||||
)
|
||||
except Exception as e:
|
||||
print("Exception!", e)
|
||||
pass
|
||||
|
||||
def _get_default_streamlit_input_kwargs(self, key: str, property: Dict) -> Dict:
|
||||
streamlit_kwargs = {
|
||||
"label": property.get("title"),
|
||||
"key": key,
|
||||
}
|
||||
|
||||
if property.get("description"):
|
||||
streamlit_kwargs["help"] = property.get("description")
|
||||
return streamlit_kwargs
|
||||
|
||||
def _store_value(self, key: str, value: Any) -> None:
|
||||
data_element = self._session_state.input_data
|
||||
key_elements = key.split(".")
|
||||
for i, key_element in enumerate(key_elements):
|
||||
if i == len(key_elements) - 1:
|
||||
# add value to this element
|
||||
data_element[key_element] = value
|
||||
return
|
||||
if key_element not in data_element:
|
||||
data_element[key_element] = {}
|
||||
data_element = data_element[key_element]
|
||||
|
||||
def _get_value(self, key: str) -> Any:
|
||||
data_element = self._session_state.input_data
|
||||
key_elements = key.split(".")
|
||||
for i, key_element in enumerate(key_elements):
|
||||
if i == len(key_elements) - 1:
|
||||
# add value to this element
|
||||
if key_element not in data_element:
|
||||
return None
|
||||
return data_element[key_element]
|
||||
if key_element not in data_element:
|
||||
data_element[key_element] = {}
|
||||
data_element = data_element[key_element]
|
||||
return None
|
||||
|
||||
def _render_single_datetime_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
if property.get("format") == "time":
|
||||
if property.get("default"):
|
||||
try:
|
||||
streamlit_kwargs["value"] = datetime.time.fromisoformat( # type: ignore
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return streamlit_app.time_input(**streamlit_kwargs)
|
||||
elif property.get("format") == "date":
|
||||
if property.get("default"):
|
||||
try:
|
||||
streamlit_kwargs["value"] = datetime.date.fromisoformat( # type: ignore
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return streamlit_app.date_input(**streamlit_kwargs)
|
||||
elif property.get("format") == "date-time":
|
||||
if property.get("default"):
|
||||
try:
|
||||
streamlit_kwargs["value"] = datetime.datetime.fromisoformat( # type: ignore
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
with streamlit_app.container():
|
||||
streamlit_app.subheader(streamlit_kwargs.get("label"))
|
||||
if streamlit_kwargs.get("description"):
|
||||
streamlit_app.text(streamlit_kwargs.get("description"))
|
||||
selected_date = None
|
||||
selected_time = None
|
||||
date_col, time_col = streamlit_app.columns(2)
|
||||
with date_col:
|
||||
date_kwargs = {"label": "Date", "key": key + "-date-input"}
|
||||
if streamlit_kwargs.get("value"):
|
||||
try:
|
||||
date_kwargs["value"] = streamlit_kwargs.get( # type: ignore
|
||||
"value"
|
||||
).date()
|
||||
except Exception:
|
||||
pass
|
||||
selected_date = streamlit_app.date_input(**date_kwargs)
|
||||
|
||||
with time_col:
|
||||
time_kwargs = {"label": "Time", "key": key + "-time-input"}
|
||||
if streamlit_kwargs.get("value"):
|
||||
try:
|
||||
time_kwargs["value"] = streamlit_kwargs.get( # type: ignore
|
||||
"value"
|
||||
).time()
|
||||
except Exception:
|
||||
pass
|
||||
selected_time = streamlit_app.time_input(**time_kwargs)
|
||||
return datetime.datetime.combine(selected_date, selected_time)
|
||||
else:
|
||||
streamlit_app.warning(
|
||||
"Date format is not supported: " + str(property.get("format"))
|
||||
)
|
||||
|
||||
def _render_single_file_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
file_extension = None
|
||||
if "mime_type" in property:
|
||||
file_extension = mimetypes.guess_extension(property["mime_type"])
|
||||
|
||||
uploaded_file = streamlit_app.file_uploader(
|
||||
**streamlit_kwargs, accept_multiple_files=False, type=file_extension
|
||||
)
|
||||
if uploaded_file is None:
|
||||
return None
|
||||
|
||||
bytes = uploaded_file.getvalue()
|
||||
if property.get("mime_type"):
|
||||
if is_compatible_audio(property["mime_type"]):
|
||||
# Show audio
|
||||
streamlit_app.audio(bytes, format=property.get("mime_type"))
|
||||
if is_compatible_image(property["mime_type"]):
|
||||
# Show image
|
||||
streamlit_app.image(bytes)
|
||||
if is_compatible_video(property["mime_type"]):
|
||||
# Show video
|
||||
streamlit_app.video(bytes, format=property.get("mime_type"))
|
||||
return bytes
|
||||
|
||||
def _render_single_string_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
if property.get("default"):
|
||||
streamlit_kwargs["value"] = property.get("default")
|
||||
elif property.get("example"):
|
||||
# TODO: also use example for other property types
|
||||
# Use example as value if it is provided
|
||||
streamlit_kwargs["value"] = property.get("example")
|
||||
|
||||
if property.get("maxLength") is not None:
|
||||
streamlit_kwargs["max_chars"] = property.get("maxLength")
|
||||
|
||||
if (
|
||||
property.get("format")
|
||||
or (
|
||||
property.get("maxLength") is not None
|
||||
and int(property.get("maxLength")) < 140 # type: ignore
|
||||
)
|
||||
or property.get("writeOnly")
|
||||
):
|
||||
# If any format is set, use single text input
|
||||
# If max chars is set to less than 140, use single text input
|
||||
# If write only -> password field
|
||||
if property.get("writeOnly"):
|
||||
streamlit_kwargs["type"] = "password"
|
||||
return streamlit_app.text_input(**streamlit_kwargs)
|
||||
else:
|
||||
# Otherwise use multiline text area
|
||||
return streamlit_app.text_area(**streamlit_kwargs)
|
||||
|
||||
def _render_multi_enum_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
reference_item = schema_utils.resolve_reference(
|
||||
property["items"]["$ref"], self._schema_references
|
||||
)
|
||||
# TODO: how to select defaults
|
||||
return streamlit_app.multiselect(
|
||||
**streamlit_kwargs, options=reference_item["enum"]
|
||||
)
|
||||
|
||||
def _render_single_enum_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
reference_item = schema_utils.get_single_reference_item(
|
||||
property, self._schema_references
|
||||
)
|
||||
|
||||
if property.get("default") is not None:
|
||||
try:
|
||||
streamlit_kwargs["index"] = reference_item["enum"].index(
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
# Use default selection
|
||||
pass
|
||||
|
||||
return streamlit_app.selectbox(
|
||||
**streamlit_kwargs, options=reference_item["enum"]
|
||||
)
|
||||
|
||||
def _render_single_dict_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
# Add title and subheader
|
||||
streamlit_app.subheader(property.get("title"))
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
current_dict = self._get_value(key)
|
||||
if not current_dict:
|
||||
current_dict = {}
|
||||
|
||||
key_col, value_col = streamlit_app.columns(2)
|
||||
|
||||
with key_col:
|
||||
updated_key = streamlit_app.text_input(
|
||||
"Key", value="", key=key + "-new-key"
|
||||
)
|
||||
|
||||
with value_col:
|
||||
# TODO: also add boolean?
|
||||
value_kwargs = {"label": "Value", "key": key + "-new-value"}
|
||||
if property["additionalProperties"].get("type") == "integer":
|
||||
value_kwargs["value"] = 0 # type: ignore
|
||||
updated_value = streamlit_app.number_input(**value_kwargs)
|
||||
elif property["additionalProperties"].get("type") == "number":
|
||||
value_kwargs["value"] = 0.0 # type: ignore
|
||||
value_kwargs["format"] = "%f"
|
||||
updated_value = streamlit_app.number_input(**value_kwargs)
|
||||
else:
|
||||
value_kwargs["value"] = ""
|
||||
updated_value = streamlit_app.text_input(**value_kwargs)
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
with streamlit_app.container():
|
||||
clear_col, add_col = streamlit_app.columns([1, 2])
|
||||
|
||||
with clear_col:
|
||||
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
||||
current_dict = {}
|
||||
|
||||
with add_col:
|
||||
if (
|
||||
streamlit_app.button("Add Item", key=key + "-add-item")
|
||||
and updated_key
|
||||
):
|
||||
current_dict[updated_key] = updated_value
|
||||
|
||||
streamlit_app.write(current_dict)
|
||||
|
||||
return current_dict
|
||||
|
||||
def _render_single_reference(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
reference_item = schema_utils.get_single_reference_item(
|
||||
property, self._schema_references
|
||||
)
|
||||
return self._render_property(streamlit_app, key, reference_item)
|
||||
|
||||
def _render_multi_file_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
file_extension = None
|
||||
if "mime_type" in property:
|
||||
file_extension = mimetypes.guess_extension(property["mime_type"])
|
||||
|
||||
uploaded_files = streamlit_app.file_uploader(
|
||||
**streamlit_kwargs, accept_multiple_files=True, type=file_extension
|
||||
)
|
||||
uploaded_files_bytes = []
|
||||
if uploaded_files:
|
||||
for uploaded_file in uploaded_files:
|
||||
uploaded_files_bytes.append(uploaded_file.read())
|
||||
return uploaded_files_bytes
|
||||
|
||||
def _render_single_boolean_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
if property.get("default"):
|
||||
streamlit_kwargs["value"] = property.get("default")
|
||||
return streamlit_app.checkbox(**streamlit_kwargs)
|
||||
|
||||
def _render_single_number_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
number_transform = int
|
||||
if property.get("type") == "number":
|
||||
number_transform = float # type: ignore
|
||||
streamlit_kwargs["format"] = "%f"
|
||||
|
||||
if "multipleOf" in property:
|
||||
# Set stepcount based on multiple of parameter
|
||||
streamlit_kwargs["step"] = number_transform(property["multipleOf"])
|
||||
elif number_transform == int:
|
||||
# Set step size to 1 as default
|
||||
streamlit_kwargs["step"] = 1
|
||||
elif number_transform == float:
|
||||
# Set step size to 0.01 as default
|
||||
# TODO: adapt to default value
|
||||
streamlit_kwargs["step"] = 0.01
|
||||
|
||||
if "minimum" in property:
|
||||
streamlit_kwargs["min_value"] = number_transform(property["minimum"])
|
||||
if "exclusiveMinimum" in property:
|
||||
streamlit_kwargs["min_value"] = number_transform(
|
||||
property["exclusiveMinimum"] + streamlit_kwargs["step"]
|
||||
)
|
||||
if "maximum" in property:
|
||||
streamlit_kwargs["max_value"] = number_transform(property["maximum"])
|
||||
|
||||
if "exclusiveMaximum" in property:
|
||||
streamlit_kwargs["max_value"] = number_transform(
|
||||
property["exclusiveMaximum"] - streamlit_kwargs["step"]
|
||||
)
|
||||
|
||||
if property.get("default") is not None:
|
||||
streamlit_kwargs["value"] = number_transform(property.get("default")) # type: ignore
|
||||
else:
|
||||
if "min_value" in streamlit_kwargs:
|
||||
streamlit_kwargs["value"] = streamlit_kwargs["min_value"]
|
||||
elif number_transform == int:
|
||||
streamlit_kwargs["value"] = 0
|
||||
else:
|
||||
# Set default value to step
|
||||
streamlit_kwargs["value"] = number_transform(streamlit_kwargs["step"])
|
||||
|
||||
if "min_value" in streamlit_kwargs and "max_value" in streamlit_kwargs:
|
||||
# TODO: Only if less than X steps
|
||||
return streamlit_app.slider(**streamlit_kwargs)
|
||||
else:
|
||||
return streamlit_app.number_input(**streamlit_kwargs)
|
||||
|
||||
def _render_object_input(self, streamlit_app: st, key: str, property: Dict) -> Any:
|
||||
properties = property["properties"]
|
||||
object_inputs = {}
|
||||
for property_key in properties:
|
||||
property = properties[property_key]
|
||||
if not property.get("title"):
|
||||
# Set property key as fallback title
|
||||
property["title"] = name_to_title(property_key)
|
||||
# construct full key based on key parts -> required later to get the value
|
||||
full_key = key + "." + property_key
|
||||
object_inputs[property_key] = self._render_property(
|
||||
streamlit_app, full_key, property
|
||||
)
|
||||
return object_inputs
|
||||
|
||||
def _render_single_object_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
# Add title and subheader
|
||||
title = property.get("title")
|
||||
streamlit_app.subheader(title)
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
object_reference = schema_utils.get_single_reference_item(
|
||||
property, self._schema_references
|
||||
)
|
||||
return self._render_object_input(streamlit_app, key, object_reference)
|
||||
|
||||
def _render_property_list_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
# Add title and subheader
|
||||
streamlit_app.subheader(property.get("title"))
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
current_list = self._get_value(key)
|
||||
if not current_list:
|
||||
current_list = []
|
||||
|
||||
value_kwargs = {"label": "Value", "key": key + "-new-value"}
|
||||
if property["items"]["type"] == "integer":
|
||||
value_kwargs["value"] = 0 # type: ignore
|
||||
new_value = streamlit_app.number_input(**value_kwargs)
|
||||
elif property["items"]["type"] == "number":
|
||||
value_kwargs["value"] = 0.0 # type: ignore
|
||||
value_kwargs["format"] = "%f"
|
||||
new_value = streamlit_app.number_input(**value_kwargs)
|
||||
else:
|
||||
value_kwargs["value"] = ""
|
||||
new_value = streamlit_app.text_input(**value_kwargs)
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
with streamlit_app.container():
|
||||
clear_col, add_col = streamlit_app.columns([1, 2])
|
||||
|
||||
with clear_col:
|
||||
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
||||
current_list = []
|
||||
|
||||
with add_col:
|
||||
if (
|
||||
streamlit_app.button("Add Item", key=key + "-add-item")
|
||||
and new_value is not None
|
||||
):
|
||||
current_list.append(new_value)
|
||||
|
||||
streamlit_app.write(current_list)
|
||||
|
||||
return current_list
|
||||
|
||||
def _render_object_list_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
# TODO: support max_items, and min_items properties
|
||||
|
||||
# Add title and subheader
|
||||
streamlit_app.subheader(property.get("title"))
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
current_list = self._get_value(key)
|
||||
if not current_list:
|
||||
current_list = []
|
||||
|
||||
object_reference = schema_utils.resolve_reference(
|
||||
property["items"]["$ref"], self._schema_references
|
||||
)
|
||||
input_data = self._render_object_input(streamlit_app, key, object_reference)
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
with streamlit_app.container():
|
||||
clear_col, add_col = streamlit_app.columns([1, 2])
|
||||
|
||||
with clear_col:
|
||||
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
||||
current_list = []
|
||||
|
||||
with add_col:
|
||||
if (
|
||||
streamlit_app.button("Add Item", key=key + "-add-item")
|
||||
and input_data
|
||||
):
|
||||
current_list.append(input_data)
|
||||
|
||||
streamlit_app.write(current_list)
|
||||
return current_list
|
||||
|
||||
def _render_property(self, streamlit_app: st, key: str, property: Dict) -> Any:
|
||||
if schema_utils.is_single_enum_property(property, self._schema_references):
|
||||
return self._render_single_enum_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_multi_enum_property(property, self._schema_references):
|
||||
return self._render_multi_enum_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_file_property(property):
|
||||
return self._render_single_file_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_multi_file_property(property):
|
||||
return self._render_multi_file_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_datetime_property(property):
|
||||
return self._render_single_datetime_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_boolean_property(property):
|
||||
return self._render_single_boolean_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_dict_property(property):
|
||||
return self._render_single_dict_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_number_property(property):
|
||||
return self._render_single_number_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_string_property(property):
|
||||
return self._render_single_string_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_object(property, self._schema_references):
|
||||
return self._render_single_object_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_object_list_property(property, self._schema_references):
|
||||
return self._render_object_list_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_property_list(property):
|
||||
return self._render_property_list_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_reference(property):
|
||||
return self._render_single_reference(streamlit_app, key, property)
|
||||
|
||||
streamlit_app.warning(
|
||||
"The type of the following property is currently not supported: "
|
||||
+ str(property.get("title"))
|
||||
)
|
||||
raise Exception("Unsupported property")
|
||||
|
||||
|
||||
class OutputUI:
|
||||
def __init__(self, output_data: Any, input_data: Any):
|
||||
self._output_data = output_data
|
||||
self._input_data = input_data
|
||||
|
||||
def render_ui(self, streamlit_app) -> None:
|
||||
try:
|
||||
if isinstance(self._output_data, BaseModel):
|
||||
self._render_single_output(streamlit_app, self._output_data)
|
||||
return
|
||||
if type(self._output_data) == list:
|
||||
self._render_list_output(streamlit_app, self._output_data)
|
||||
return
|
||||
except Exception as ex:
|
||||
streamlit_app.exception(ex)
|
||||
# Fallback to
|
||||
streamlit_app.json(jsonable_encoder(self._output_data))
|
||||
|
||||
def _render_single_text_property(
|
||||
self, streamlit: st, property_schema: Dict, value: Any
|
||||
) -> None:
|
||||
# Add title and subheader
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
if value is None or value == "":
|
||||
streamlit.info("No value returned!")
|
||||
else:
|
||||
streamlit.code(str(value), language="plain")
|
||||
|
||||
def _render_single_file_property(
|
||||
self, streamlit: st, property_schema: Dict, value: Any
|
||||
) -> None:
|
||||
# Add title and subheader
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
if value is None or value == "":
|
||||
streamlit.info("No value returned!")
|
||||
else:
|
||||
# TODO: Detect if it is a FileContent instance
|
||||
# TODO: detect if it is base64
|
||||
file_extension = ""
|
||||
if "mime_type" in property_schema:
|
||||
mime_type = property_schema["mime_type"]
|
||||
file_extension = mimetypes.guess_extension(mime_type) or ""
|
||||
|
||||
if is_compatible_audio(mime_type):
|
||||
streamlit.audio(value.as_bytes(), format=mime_type)
|
||||
return
|
||||
|
||||
if is_compatible_image(mime_type):
|
||||
streamlit.image(value.as_bytes())
|
||||
return
|
||||
|
||||
if is_compatible_video(mime_type):
|
||||
streamlit.video(value.as_bytes(), format=mime_type)
|
||||
return
|
||||
|
||||
filename = (
|
||||
(property_schema["title"] + file_extension)
|
||||
.lower()
|
||||
.strip()
|
||||
.replace(" ", "-")
|
||||
)
|
||||
streamlit.markdown(
|
||||
f'<a href="data:application/octet-stream;base64,{value}" download="{filename}"><input type="button" value="Download File"></a>',
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
def _render_single_complex_property(
|
||||
self, streamlit: st, property_schema: Dict, value: Any
|
||||
) -> None:
|
||||
# Add title and subheader
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
|
||||
streamlit.json(jsonable_encoder(value))
|
||||
|
||||
def _render_single_output(self, streamlit: st, output_data: BaseModel) -> None:
|
||||
try:
|
||||
if has_output_ui_renderer(output_data):
|
||||
if function_has_named_arg(output_data.render_output_ui, "input"): # type: ignore
|
||||
# render method also requests the input data
|
||||
output_data.render_output_ui(streamlit, input=self._input_data) # type: ignore
|
||||
else:
|
||||
output_data.render_output_ui(streamlit) # type: ignore
|
||||
return
|
||||
except Exception:
|
||||
# Use default auto-generation methods if the custom rendering throws an exception
|
||||
logger.exception(
|
||||
"Failed to execute custom render_output_ui function. Using auto-generation instead"
|
||||
)
|
||||
|
||||
model_schema = output_data.schema(by_alias=False)
|
||||
model_properties = model_schema.get("properties")
|
||||
definitions = model_schema.get("definitions")
|
||||
|
||||
if model_properties:
|
||||
for property_key in output_data.__dict__:
|
||||
property_schema = model_properties.get(property_key)
|
||||
if not property_schema.get("title"):
|
||||
# Set property key as fallback title
|
||||
property_schema["title"] = property_key
|
||||
|
||||
output_property_value = output_data.__dict__[property_key]
|
||||
|
||||
if has_output_ui_renderer(output_property_value):
|
||||
output_property_value.render_output_ui(streamlit) # type: ignore
|
||||
continue
|
||||
|
||||
if isinstance(output_property_value, BaseModel):
|
||||
# Render output recursivly
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
self._render_single_output(streamlit, output_property_value)
|
||||
continue
|
||||
|
||||
if property_schema:
|
||||
if schema_utils.is_single_file_property(property_schema):
|
||||
self._render_single_file_property(
|
||||
streamlit, property_schema, output_property_value
|
||||
)
|
||||
continue
|
||||
|
||||
if (
|
||||
schema_utils.is_single_string_property(property_schema)
|
||||
or schema_utils.is_single_number_property(property_schema)
|
||||
or schema_utils.is_single_datetime_property(property_schema)
|
||||
or schema_utils.is_single_boolean_property(property_schema)
|
||||
):
|
||||
self._render_single_text_property(
|
||||
streamlit, property_schema, output_property_value
|
||||
)
|
||||
continue
|
||||
if definitions and schema_utils.is_single_enum_property(
|
||||
property_schema, definitions
|
||||
):
|
||||
self._render_single_text_property(
|
||||
streamlit, property_schema, output_property_value.value
|
||||
)
|
||||
continue
|
||||
|
||||
# TODO: render dict as table
|
||||
|
||||
self._render_single_complex_property(
|
||||
streamlit, property_schema, output_property_value
|
||||
)
|
||||
return
|
||||
|
||||
def _render_list_output(self, streamlit: st, output_data: List) -> None:
|
||||
try:
|
||||
data_items: List = []
|
||||
for data_item in output_data:
|
||||
if has_output_ui_renderer(data_item):
|
||||
# Render using the render function
|
||||
data_item.render_output_ui(streamlit) # type: ignore
|
||||
continue
|
||||
data_items.append(data_item.dict())
|
||||
# Try to show as dataframe
|
||||
streamlit.table(pd.DataFrame(data_items))
|
||||
except Exception:
|
||||
# Fallback to
|
||||
streamlit.json(jsonable_encoder(output_data))
|
||||
|
||||
|
||||
def getOpyrator(mode: str) -> Opyrator:
|
||||
if mode == None or mode.startswith('VC'):
|
||||
from mkgui.app_vc import convert
|
||||
return Opyrator(convert)
|
||||
if mode == None or mode.startswith('预处理'):
|
||||
from mkgui.preprocess import preprocess
|
||||
return Opyrator(preprocess)
|
||||
if mode == None or mode.startswith('模型训练'):
|
||||
from mkgui.train import train
|
||||
return Opyrator(train)
|
||||
if mode == None or mode.startswith('模型训练(VC)'):
|
||||
from mkgui.train_vc import train_vc
|
||||
return Opyrator(train_vc)
|
||||
from mkgui.app import synthesize
|
||||
return Opyrator(synthesize)
|
||||
|
||||
|
||||
def render_streamlit_ui() -> None:
|
||||
# init
|
||||
session_state = st.session_state
|
||||
session_state.input_data = {}
|
||||
# Add custom css settings
|
||||
st.markdown(f"<style>{CUSTOM_STREAMLIT_CSS}</style>", unsafe_allow_html=True)
|
||||
|
||||
with st.spinner("Loading MockingBird GUI. Please wait..."):
|
||||
session_state.mode = st.sidebar.selectbox(
|
||||
'模式选择',
|
||||
( "AI拟音", "VC拟音", "预处理", "模型训练", "模型训练(VC)")
|
||||
)
|
||||
if "mode" in session_state:
|
||||
mode = session_state.mode
|
||||
else:
|
||||
mode = ""
|
||||
opyrator = getOpyrator(mode)
|
||||
title = opyrator.name + mode
|
||||
|
||||
col1, col2, _ = st.columns(3)
|
||||
col2.title(title)
|
||||
col2.markdown("欢迎使用MockingBird Web 2")
|
||||
|
||||
image = Image.open('.\\mkgui\\static\\mb.png')
|
||||
col1.image(image)
|
||||
|
||||
st.markdown("---")
|
||||
left, right = st.columns([0.4, 0.6])
|
||||
|
||||
with left:
|
||||
st.header("Control 控制")
|
||||
InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(st)
|
||||
execute_selected = st.button(opyrator.action)
|
||||
if execute_selected:
|
||||
with st.spinner("Executing operation. Please wait..."):
|
||||
try:
|
||||
input_data_obj = parse_obj_as(
|
||||
opyrator.input_type, session_state.input_data
|
||||
)
|
||||
session_state.output_data = opyrator(input=input_data_obj)
|
||||
session_state.latest_operation_input = input_data_obj # should this really be saved as additional session object?
|
||||
except ValidationError as ex:
|
||||
st.error(ex)
|
||||
else:
|
||||
# st.success("Operation executed successfully.")
|
||||
pass
|
||||
|
||||
with right:
|
||||
st.header("Result 结果")
|
||||
if 'output_data' in session_state:
|
||||
OutputUI(
|
||||
session_state.output_data, session_state.latest_operation_input
|
||||
).render_ui(st)
|
||||
if st.button("Clear"):
|
||||
# Clear all state
|
||||
for key in st.session_state.keys():
|
||||
del st.session_state[key]
|
||||
session_state.input_data = {}
|
||||
st.experimental_rerun()
|
||||
else:
|
||||
# placeholder
|
||||
st.caption("请使用左侧控制板进行输入并运行获得结果")
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
CUSTOM_STREAMLIT_CSS = """
|
||||
div[data-testid="stBlock"] button {
|
||||
width: 100% !important;
|
||||
margin-bottom: 20px !important;
|
||||
border-color: #bfbfbf !important;
|
||||
}
|
||||
section[data-testid="stSidebar"] div {
|
||||
max-width: 10rem;
|
||||
}
|
||||
pre code {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
"""
|
||||
@@ -1,96 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from typing import Any, Tuple
|
||||
|
||||
|
||||
# Constants
|
||||
EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
|
||||
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
|
||||
|
||||
|
||||
if os.path.isdir(EXT_MODELS_DIRT):
|
||||
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded extractor models: " + str(len(extractors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(ENC_MODELS_DIRT):
|
||||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded encoders models: " + str(len(encoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
class Model(str, Enum):
|
||||
VC_PPG2MEL = "ppg2mel"
|
||||
|
||||
class Dataset(str, Enum):
|
||||
AIDATATANG_200ZH = "aidatatang_200zh"
|
||||
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
|
||||
|
||||
class Input(BaseModel):
|
||||
# def render_input_ui(st, input) -> Dict:
|
||||
# input["selected_dataset"] = st.selectbox(
|
||||
# '选择数据集',
|
||||
# ("aidatatang_200zh", "aidatatang_200zh_s")
|
||||
# )
|
||||
# return input
|
||||
model: Model = Field(
|
||||
Model.VC_PPG2MEL, title="目标模型",
|
||||
)
|
||||
dataset: Dataset = Field(
|
||||
Dataset.AIDATATANG_200ZH, title="数据集选择",
|
||||
)
|
||||
datasets_root: str = Field(
|
||||
..., alias="数据集根目录", description="输入数据集根目录(相对/绝对)",
|
||||
format=True,
|
||||
example="..\\trainning_data\\"
|
||||
)
|
||||
output_root: str = Field(
|
||||
..., alias="输出根目录", description="输出结果根目录(相对/绝对)",
|
||||
format=True,
|
||||
example="..\\trainning_data\\"
|
||||
)
|
||||
n_processes: int = Field(
|
||||
2, alias="处理线程数", description="根据CPU线程数来设置",
|
||||
le=32, ge=1
|
||||
)
|
||||
extractor: extractors = Field(
|
||||
..., alias="特征提取模型",
|
||||
description="选择PPG特征提取模型文件."
|
||||
)
|
||||
encoder: encoders = Field(
|
||||
..., alias="语音编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: Tuple[str, int]
|
||||
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
sr, count = self.__root__
|
||||
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
|
||||
|
||||
def preprocess(input: Input) -> Output:
|
||||
"""Preprocess(预处理)"""
|
||||
finished = 0
|
||||
if input.model == Model.VC_PPG2MEL:
|
||||
from ppg2mel.preprocess import preprocess_dataset
|
||||
finished = preprocess_dataset(
|
||||
datasets_root=Path(input.datasets_root),
|
||||
dataset=input.dataset,
|
||||
out_dir=Path(input.output_root),
|
||||
n_processes=input.n_processes,
|
||||
ppg_encoder_model_fpath=Path(input.extractor.value),
|
||||
speaker_encoder_model=Path(input.encoder.value)
|
||||
)
|
||||
# TODO: pass useful return code
|
||||
return Output(__root__=(input.dataset, finished))
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 5.6 KiB |
106
mkgui/train.py
106
mkgui/train.py
@@ -1,106 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from synthesizer.hparams import hparams
|
||||
from synthesizer.train import train as synt_train
|
||||
|
||||
# Constants
|
||||
SYN_MODELS_DIRT = f"synthesizer{os.sep}saved_models"
|
||||
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
|
||||
|
||||
|
||||
# EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
|
||||
# CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models"
|
||||
# ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
|
||||
|
||||
# Pre-Load models
|
||||
if os.path.isdir(SYN_MODELS_DIRT):
|
||||
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
||||
else:
|
||||
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(ENC_MODELS_DIRT):
|
||||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded encoders models: " + str(len(encoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
class Model(str, Enum):
|
||||
DEFAULT = "default"
|
||||
|
||||
class Input(BaseModel):
|
||||
model: Model = Field(
|
||||
Model.DEFAULT, title="模型类型",
|
||||
)
|
||||
# datasets_root: str = Field(
|
||||
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
|
||||
# format=True,
|
||||
# example="..\\trainning_data\\"
|
||||
# )
|
||||
input_root: str = Field(
|
||||
..., alias="输入目录", description="预处理数据根目录",
|
||||
format=True,
|
||||
example=f"..{os.sep}audiodata{os.sep}SV2TTS{os.sep}synthesizer"
|
||||
)
|
||||
run_id: str = Field(
|
||||
"", alias="新模型名/运行ID", description="使用新ID进行重新训练,否则选择下面的模型进行继续训练",
|
||||
)
|
||||
synthesizer: synthesizers = Field(
|
||||
..., alias="已有合成模型",
|
||||
description="选择语音合成模型文件."
|
||||
)
|
||||
gpu: bool = Field(
|
||||
True, alias="GPU训练", description="选择“是”,则使用GPU训练",
|
||||
)
|
||||
verbose: bool = Field(
|
||||
True, alias="打印详情", description="选择“是”,输出更多详情",
|
||||
)
|
||||
encoder: encoders = Field(
|
||||
..., alias="语音编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
save_every: int = Field(
|
||||
1000, alias="更新间隔", description="每隔n步则更新一次模型",
|
||||
)
|
||||
backup_every: int = Field(
|
||||
10000, alias="保存间隔", description="每隔n步则保存一次模型",
|
||||
)
|
||||
log_every: int = Field(
|
||||
500, alias="打印间隔", description="每隔n步则打印一次训练统计",
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: int
|
||||
|
||||
def render_output_ui(self, streamlit_app) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
streamlit_app.subheader(f"Training started with code: {self.__root__}")
|
||||
|
||||
def train(input: Input) -> Output:
|
||||
"""Train(训练)"""
|
||||
|
||||
print(">>> Start training ...")
|
||||
force_restart = len(input.run_id) > 0
|
||||
if not force_restart:
|
||||
input.run_id = Path(input.synthesizer.value).name.split('.')[0]
|
||||
|
||||
synt_train(
|
||||
input.run_id,
|
||||
input.input_root,
|
||||
f"synthesizer{os.sep}saved_models",
|
||||
input.save_every,
|
||||
input.backup_every,
|
||||
input.log_every,
|
||||
force_restart,
|
||||
hparams
|
||||
)
|
||||
return Output(__root__=0)
|
||||
@@ -1,155 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from typing import Any, Tuple
|
||||
import numpy as np
|
||||
from utils.load_yaml import HpsYaml
|
||||
from utils.util import AttrDict
|
||||
import torch
|
||||
|
||||
# Constants
|
||||
EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
|
||||
CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models"
|
||||
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
|
||||
|
||||
|
||||
if os.path.isdir(EXT_MODELS_DIRT):
|
||||
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded extractor models: " + str(len(extractors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(CONV_MODELS_DIRT):
|
||||
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
|
||||
print("Loaded convertor models: " + str(len(convertors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(ENC_MODELS_DIRT):
|
||||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded encoders models: " + str(len(encoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
class Model(str, Enum):
|
||||
VC_PPG2MEL = "ppg2mel"
|
||||
|
||||
class Dataset(str, Enum):
|
||||
AIDATATANG_200ZH = "aidatatang_200zh"
|
||||
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
|
||||
|
||||
class Input(BaseModel):
|
||||
# def render_input_ui(st, input) -> Dict:
|
||||
# input["selected_dataset"] = st.selectbox(
|
||||
# '选择数据集',
|
||||
# ("aidatatang_200zh", "aidatatang_200zh_s")
|
||||
# )
|
||||
# return input
|
||||
model: Model = Field(
|
||||
Model.VC_PPG2MEL, title="模型类型",
|
||||
)
|
||||
# datasets_root: str = Field(
|
||||
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
|
||||
# format=True,
|
||||
# example="..\\trainning_data\\"
|
||||
# )
|
||||
output_root: str = Field(
|
||||
..., alias="输出目录(可选)", description="建议不填,保持默认",
|
||||
format=True,
|
||||
example=""
|
||||
)
|
||||
continue_mode: bool = Field(
|
||||
True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练",
|
||||
)
|
||||
gpu: bool = Field(
|
||||
True, alias="GPU训练", description="选择“是”,则使用GPU训练",
|
||||
)
|
||||
verbose: bool = Field(
|
||||
True, alias="打印详情", description="选择“是”,输出更多详情",
|
||||
)
|
||||
# TODO: Move to hiden fields by default
|
||||
convertor: convertors = Field(
|
||||
..., alias="转换模型",
|
||||
description="选择语音转换模型文件."
|
||||
)
|
||||
extractor: extractors = Field(
|
||||
..., alias="特征提取模型",
|
||||
description="选择PPG特征提取模型文件."
|
||||
)
|
||||
encoder: encoders = Field(
|
||||
..., alias="语音编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
njobs: int = Field(
|
||||
8, alias="进程数", description="适用于ppg2mel",
|
||||
)
|
||||
seed: int = Field(
|
||||
default=0, alias="初始随机数", description="适用于ppg2mel",
|
||||
)
|
||||
model_name: str = Field(
|
||||
..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效",
|
||||
example="test"
|
||||
)
|
||||
model_config: str = Field(
|
||||
..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效",
|
||||
example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2"
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: Tuple[str, int]
|
||||
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
sr, count = self.__root__
|
||||
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
|
||||
|
||||
def train_vc(input: Input) -> Output:
|
||||
"""Train VC(训练 VC)"""
|
||||
|
||||
print(">>> OneShot VC training ...")
|
||||
params = AttrDict()
|
||||
params.update({
|
||||
"gpu": input.gpu,
|
||||
"cpu": not input.gpu,
|
||||
"njobs": input.njobs,
|
||||
"seed": input.seed,
|
||||
"verbose": input.verbose,
|
||||
"load": input.convertor.value,
|
||||
"warm_start": False,
|
||||
})
|
||||
if input.continue_mode:
|
||||
# trace old model and config
|
||||
p = Path(input.convertor.value)
|
||||
params.name = p.parent.name
|
||||
# search a config file
|
||||
model_config_fpaths = list(p.parent.rglob("*.yaml"))
|
||||
if len(model_config_fpaths) == 0:
|
||||
raise "No model yaml config found for convertor"
|
||||
config = HpsYaml(model_config_fpaths[0])
|
||||
params.ckpdir = p.parent.parent
|
||||
params.config = model_config_fpaths[0]
|
||||
params.logdir = os.path.join(p.parent, "log")
|
||||
else:
|
||||
# Make the config dict dot visitable
|
||||
config = HpsYaml(input.config)
|
||||
np.random.seed(input.seed)
|
||||
torch.manual_seed(input.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(input.seed)
|
||||
mode = "train"
|
||||
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||
solver = Solver(config, params, mode)
|
||||
solver.load_data()
|
||||
solver.set_model()
|
||||
solver.exec()
|
||||
print(">>> Oneshot VC train finished!")
|
||||
|
||||
# TODO: pass useful return code
|
||||
return Output(__root__=(input.dataset, 0))
|
||||
@@ -191,15 +191,12 @@ class MelDecoderMOLv2(AbsMelDecoder):
|
||||
|
||||
return mel_outputs[0], mel_outputs_postnet[0], alignments[0]
|
||||
|
||||
def load_model(model_file, device=None):
|
||||
# search a config file
|
||||
model_config_fpaths = list(model_file.parent.rglob("*.yaml"))
|
||||
if len(model_config_fpaths) == 0:
|
||||
raise "No model yaml config found for convertor"
|
||||
def load_model(train_config, model_file, device=None):
|
||||
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model_config = HpsYaml(model_config_fpaths[0])
|
||||
model_config = HpsYaml(train_config)
|
||||
ppg2mel_model = MelDecoderMOLv2(
|
||||
**model_config["model"]
|
||||
).to(device)
|
||||
|
||||
@@ -110,4 +110,3 @@ def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder
|
||||
t_fid_file.close()
|
||||
d_fid_file.close()
|
||||
e_fid_file.close()
|
||||
return len(wav_file_list)
|
||||
|
||||
@@ -31,10 +31,15 @@ def main():
|
||||
parser.add_argument('--njobs', default=8, type=int,
|
||||
help='Number of threads for dataloader/decoding.', required=False)
|
||||
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
||||
# parser.add_argument('--no-pin', action='store_true',
|
||||
# help='Disable pin-memory for dataloader')
|
||||
parser.add_argument('--no-pin', action='store_true',
|
||||
help='Disable pin-memory for dataloader')
|
||||
parser.add_argument('--test', action='store_true', help='Test the model.')
|
||||
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
||||
|
||||
parser.add_argument('--finetune', action='store_true', help='Finetune model')
|
||||
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
|
||||
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
|
||||
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
|
||||
|
||||
###
|
||||
|
||||
paras = parser.parse_args()
|
||||
|
||||
@@ -93,7 +93,6 @@ class BaseSolver():
|
||||
|
||||
def load_ckpt(self):
|
||||
''' Load ckpt if --load option is specified '''
|
||||
print(self.paras)
|
||||
if self.paras.load is not None:
|
||||
if self.paras.warm_start:
|
||||
self.verbose(f"Warm starting model from checkpoint {self.paras.load}.")
|
||||
@@ -101,7 +100,7 @@ class BaseSolver():
|
||||
self.paras.load, map_location=self.device if self.mode == 'train'
|
||||
else 'cpu')
|
||||
model_dict = ckpt['model']
|
||||
if "ignore_layers" in self.config.model and len(self.config.model.ignore_layers) > 0:
|
||||
if len(self.config.model.ignore_layers) > 0:
|
||||
model_dict = {k:v for k, v in model_dict.items()
|
||||
if k not in self.config.model.ignore_layers}
|
||||
dummy_dict = self.model.state_dict()
|
||||
|
||||
@@ -21,8 +21,6 @@ flask_cors==3.0.10
|
||||
gevent==21.8.0
|
||||
flask_restx
|
||||
tensorboard
|
||||
streamlit==1.8.0
|
||||
PyYAML==5.4.1
|
||||
torch_complex
|
||||
espnet
|
||||
PyWavelets
|
||||
espnet
|
||||
142
run.py
142
run.py
@@ -1,142 +0,0 @@
|
||||
import time
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
import glob
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from ppg_extractor import load_model
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
from utils.load_yaml import HpsYaml
|
||||
|
||||
from encoder.audio import preprocess_wav
|
||||
from encoder import inference as speacker_encoder
|
||||
from vocoder.hifigan import inference as vocoder
|
||||
from ppg2mel import MelDecoderMOLv2
|
||||
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
|
||||
|
||||
|
||||
def _build_ppg2mel_model(model_config, model_file, device):
|
||||
ppg2mel_model = MelDecoderMOLv2(
|
||||
**model_config["model"]
|
||||
).to(device)
|
||||
ckpt = torch.load(model_file, map_location=device)
|
||||
ppg2mel_model.load_state_dict(ckpt["model"])
|
||||
ppg2mel_model.eval()
|
||||
return ppg2mel_model
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert(args):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
output_dir = args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
step = os.path.basename(args.ppg2mel_model_file)[:-4].split("_")[-1]
|
||||
|
||||
# Build models
|
||||
print("Load PPG-model, PPG2Mel-model, Vocoder-model...")
|
||||
ppg_model = load_model(
|
||||
Path('./ppg_extractor/saved_models/24epoch.pt'),
|
||||
device,
|
||||
)
|
||||
ppg2mel_model = _build_ppg2mel_model(HpsYaml(args.ppg2mel_model_train_config), args.ppg2mel_model_file, device)
|
||||
# vocoder.load_model('./vocoder/saved_models/pretrained/g_hifigan.pt', "./vocoder/hifigan/config_16k_.json")
|
||||
vocoder.load_model('./vocoder/saved_models/24k/g_02830000.pt')
|
||||
# Data related
|
||||
ref_wav_path = args.ref_wav_path
|
||||
ref_wav = preprocess_wav(ref_wav_path)
|
||||
ref_fid = os.path.basename(ref_wav_path)[:-4]
|
||||
|
||||
# TODO: specify encoder
|
||||
speacker_encoder.load_model(Path("encoder/saved_models/pretrained_bak_5805000.pt"))
|
||||
ref_spk_dvec = speacker_encoder.embed_utterance(ref_wav)
|
||||
ref_spk_dvec = torch.from_numpy(ref_spk_dvec).unsqueeze(0).to(device)
|
||||
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
|
||||
|
||||
source_file_list = sorted(glob.glob(f"{args.wav_dir}/*.wav"))
|
||||
print(f"Number of source utterances: {len(source_file_list)}.")
|
||||
|
||||
total_rtf = 0.0
|
||||
cnt = 0
|
||||
for src_wav_path in tqdm(source_file_list):
|
||||
# Load the audio to a numpy array:
|
||||
src_wav, _ = librosa.load(src_wav_path, sr=16000)
|
||||
src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(device)
|
||||
src_wav_lengths = torch.LongTensor([len(src_wav)]).to(device)
|
||||
ppg = ppg_model(src_wav_tensor, src_wav_lengths)
|
||||
|
||||
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
|
||||
min_len = min(ppg.shape[1], len(lf0_uv))
|
||||
|
||||
ppg = ppg[:, :min_len]
|
||||
lf0_uv = lf0_uv[:min_len]
|
||||
|
||||
start = time.time()
|
||||
_, mel_pred, att_ws = ppg2mel_model.inference(
|
||||
ppg,
|
||||
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
|
||||
spembs=ref_spk_dvec,
|
||||
)
|
||||
src_fid = os.path.basename(src_wav_path)[:-4]
|
||||
wav_fname = f"{output_dir}/vc_{src_fid}_ref_{ref_fid}_step{step}.wav"
|
||||
mel_len = mel_pred.shape[0]
|
||||
rtf = (time.time() - start) / (0.01 * mel_len)
|
||||
total_rtf += rtf
|
||||
cnt += 1
|
||||
# continue
|
||||
mel_pred= mel_pred.transpose(0, 1)
|
||||
y, output_sample_rate = vocoder.infer_waveform(mel_pred.cpu())
|
||||
sf.write(wav_fname, y.squeeze(), output_sample_rate, "PCM_16")
|
||||
|
||||
print("RTF:")
|
||||
print(total_rtf / cnt)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(description="Conversion from wave input")
|
||||
parser.add_argument(
|
||||
"--wav_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Source wave directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ref_wav_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Reference wave file path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ppg2mel_model_train_config", "-c",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Training config file (yaml file)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ppg2mel_model_file", "-m",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="ppg2mel model checkpoint file path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir", "-o",
|
||||
type=str,
|
||||
default="vc_gens_vctk_oneshot",
|
||||
help="Output folder to save the converted wave."
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
convert(args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Binary file not shown.
@@ -167,7 +167,7 @@ def _mel_to_linear(mel_spectrogram, hparams):
|
||||
|
||||
def _build_mel_basis(hparams):
|
||||
assert hparams.fmax <= hparams.sample_rate // 2
|
||||
return librosa.filters.mel(sr=hparams.sample_rate, n_fft=hparams.n_fft, n_mels=hparams.num_mels,
|
||||
return librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels,
|
||||
fmin=hparams.fmin, fmax=hparams.fmax)
|
||||
|
||||
def _amp_to_db(x, hparams):
|
||||
|
||||
@@ -22,8 +22,7 @@ class HParams(object):
|
||||
def loadJson(self, dict):
|
||||
print("\Loading the json with %s\n", dict)
|
||||
for k in dict.keys():
|
||||
if k not in ["tts_schedule", "tts_finetune_layers"]:
|
||||
self.__dict__[k] = dict[k]
|
||||
self.__dict__[k] = dict[k]
|
||||
return self
|
||||
|
||||
def dumpJson(self, fp):
|
||||
|
||||
@@ -149,7 +149,7 @@ class Synthesizer:
|
||||
Loads and preprocesses an audio file under the same conditions the audio files were used to
|
||||
train the synthesizer.
|
||||
"""
|
||||
wav = librosa.load(path=str(fpath), sr=hparams.sample_rate)[0]
|
||||
wav = librosa.load(str(fpath), hparams.sample_rate)[0]
|
||||
if hparams.rescale:
|
||||
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
||||
# denoise
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import imp
|
||||
import numpy as np
|
||||
|
||||
class Base(nn.Module):
|
||||
def __init__(self, stop_threshold):
|
||||
super().__init__()
|
||||
|
||||
self.init_model()
|
||||
self.num_params()
|
||||
|
||||
self.register_buffer("step", torch.zeros(1, dtype=torch.long))
|
||||
self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
|
||||
|
||||
@property
|
||||
def r(self):
|
||||
return self.decoder.r.item()
|
||||
|
||||
@r.setter
|
||||
def r(self, value):
|
||||
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
||||
|
||||
def init_model(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1: nn.init.xavier_uniform_(p)
|
||||
|
||||
def finetune_partial(self, whitelist_layers):
|
||||
self.zero_grad()
|
||||
for name, child in self.named_children():
|
||||
if name in whitelist_layers:
|
||||
print("Trainable Layer: %s" % name)
|
||||
print("Trainable Parameters: %.3f" % sum([np.prod(p.size()) for p in child.parameters()]))
|
||||
for param in child.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def get_step(self):
|
||||
return self.step.data.item()
|
||||
|
||||
def reset_step(self):
|
||||
# assignment to parameters or buffers is overloaded, updates internal dict entry
|
||||
self.step = self.step.data.new_tensor(1)
|
||||
|
||||
def log(self, path, msg):
|
||||
with open(path, "a") as f:
|
||||
print(msg, file=f)
|
||||
|
||||
def load(self, path, device, optimizer=None):
|
||||
# Use device of model params as location for loaded state
|
||||
checkpoint = torch.load(str(path), map_location=device)
|
||||
self.load_state_dict(checkpoint["model_state"], strict=False)
|
||||
|
||||
if "optimizer_state" in checkpoint and optimizer is not None:
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
||||
|
||||
def save(self, path, optimizer=None):
|
||||
if optimizer is not None:
|
||||
torch.save({
|
||||
"model_state": self.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
}, str(path))
|
||||
else:
|
||||
torch.save({
|
||||
"model_state": self.state_dict(),
|
||||
}, str(path))
|
||||
|
||||
|
||||
def num_params(self, print_out=True):
|
||||
parameters = filter(lambda p: p.requires_grad, self.parameters())
|
||||
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
||||
if print_out:
|
||||
print("Trainable Parameters: %.3fM" % parameters)
|
||||
return parameters
|
||||
@@ -97,7 +97,7 @@ class STL(nn.Module):
|
||||
def forward(self, inputs):
|
||||
N = inputs.size(0)
|
||||
query = inputs.unsqueeze(1) # [N, 1, E//2]
|
||||
keys = torch.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads]
|
||||
keys = tFunctional.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads]
|
||||
style_embed = self.attention(query, keys)
|
||||
|
||||
return style_embed
|
||||
@@ -1 +0,0 @@
|
||||
#
|
||||
@@ -1,85 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .common.batch_norm_conv import BatchNormConv
|
||||
from .common.highway_network import HighwayNetwork
|
||||
|
||||
class CBHG(nn.Module):
|
||||
def __init__(self, K, in_channels, channels, proj_channels, num_highways):
|
||||
super().__init__()
|
||||
|
||||
# List of all rnns to call `flatten_parameters()` on
|
||||
self._to_flatten = []
|
||||
|
||||
self.bank_kernels = [i for i in range(1, K + 1)]
|
||||
self.conv1d_bank = nn.ModuleList()
|
||||
for k in self.bank_kernels:
|
||||
conv = BatchNormConv(in_channels, channels, k)
|
||||
self.conv1d_bank.append(conv)
|
||||
|
||||
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
||||
|
||||
self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
|
||||
self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
|
||||
|
||||
# Fix the highway input if necessary
|
||||
if proj_channels[-1] != channels:
|
||||
self.highway_mismatch = True
|
||||
self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
|
||||
else:
|
||||
self.highway_mismatch = False
|
||||
|
||||
self.highways = nn.ModuleList()
|
||||
for i in range(num_highways):
|
||||
hn = HighwayNetwork(channels)
|
||||
self.highways.append(hn)
|
||||
|
||||
self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
|
||||
self._to_flatten.append(self.rnn)
|
||||
|
||||
# Avoid fragmentation of RNN parameters and associated warning
|
||||
self._flatten_parameters()
|
||||
|
||||
def forward(self, x):
|
||||
# Although we `_flatten_parameters()` on init, when using DataParallel
|
||||
# the model gets replicated, making it no longer guaranteed that the
|
||||
# weights are contiguous in GPU memory. Hence, we must call it again
|
||||
self.rnn.flatten_parameters()
|
||||
|
||||
# Save these for later
|
||||
residual = x
|
||||
seq_len = x.size(-1)
|
||||
conv_bank = []
|
||||
|
||||
# Convolution Bank
|
||||
for conv in self.conv1d_bank:
|
||||
c = conv(x) # Convolution
|
||||
conv_bank.append(c[:, :, :seq_len])
|
||||
|
||||
# Stack along the channel axis
|
||||
conv_bank = torch.cat(conv_bank, dim=1)
|
||||
|
||||
# dump the last padding to fit residual
|
||||
x = self.maxpool(conv_bank)[:, :, :seq_len]
|
||||
|
||||
# Conv1d projections
|
||||
x = self.conv_project1(x)
|
||||
x = self.conv_project2(x)
|
||||
|
||||
# Residual Connect
|
||||
x = x + residual
|
||||
|
||||
# Through the highways
|
||||
x = x.transpose(1, 2)
|
||||
if self.highway_mismatch is True:
|
||||
x = self.pre_highway(x)
|
||||
for h in self.highways: x = h(x)
|
||||
|
||||
# And then the RNN
|
||||
x, _ = self.rnn(x)
|
||||
return x
|
||||
|
||||
def _flatten_parameters(self):
|
||||
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
|
||||
to improve efficiency and avoid PyTorch yelling at us."""
|
||||
[m.flatten_parameters() for m in self._to_flatten]
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class BatchNormConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel, relu=True):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
|
||||
self.bnorm = nn.BatchNorm1d(out_channels)
|
||||
self.relu = relu
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = F.relu(x) if self.relu is True else x
|
||||
return self.bnorm(x)
|
||||
@@ -1,17 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class HighwayNetwork(nn.Module):
|
||||
def __init__(self, size):
|
||||
super().__init__()
|
||||
self.W1 = nn.Linear(size, size)
|
||||
self.W2 = nn.Linear(size, size)
|
||||
self.W1.bias.data.fill_(0.)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.W1(x)
|
||||
x2 = self.W2(x)
|
||||
g = torch.sigmoid(x2)
|
||||
y = g * F.relu(x1) + (1. - g) * x
|
||||
return y
|
||||
@@ -1,42 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class LSA(nn.Module):
|
||||
def __init__(self, attn_dim, kernel_size=31, filters=32):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
|
||||
self.L = nn.Linear(filters, attn_dim, bias=False)
|
||||
self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
|
||||
self.v = nn.Linear(attn_dim, 1, bias=False)
|
||||
self.cumulative = None
|
||||
self.attention = None
|
||||
|
||||
def init_attention(self, encoder_seq_proj):
|
||||
device = encoder_seq_proj.device # use same device as parameters
|
||||
b, t, c = encoder_seq_proj.size()
|
||||
self.cumulative = torch.zeros(b, t, device=device)
|
||||
self.attention = torch.zeros(b, t, device=device)
|
||||
|
||||
def forward(self, encoder_seq_proj, query, times, chars):
|
||||
|
||||
if times == 0: self.init_attention(encoder_seq_proj)
|
||||
|
||||
processed_query = self.W(query).unsqueeze(1)
|
||||
|
||||
location = self.cumulative.unsqueeze(1)
|
||||
processed_loc = self.L(self.conv(location).transpose(1, 2))
|
||||
|
||||
u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
|
||||
u = u.squeeze(-1)
|
||||
|
||||
# Mask zero padding chars
|
||||
u = u * (chars != 0).float()
|
||||
|
||||
# Smooth Attention
|
||||
# scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
|
||||
scores = F.softmax(u, dim=1)
|
||||
self.attention = scores
|
||||
self.cumulative = self.cumulative + self.attention
|
||||
|
||||
return scores.unsqueeze(-1).transpose(1, 2)
|
||||
@@ -1,27 +0,0 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class PreNet(nn.Module):
|
||||
def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(in_dims, fc1_dims)
|
||||
self.fc2 = nn.Linear(fc1_dims, fc2_dims)
|
||||
self.p = dropout
|
||||
|
||||
def forward(self, x):
|
||||
"""forward
|
||||
|
||||
Args:
|
||||
x (3D tensor with size `[batch_size, num_chars, tts_embed_dims]`): input texts list
|
||||
|
||||
Returns:
|
||||
3D tensor with size `[batch_size, num_chars, encoder_dims]`
|
||||
|
||||
"""
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = F.dropout(x, self.p, training=True)
|
||||
x = self.fc2(x)
|
||||
x = F.relu(x)
|
||||
x = F.dropout(x, self.p, training=True)
|
||||
return x
|
||||
@@ -1,88 +1,277 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .sublayer.global_style_token import GlobalStyleToken
|
||||
from .sublayer.pre_net import PreNet
|
||||
from .sublayer.cbhg import CBHG
|
||||
from .sublayer.lsa import LSA
|
||||
from .base import Base
|
||||
import torch.nn.functional as F
|
||||
from synthesizer.models.global_style_token import GlobalStyleToken
|
||||
from synthesizer.gst_hyperparameters import GSTHyperparameters as gst_hp
|
||||
from synthesizer.hparams import hparams
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, num_chars, embed_dims=512, encoder_dims=256, K=5, num_highways=4, dropout=0.5):
|
||||
""" Encoder for SV2TTS
|
||||
|
||||
Args:
|
||||
num_chars (int): length of symbols
|
||||
embed_dims (int, optional): embedding dim for input texts. Defaults to 512.
|
||||
encoder_dims (int, optional): output dim for encoder. Defaults to 256.
|
||||
K (int, optional): _description_. Defaults to 5.
|
||||
num_highways (int, optional): _description_. Defaults to 4.
|
||||
dropout (float, optional): _description_. Defaults to 0.5.
|
||||
"""
|
||||
class HighwayNetwork(nn.Module):
|
||||
def __init__(self, size):
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(num_chars, embed_dims)
|
||||
self.pre_net = PreNet(embed_dims, fc1_dims=encoder_dims, fc2_dims=encoder_dims,
|
||||
dropout=dropout)
|
||||
self.cbhg = CBHG(K=K, in_channels=encoder_dims, channels=encoder_dims,
|
||||
proj_channels=[encoder_dims, encoder_dims],
|
||||
num_highways=num_highways)
|
||||
self.W1 = nn.Linear(size, size)
|
||||
self.W2 = nn.Linear(size, size)
|
||||
self.W1.bias.data.fill_(0.)
|
||||
|
||||
def forward(self, x):
|
||||
"""forward pass for encoder
|
||||
x1 = self.W1(x)
|
||||
x2 = self.W2(x)
|
||||
g = torch.sigmoid(x2)
|
||||
y = g * F.relu(x1) + (1. - g) * x
|
||||
return y
|
||||
|
||||
Args:
|
||||
x (2D tensor with size `[batch_size, text_num_chars]`): input texts list
|
||||
|
||||
Returns:
|
||||
3D tensor with size `[batch_size, text_num_chars, encoder_dims]`
|
||||
|
||||
"""
|
||||
x = self.embedding(x) # return: [batch_size, text_num_chars, tts_embed_dims]
|
||||
x = self.pre_net(x) # return: [batch_size, text_num_chars, encoder_dims]
|
||||
x.transpose_(1, 2) # return: [batch_size, encoder_dims, text_num_chars]
|
||||
return self.cbhg(x) # return: [batch_size, text_num_chars, encoder_dims]
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
|
||||
super().__init__()
|
||||
prenet_dims = (encoder_dims, encoder_dims)
|
||||
cbhg_channels = encoder_dims
|
||||
self.embedding = nn.Embedding(num_chars, embed_dims)
|
||||
self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
|
||||
dropout=dropout)
|
||||
self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
|
||||
proj_channels=[cbhg_channels, cbhg_channels],
|
||||
num_highways=num_highways)
|
||||
|
||||
def forward(self, x, speaker_embedding=None):
|
||||
x = self.embedding(x)
|
||||
x = self.pre_net(x)
|
||||
x.transpose_(1, 2)
|
||||
x = self.cbhg(x)
|
||||
if speaker_embedding is not None:
|
||||
x = self.add_speaker_embedding(x, speaker_embedding)
|
||||
return x
|
||||
|
||||
def add_speaker_embedding(self, x, speaker_embedding):
|
||||
# SV2TTS
|
||||
# The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
|
||||
# When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
|
||||
# (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
|
||||
# This concats the speaker embedding for each char in the encoder output
|
||||
|
||||
# Save the dimensions as human-readable names
|
||||
batch_size = x.size()[0]
|
||||
num_chars = x.size()[1]
|
||||
|
||||
if speaker_embedding.dim() == 1:
|
||||
idx = 0
|
||||
else:
|
||||
idx = 1
|
||||
|
||||
# Start by making a copy of each speaker embedding to match the input text length
|
||||
# The output of this has size (batch_size, num_chars * speaker_embedding_size)
|
||||
speaker_embedding_size = speaker_embedding.size()[idx]
|
||||
e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
|
||||
|
||||
# Reshape it and transpose
|
||||
e = e.reshape(batch_size, speaker_embedding_size, num_chars)
|
||||
e = e.transpose(1, 2)
|
||||
|
||||
# Concatenate the tiled speaker embedding with the encoder output
|
||||
x = torch.cat((x, e), 2)
|
||||
return x
|
||||
|
||||
|
||||
class BatchNormConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel, relu=True):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
|
||||
self.bnorm = nn.BatchNorm1d(out_channels)
|
||||
self.relu = relu
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = F.relu(x) if self.relu is True else x
|
||||
return self.bnorm(x)
|
||||
|
||||
|
||||
class CBHG(nn.Module):
|
||||
def __init__(self, K, in_channels, channels, proj_channels, num_highways):
|
||||
super().__init__()
|
||||
|
||||
# List of all rnns to call `flatten_parameters()` on
|
||||
self._to_flatten = []
|
||||
|
||||
self.bank_kernels = [i for i in range(1, K + 1)]
|
||||
self.conv1d_bank = nn.ModuleList()
|
||||
for k in self.bank_kernels:
|
||||
conv = BatchNormConv(in_channels, channels, k)
|
||||
self.conv1d_bank.append(conv)
|
||||
|
||||
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
||||
|
||||
self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
|
||||
self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
|
||||
|
||||
# Fix the highway input if necessary
|
||||
if proj_channels[-1] != channels:
|
||||
self.highway_mismatch = True
|
||||
self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
|
||||
else:
|
||||
self.highway_mismatch = False
|
||||
|
||||
self.highways = nn.ModuleList()
|
||||
for i in range(num_highways):
|
||||
hn = HighwayNetwork(channels)
|
||||
self.highways.append(hn)
|
||||
|
||||
self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
|
||||
self._to_flatten.append(self.rnn)
|
||||
|
||||
# Avoid fragmentation of RNN parameters and associated warning
|
||||
self._flatten_parameters()
|
||||
|
||||
def forward(self, x):
|
||||
# Although we `_flatten_parameters()` on init, when using DataParallel
|
||||
# the model gets replicated, making it no longer guaranteed that the
|
||||
# weights are contiguous in GPU memory. Hence, we must call it again
|
||||
self.rnn.flatten_parameters()
|
||||
|
||||
# Save these for later
|
||||
residual = x
|
||||
seq_len = x.size(-1)
|
||||
conv_bank = []
|
||||
|
||||
# Convolution Bank
|
||||
for conv in self.conv1d_bank:
|
||||
c = conv(x) # Convolution
|
||||
conv_bank.append(c[:, :, :seq_len])
|
||||
|
||||
# Stack along the channel axis
|
||||
conv_bank = torch.cat(conv_bank, dim=1)
|
||||
|
||||
# dump the last padding to fit residual
|
||||
x = self.maxpool(conv_bank)[:, :, :seq_len]
|
||||
|
||||
# Conv1d projections
|
||||
x = self.conv_project1(x)
|
||||
x = self.conv_project2(x)
|
||||
|
||||
# Residual Connect
|
||||
x = x + residual
|
||||
|
||||
# Through the highways
|
||||
x = x.transpose(1, 2)
|
||||
if self.highway_mismatch is True:
|
||||
x = self.pre_highway(x)
|
||||
for h in self.highways: x = h(x)
|
||||
|
||||
# And then the RNN
|
||||
x, _ = self.rnn(x)
|
||||
return x
|
||||
|
||||
def _flatten_parameters(self):
|
||||
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
|
||||
to improve efficiency and avoid PyTorch yelling at us."""
|
||||
[m.flatten_parameters() for m in self._to_flatten]
|
||||
|
||||
class PreNet(nn.Module):
|
||||
def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(in_dims, fc1_dims)
|
||||
self.fc2 = nn.Linear(fc1_dims, fc2_dims)
|
||||
self.p = dropout
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = F.dropout(x, self.p, training=True)
|
||||
x = self.fc2(x)
|
||||
x = F.relu(x)
|
||||
x = F.dropout(x, self.p, training=True)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, attn_dims):
|
||||
super().__init__()
|
||||
self.W = nn.Linear(attn_dims, attn_dims, bias=False)
|
||||
self.v = nn.Linear(attn_dims, 1, bias=False)
|
||||
|
||||
def forward(self, encoder_seq_proj, query, t):
|
||||
|
||||
# print(encoder_seq_proj.shape)
|
||||
# Transform the query vector
|
||||
query_proj = self.W(query).unsqueeze(1)
|
||||
|
||||
# Compute the scores
|
||||
u = self.v(torch.tanh(encoder_seq_proj + query_proj))
|
||||
scores = F.softmax(u, dim=1)
|
||||
|
||||
return scores.transpose(1, 2)
|
||||
|
||||
|
||||
class LSA(nn.Module):
|
||||
def __init__(self, attn_dim, kernel_size=31, filters=32):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
|
||||
self.L = nn.Linear(filters, attn_dim, bias=False)
|
||||
self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
|
||||
self.v = nn.Linear(attn_dim, 1, bias=False)
|
||||
self.cumulative = None
|
||||
self.attention = None
|
||||
|
||||
def init_attention(self, encoder_seq_proj):
|
||||
device = encoder_seq_proj.device # use same device as parameters
|
||||
b, t, c = encoder_seq_proj.size()
|
||||
self.cumulative = torch.zeros(b, t, device=device)
|
||||
self.attention = torch.zeros(b, t, device=device)
|
||||
|
||||
def forward(self, encoder_seq_proj, query, t, chars):
|
||||
|
||||
if t == 0: self.init_attention(encoder_seq_proj)
|
||||
|
||||
processed_query = self.W(query).unsqueeze(1)
|
||||
|
||||
location = self.cumulative.unsqueeze(1)
|
||||
processed_loc = self.L(self.conv(location).transpose(1, 2))
|
||||
|
||||
u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
|
||||
u = u.squeeze(-1)
|
||||
|
||||
# Mask zero padding chars
|
||||
u = u * (chars != 0).float()
|
||||
|
||||
# Smooth Attention
|
||||
# scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
|
||||
scores = F.softmax(u, dim=1)
|
||||
self.attention = scores
|
||||
self.cumulative = self.cumulative + self.attention
|
||||
|
||||
return scores.unsqueeze(-1).transpose(1, 2)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
# Class variable because its value doesn't change between classes
|
||||
# yet ought to be scoped by class because its a property of a Decoder
|
||||
max_r = 20
|
||||
def __init__(self, n_mels, input_dims, decoder_dims, lstm_dims,
|
||||
def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims,
|
||||
dropout, speaker_embedding_size):
|
||||
super().__init__()
|
||||
self.register_buffer("r", torch.tensor(1, dtype=torch.int))
|
||||
self.n_mels = n_mels
|
||||
self.prenet = PreNet(n_mels, fc1_dims=decoder_dims * 2, fc2_dims=decoder_dims * 2,
|
||||
prenet_dims = (decoder_dims * 2, decoder_dims * 2)
|
||||
self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
|
||||
dropout=dropout)
|
||||
self.attn_net = LSA(decoder_dims)
|
||||
if hparams.use_gst:
|
||||
speaker_embedding_size += gst_hp.E
|
||||
self.attn_rnn = nn.GRUCell(input_dims + decoder_dims * 2, decoder_dims)
|
||||
self.rnn_input = nn.Linear(input_dims + decoder_dims, lstm_dims)
|
||||
self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
|
||||
self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
|
||||
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
|
||||
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
|
||||
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
|
||||
self.stop_proj = nn.Linear(input_dims + lstm_dims, 1)
|
||||
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
|
||||
|
||||
def zoneout(self, prev, current, device, p=0.1):
|
||||
mask = torch.zeros(prev.size(),device=device).bernoulli_(p)
|
||||
return prev * mask + current * (1 - mask)
|
||||
|
||||
def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
|
||||
hidden_states, cell_states, context_vec, times, chars):
|
||||
"""_summary_
|
||||
hidden_states, cell_states, context_vec, t, chars):
|
||||
|
||||
Args:
|
||||
encoder_seq (3D tensor `[batch_size, text_num_chars, project_dim(default to 512)]`): _description_
|
||||
encoder_seq_proj (3D tensor `[batch_size, text_num_chars, decoder_dims(default to 128)]`): _description_
|
||||
prenet_in (2D tensor `[batch_size, n_mels]`): _description_
|
||||
hidden_states (_type_): _description_
|
||||
cell_states (_type_): _description_
|
||||
context_vec (2D tensor `[batch_size, project_dim(default to 512)]`): _description_
|
||||
times (int): the number of times runned
|
||||
chars (2D tensor with size `[batch_size, text_num_chars]`): original texts list input
|
||||
|
||||
"""
|
||||
# Need this for reshaping mels
|
||||
batch_size = encoder_seq.size(0)
|
||||
device = encoder_seq.device
|
||||
@@ -91,25 +280,25 @@ class Decoder(nn.Module):
|
||||
rnn1_cell, rnn2_cell = cell_states
|
||||
|
||||
# PreNet for the Attention RNN
|
||||
prenet_out = self.prenet(prenet_in) # return: `[batch_size, decoder_dims * 2(256)]`
|
||||
prenet_out = self.prenet(prenet_in)
|
||||
|
||||
# Compute the Attention RNN hidden state
|
||||
attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1) # `[batch_size, project_dim + decoder_dims * 2 (768)]`
|
||||
attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden) # `[batch_size, decoder_dims (128)]`
|
||||
attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
|
||||
attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
|
||||
|
||||
# Compute the attention scores
|
||||
scores = self.attn_net(encoder_seq_proj, attn_hidden, times, chars)
|
||||
scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)
|
||||
|
||||
# Dot product to create the context vector
|
||||
context_vec = scores @ encoder_seq
|
||||
context_vec = context_vec.squeeze(1)
|
||||
|
||||
# Concat Attention RNN output w. Context Vector & project
|
||||
x = torch.cat([context_vec, attn_hidden], dim=1) # `[batch_size, project_dim + decoder_dims (630)]`
|
||||
x = self.rnn_input(x) # `[batch_size, lstm_dims(1024)]`
|
||||
x = torch.cat([context_vec, attn_hidden], dim=1)
|
||||
x = self.rnn_input(x)
|
||||
|
||||
# Compute first Residual RNN, training with fixed zoneout rate 0.1
|
||||
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell)) # `[batch_size, lstm_dims(1024)]`
|
||||
# Compute first Residual RNN
|
||||
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
|
||||
if self.training:
|
||||
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next,device=device)
|
||||
else:
|
||||
@@ -117,7 +306,7 @@ class Decoder(nn.Module):
|
||||
x = x + rnn1_hidden
|
||||
|
||||
# Compute second Residual RNN
|
||||
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell)) # `[batch_size, lstm_dims(1024)]`
|
||||
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
|
||||
if self.training:
|
||||
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next, device=device)
|
||||
else:
|
||||
@@ -125,8 +314,8 @@ class Decoder(nn.Module):
|
||||
x = x + rnn2_hidden
|
||||
|
||||
# Project Mels
|
||||
mels = self.mel_proj(x) # `[batch_size, 1600]`
|
||||
mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r] # `[batch_size, n_mels, r]`
|
||||
mels = self.mel_proj(x)
|
||||
mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
|
||||
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
||||
cell_states = (rnn1_cell, rnn2_cell)
|
||||
|
||||
@@ -137,30 +326,45 @@ class Decoder(nn.Module):
|
||||
|
||||
return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
|
||||
|
||||
class Tacotron(Base):
|
||||
|
||||
class Tacotron(nn.Module):
|
||||
def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels,
|
||||
fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways,
|
||||
dropout, stop_threshold, speaker_embedding_size):
|
||||
super().__init__(stop_threshold)
|
||||
super().__init__()
|
||||
self.n_mels = n_mels
|
||||
self.lstm_dims = lstm_dims
|
||||
self.encoder_dims = encoder_dims
|
||||
self.decoder_dims = decoder_dims
|
||||
self.speaker_embedding_size = speaker_embedding_size
|
||||
self.encoder = Encoder(num_chars, embed_dims, encoder_dims,
|
||||
self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
|
||||
encoder_K, num_highways, dropout)
|
||||
self.project_dims = encoder_dims + speaker_embedding_size
|
||||
project_dims = encoder_dims + speaker_embedding_size
|
||||
if hparams.use_gst:
|
||||
self.project_dims += gst_hp.E
|
||||
self.encoder_proj = nn.Linear(self.project_dims, decoder_dims, bias=False)
|
||||
project_dims += gst_hp.E
|
||||
self.encoder_proj = nn.Linear(project_dims, decoder_dims, bias=False)
|
||||
if hparams.use_gst:
|
||||
self.gst = GlobalStyleToken(speaker_embedding_size)
|
||||
self.decoder = Decoder(n_mels, self.project_dims, decoder_dims, lstm_dims,
|
||||
self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
|
||||
dropout, speaker_embedding_size)
|
||||
self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
|
||||
[postnet_dims, fft_bins], num_highways)
|
||||
self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
|
||||
|
||||
self.init_model()
|
||||
self.num_params()
|
||||
|
||||
self.register_buffer("step", torch.zeros(1, dtype=torch.long))
|
||||
self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
|
||||
|
||||
@property
|
||||
def r(self):
|
||||
return self.decoder.r.item()
|
||||
|
||||
@r.setter
|
||||
def r(self, value):
|
||||
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
||||
|
||||
@staticmethod
|
||||
def _concat_speaker_embedding(outputs, speaker_embeddings):
|
||||
speaker_embeddings_ = speaker_embeddings.expand(
|
||||
@@ -168,52 +372,11 @@ class Tacotron(Base):
|
||||
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _add_speaker_embedding(x, speaker_embedding):
|
||||
"""Add speaker embedding
|
||||
This concats the speaker embedding for each char in the encoder output
|
||||
Args:
|
||||
x (3D tensor with size `[batch_size, text_num_chars, encoder_dims]`): the encoder output
|
||||
speaker_embedding (2D tensor `[batch_size, speaker_embedding_size]`): the speaker embedding
|
||||
|
||||
Returns:
|
||||
3D tensor with size `[batch_size, text_num_chars, encoder_dims+speaker_embedding_size]`
|
||||
"""
|
||||
# Save the dimensions as human-readable names
|
||||
batch_size = x.size()[0]
|
||||
text_num_chars = x.size()[1]
|
||||
|
||||
# Start by making a copy of each speaker embedding to match the input text length
|
||||
# The output of this has size (batch_size, text_num_chars * speaker_embedding_size)
|
||||
speaker_embedding_size = speaker_embedding.size()[1]
|
||||
e = speaker_embedding.repeat_interleave(text_num_chars, dim=1)
|
||||
|
||||
# Reshape it and transpose
|
||||
e = e.reshape(batch_size, speaker_embedding_size, text_num_chars)
|
||||
e = e.transpose(1, 2)
|
||||
|
||||
# Concatenate the tiled speaker embedding with the encoder output
|
||||
x = torch.cat((x, e), 2)
|
||||
return x
|
||||
|
||||
def forward(self, texts, mels, speaker_embedding, steps=2000, style_idx=0, min_stop_token=5):
|
||||
"""Forward pass for Tacotron
|
||||
|
||||
Args:
|
||||
texts (`[batch_size, text_num_chars]`): input texts list
|
||||
mels (`[batch_size, varied_mel_lengths, steps]`): mels for comparison (training only)
|
||||
speaker_embedding (`[batch_size, speaker_embedding_size(default to 256)]`): referring embedding.
|
||||
steps (int, optional): . Defaults to 2000.
|
||||
style_idx (int, optional): GST style selected. Defaults to 0.
|
||||
min_stop_token (int, optional): decoder min_stop_token. Defaults to 5.
|
||||
"""
|
||||
def forward(self, texts, mels, speaker_embedding):
|
||||
device = texts.device # use same device as parameters
|
||||
|
||||
if self.training:
|
||||
self.step += 1
|
||||
batch_size, _, steps = mels.size()
|
||||
else:
|
||||
batch_size, _ = texts.size()
|
||||
self.step += 1
|
||||
batch_size, _, steps = mels.size()
|
||||
|
||||
# Initialise all hidden states and pack into tuple
|
||||
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
||||
@@ -229,50 +392,35 @@ class Tacotron(Base):
|
||||
# <GO> Frame for start of decoder loop
|
||||
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
||||
|
||||
# Need an initial context vector
|
||||
size = self.encoder_dims + self.speaker_embedding_size
|
||||
if hparams.use_gst:
|
||||
size += gst_hp.E
|
||||
context_vec = torch.zeros(batch_size, size, device=device)
|
||||
|
||||
# SV2TTS: Run the encoder with the speaker embedding
|
||||
# The projection avoids unnecessary matmuls in the decoder loop
|
||||
encoder_seq = self.encoder(texts)
|
||||
|
||||
encoder_seq = self._add_speaker_embedding(encoder_seq, speaker_embedding)
|
||||
|
||||
encoder_seq = self.encoder(texts, speaker_embedding)
|
||||
# put after encoder
|
||||
if hparams.use_gst and self.gst is not None:
|
||||
if self.training:
|
||||
style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced
|
||||
# style_embed = style_embed.expand_as(encoder_seq)
|
||||
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
||||
elif style_idx >= 0 and style_idx < 10:
|
||||
query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
|
||||
if device.type == 'cuda':
|
||||
query = query.cuda()
|
||||
gst_embed = torch.tanh(self.gst.stl.embed)
|
||||
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
|
||||
style_embed = self.gst.stl.attention(query, key)
|
||||
else:
|
||||
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
|
||||
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
||||
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed) # return: [batch_size, text_num_chars, project_dims]
|
||||
|
||||
encoder_seq_proj = self.encoder_proj(encoder_seq) # return: [batch_size, text_num_chars, decoder_dims]
|
||||
style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced
|
||||
# style_embed = style_embed.expand_as(encoder_seq)
|
||||
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
||||
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
||||
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
||||
|
||||
# Need a couple of lists for outputs
|
||||
mel_outputs, attn_scores, stop_outputs = [], [], []
|
||||
|
||||
# Need an initial context vector
|
||||
context_vec = torch.zeros(batch_size, self.project_dims, device=device)
|
||||
|
||||
# Run the decoder loop
|
||||
for t in range(0, steps, self.r):
|
||||
if self.training:
|
||||
prenet_in = mels[:, :, t -1] if t > 0 else go_frame
|
||||
else:
|
||||
prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
|
||||
prenet_in = mels[:, :, t - 1] if t > 0 else go_frame
|
||||
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
||||
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
||||
hidden_states, cell_states, context_vec, t, texts)
|
||||
mel_outputs.append(mel_frames)
|
||||
attn_scores.append(scores)
|
||||
stop_outputs.extend([stop_tokens] * self.r)
|
||||
if not self.training and (stop_tokens * 10 > min_stop_token).all() and t > 10: break
|
||||
|
||||
# Concat the mel outputs into sequence
|
||||
mel_outputs = torch.cat(mel_outputs, dim=2)
|
||||
@@ -287,12 +435,135 @@ class Tacotron(Base):
|
||||
# attn_scores = attn_scores.cpu().data.numpy()
|
||||
stop_outputs = torch.cat(stop_outputs, 1)
|
||||
|
||||
if self.training:
|
||||
self.train()
|
||||
|
||||
return mel_outputs, linear, attn_scores, stop_outputs
|
||||
|
||||
def generate(self, x, speaker_embedding, steps=2000, style_idx=0, min_stop_token=5):
|
||||
def generate(self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5):
|
||||
self.eval()
|
||||
mel_outputs, linear, attn_scores, _ = self.forward(x, None, speaker_embedding, steps, style_idx, min_stop_token)
|
||||
device = x.device # use same device as parameters
|
||||
|
||||
batch_size, _ = x.size()
|
||||
|
||||
# Need to initialise all hidden states and pack into tuple for tidyness
|
||||
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
||||
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
||||
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
||||
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
||||
|
||||
# Need to initialise all lstm cell states and pack into tuple for tidyness
|
||||
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
||||
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
||||
cell_states = (rnn1_cell, rnn2_cell)
|
||||
|
||||
# Need a <GO> Frame for start of decoder loop
|
||||
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
||||
|
||||
# Need an initial context vector
|
||||
size = self.encoder_dims + self.speaker_embedding_size
|
||||
if hparams.use_gst:
|
||||
size += gst_hp.E
|
||||
context_vec = torch.zeros(batch_size, size, device=device)
|
||||
|
||||
# SV2TTS: Run the encoder with the speaker embedding
|
||||
# The projection avoids unnecessary matmuls in the decoder loop
|
||||
encoder_seq = self.encoder(x, speaker_embedding)
|
||||
|
||||
# put after encoder
|
||||
if hparams.use_gst and self.gst is not None:
|
||||
if style_idx >= 0 and style_idx < 10:
|
||||
query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
|
||||
if device.type == 'cuda':
|
||||
query = query.cuda()
|
||||
gst_embed = torch.tanh(self.gst.stl.embed)
|
||||
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
|
||||
style_embed = self.gst.stl.attention(query, key)
|
||||
else:
|
||||
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
|
||||
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
||||
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
||||
# style_embed = style_embed.expand_as(encoder_seq)
|
||||
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
||||
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
||||
|
||||
# Need a couple of lists for outputs
|
||||
mel_outputs, attn_scores, stop_outputs = [], [], []
|
||||
|
||||
# Run the decoder loop
|
||||
for t in range(0, steps, self.r):
|
||||
prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
|
||||
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
||||
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
||||
hidden_states, cell_states, context_vec, t, x)
|
||||
mel_outputs.append(mel_frames)
|
||||
attn_scores.append(scores)
|
||||
stop_outputs.extend([stop_tokens] * self.r)
|
||||
# Stop the loop when all stop tokens in batch exceed threshold
|
||||
if (stop_tokens * 10 > min_stop_token).all() and t > 10: break
|
||||
|
||||
# Concat the mel outputs into sequence
|
||||
mel_outputs = torch.cat(mel_outputs, dim=2)
|
||||
|
||||
# Post-Process for Linear Spectrograms
|
||||
postnet_out = self.postnet(mel_outputs)
|
||||
linear = self.post_proj(postnet_out)
|
||||
|
||||
|
||||
linear = linear.transpose(1, 2)
|
||||
|
||||
# For easy visualisation
|
||||
attn_scores = torch.cat(attn_scores, 1)
|
||||
stop_outputs = torch.cat(stop_outputs, 1)
|
||||
|
||||
self.train()
|
||||
|
||||
return mel_outputs, linear, attn_scores
|
||||
|
||||
def init_model(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1: nn.init.xavier_uniform_(p)
|
||||
|
||||
def finetune_partial(self, whitelist_layers):
|
||||
self.zero_grad()
|
||||
for name, child in self.named_children():
|
||||
if name in whitelist_layers:
|
||||
print("Trainable Layer: %s" % name)
|
||||
print("Trainable Parameters: %.3f" % sum([np.prod(p.size()) for p in child.parameters()]))
|
||||
for param in child.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def get_step(self):
|
||||
return self.step.data.item()
|
||||
|
||||
def reset_step(self):
|
||||
# assignment to parameters or buffers is overloaded, updates internal dict entry
|
||||
self.step = self.step.data.new_tensor(1)
|
||||
|
||||
def log(self, path, msg):
|
||||
with open(path, "a") as f:
|
||||
print(msg, file=f)
|
||||
|
||||
def load(self, path, device, optimizer=None):
|
||||
# Use device of model params as location for loaded state
|
||||
checkpoint = torch.load(str(path), map_location=device)
|
||||
self.load_state_dict(checkpoint["model_state"], strict=False)
|
||||
|
||||
if "optimizer_state" in checkpoint and optimizer is not None:
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
||||
|
||||
def save(self, path, optimizer=None):
|
||||
if optimizer is not None:
|
||||
torch.save({
|
||||
"model_state": self.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
}, str(path))
|
||||
else:
|
||||
torch.save({
|
||||
"model_state": self.state_dict(),
|
||||
}, str(path))
|
||||
|
||||
|
||||
def num_params(self, print_out=True):
|
||||
parameters = filter(lambda p: p.requires_grad, self.parameters())
|
||||
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
||||
if print_out:
|
||||
print("Trainable Parameters: %.3fM" % parameters)
|
||||
return parameters
|
||||
|
||||
@@ -63,7 +63,7 @@ def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
|
||||
|
||||
def _split_on_silences(wav_fpath, words, hparams):
|
||||
# Load the audio waveform
|
||||
wav, _ = librosa.load(wav_fpath, sr= hparams.sample_rate)
|
||||
wav, _ = librosa.load(wav_fpath, hparams.sample_rate)
|
||||
wav = librosa.effects.trim(wav, top_db= 40, frame_length=2048, hop_length=512)[0]
|
||||
if hparams.rescale:
|
||||
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
||||
|
||||
@@ -15,8 +15,9 @@ from datetime import datetime
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
|
||||
|
||||
def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
|
||||
|
||||
@@ -264,19 +265,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||
loss=loss,
|
||||
hparams=hparams,
|
||||
sw=sw)
|
||||
MAX_SAVED_COUNT = 20
|
||||
if (step / hparams.tts_eval_interval) % MAX_SAVED_COUNT == 0:
|
||||
# clean up and save last MAX_SAVED_COUNT;
|
||||
plots = next(os.walk(plot_dir), (None, None, []))[2]
|
||||
for plot in plots[-MAX_SAVED_COUNT:]:
|
||||
os.remove(plot_dir.joinpath(plot))
|
||||
mel_files = next(os.walk(mel_output_dir), (None, None, []))[2]
|
||||
for mel_file in mel_files[-MAX_SAVED_COUNT:]:
|
||||
os.remove(mel_output_dir.joinpath(mel_file))
|
||||
wavs = next(os.walk(wav_dir), (None, None, []))[2]
|
||||
for w in wavs[-MAX_SAVED_COUNT:]:
|
||||
os.remove(wav_dir.joinpath(w))
|
||||
|
||||
|
||||
# Break out of loop to update training schedule
|
||||
if step >= max_step:
|
||||
break
|
||||
|
||||
@@ -3,10 +3,12 @@ from encoder import inference as encoder
|
||||
from synthesizer.inference import Synthesizer
|
||||
from vocoder.wavernn import inference as rnn_vocoder
|
||||
from vocoder.hifigan import inference as gan_vocoder
|
||||
from vocoder.fregan import inference as fgan_vocoder
|
||||
import ppg_extractor as extractor
|
||||
import ppg2mel as convertor
|
||||
from pathlib import Path
|
||||
from time import perf_counter as timer
|
||||
from toolbox.utterance import Utterance
|
||||
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
|
||||
import numpy as np
|
||||
import traceback
|
||||
import sys
|
||||
@@ -371,8 +373,6 @@ class Toolbox:
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
ref_wav = self.ui.selected_utterance.wav
|
||||
# Import necessary dependency of Voice Conversion
|
||||
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
|
||||
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
|
||||
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
|
||||
min_len = min(ppg.shape[1], len(lf0_uv))
|
||||
@@ -397,7 +397,6 @@ class Toolbox:
|
||||
self.ui.log("Loading the extractor %s... " % model_fpath)
|
||||
self.ui.set_loading(1)
|
||||
start = timer()
|
||||
import ppg_extractor as extractor
|
||||
self.extractor = extractor.load_model(model_fpath)
|
||||
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
||||
self.ui.set_loading(0)
|
||||
@@ -406,11 +405,15 @@ class Toolbox:
|
||||
if self.ui.current_convertor_fpath is None:
|
||||
return
|
||||
model_fpath = self.ui.current_convertor_fpath
|
||||
# search a config file
|
||||
model_config_fpaths = list(model_fpath.parent.rglob("*.yaml"))
|
||||
if self.ui.current_convertor_fpath is None:
|
||||
return
|
||||
model_config_fpath = model_config_fpaths[0]
|
||||
self.ui.log("Loading the convertor %s... " % model_fpath)
|
||||
self.ui.set_loading(1)
|
||||
start = timer()
|
||||
import ppg2mel as convertor
|
||||
self.convertor = convertor.load_model( model_fpath)
|
||||
self.convertor = convertor.load_model(model_config_fpath, model_fpath)
|
||||
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
||||
self.ui.set_loading(0)
|
||||
|
||||
@@ -443,24 +446,14 @@ class Toolbox:
|
||||
return
|
||||
# Sekect vocoder based on model name
|
||||
model_config_fpath = None
|
||||
if model_fpath.name is not None and model_fpath.name.find("hifigan") > -1:
|
||||
if model_fpath.name[0] == "g":
|
||||
vocoder = gan_vocoder
|
||||
self.ui.log("set hifigan as vocoder")
|
||||
# search a config file
|
||||
model_config_fpaths = list(model_fpath.parent.rglob("*.json"))
|
||||
if self.vc_mode and self.ui.current_extractor_fpath is None:
|
||||
if self.ui.current_extractor_fpath is None:
|
||||
return
|
||||
if len(model_config_fpaths) > 0:
|
||||
model_config_fpath = model_config_fpaths[0]
|
||||
elif model_fpath.name is not None and model_fpath.name.find("fregan") > -1:
|
||||
vocoder = fgan_vocoder
|
||||
self.ui.log("set fregan as vocoder")
|
||||
# search a config file
|
||||
model_config_fpaths = list(model_fpath.parent.rglob("*.json"))
|
||||
if self.vc_mode and self.ui.current_extractor_fpath is None:
|
||||
return
|
||||
if len(model_config_fpaths) > 0:
|
||||
model_config_fpath = model_config_fpaths[0]
|
||||
model_config_fpath = model_config_fpaths[0]
|
||||
else:
|
||||
vocoder = rnn_vocoder
|
||||
self.ui.log("set wavernn as vocoder")
|
||||
|
||||
67
train.py
67
train.py
@@ -1,67 +0,0 @@
|
||||
import sys
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from utils.load_yaml import HpsYaml
|
||||
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||
|
||||
# For reproducibility, comment these may speed up training
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def main():
|
||||
# Arguments
|
||||
parser = argparse.ArgumentParser(description=
|
||||
'Training PPG2Mel VC model.')
|
||||
parser.add_argument('--config', type=str,
|
||||
help='Path to experiment config, e.g., config/vc.yaml')
|
||||
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
||||
parser.add_argument('--logdir', default='log/', type=str,
|
||||
help='Logging path.', required=False)
|
||||
parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str,
|
||||
help='Checkpoint path.', required=False)
|
||||
parser.add_argument('--outdir', default='result/', type=str,
|
||||
help='Decode output path.', required=False)
|
||||
parser.add_argument('--load', default=None, type=str,
|
||||
help='Load pre-trained model (for training only)', required=False)
|
||||
parser.add_argument('--warm_start', action='store_true',
|
||||
help='Load model weights only, ignore specified layers.')
|
||||
parser.add_argument('--seed', default=0, type=int,
|
||||
help='Random seed for reproducable results.', required=False)
|
||||
parser.add_argument('--njobs', default=8, type=int,
|
||||
help='Number of threads for dataloader/decoding.', required=False)
|
||||
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
||||
parser.add_argument('--no-pin', action='store_true',
|
||||
help='Disable pin-memory for dataloader')
|
||||
parser.add_argument('--test', action='store_true', help='Test the model.')
|
||||
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
||||
parser.add_argument('--finetune', action='store_true', help='Finetune model')
|
||||
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
|
||||
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
|
||||
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
|
||||
|
||||
###
|
||||
|
||||
paras = parser.parse_args()
|
||||
setattr(paras, 'gpu', not paras.cpu)
|
||||
setattr(paras, 'pin_memory', not paras.no_pin)
|
||||
setattr(paras, 'verbose', not paras.no_msg)
|
||||
# Make the config dict dot visitable
|
||||
config = HpsYaml(paras.config)
|
||||
|
||||
np.random.seed(paras.seed)
|
||||
torch.manual_seed(paras.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(paras.seed)
|
||||
|
||||
print(">>> OneShot VC training ...")
|
||||
mode = "train"
|
||||
solver = Solver(config, paras, mode)
|
||||
solver.load_data()
|
||||
solver.set_model()
|
||||
solver.exec()
|
||||
print(">>> Oneshot VC train finished!")
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -42,9 +42,3 @@ def human_format(num):
|
||||
# add more suffixes if you need them
|
||||
return '{:3.1f}{}'.format(num, [' ', 'K', 'M', 'G', 'T', 'P'][magnitude])
|
||||
|
||||
|
||||
# provide easy access of attribute from dict, such abc.key
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
129
vocoder/fregan/.gitignore
vendored
129
vocoder/fregan/.gitignore
vendored
@@ -1,129 +0,0 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
@@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2021 Rishikesh (ऋषिकेश)
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -1,42 +0,0 @@
|
||||
{
|
||||
"resblock": "1",
|
||||
"num_gpus": 0,
|
||||
"batch_size": 16,
|
||||
"learning_rate": 0.0002,
|
||||
"adam_b1": 0.8,
|
||||
"adam_b2": 0.99,
|
||||
"lr_decay": 0.999,
|
||||
"seed": 1234,
|
||||
"disc_start_step":0,
|
||||
|
||||
|
||||
"upsample_rates": [5,5,2,2,2],
|
||||
"upsample_kernel_sizes": [10,10,4,4,4],
|
||||
"upsample_initial_channel": 512,
|
||||
"resblock_kernel_sizes": [3,7,11],
|
||||
"resblock_dilation_sizes": [[1, 3, 5, 7], [1,3,5,7], [1,3,5,7]],
|
||||
|
||||
"segment_size": 6400,
|
||||
"num_mels": 80,
|
||||
"num_freq": 1025,
|
||||
"n_fft": 1024,
|
||||
"hop_size": 200,
|
||||
"win_size": 800,
|
||||
|
||||
"sampling_rate": 16000,
|
||||
|
||||
"fmin": 0,
|
||||
"fmax": 7600,
|
||||
"fmax_for_loss": null,
|
||||
|
||||
"num_workers": 4,
|
||||
|
||||
"dist_config": {
|
||||
"dist_backend": "nccl",
|
||||
"dist_url": "tcp://localhost:54321",
|
||||
"world_size": 1
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
@@ -1,303 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from torch.nn import Conv1d, AvgPool1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, spectral_norm
|
||||
from vocoder.fregan.utils import get_padding
|
||||
from vocoder.fregan.stft_loss import stft
|
||||
from vocoder.fregan.dwt import DWT_1D
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
|
||||
class SpecDiscriminator(nn.Module):
|
||||
"""docstring for Discriminator."""
|
||||
|
||||
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
|
||||
super(SpecDiscriminator, self).__init__()
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.fft_size = fft_size
|
||||
self.shift_size = shift_size
|
||||
self.win_length = win_length
|
||||
self.window = getattr(torch, window)(win_length)
|
||||
self.discriminators = nn.ModuleList([
|
||||
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1,1), padding=(1, 1))),
|
||||
])
|
||||
|
||||
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
|
||||
|
||||
def forward(self, y):
|
||||
|
||||
fmap = []
|
||||
with torch.no_grad():
|
||||
y = y.squeeze(1)
|
||||
y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.get_device()))
|
||||
y = y.unsqueeze(1)
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y = d(y)
|
||||
y = F.leaky_relu(y, LRELU_SLOPE)
|
||||
fmap.append(y)
|
||||
|
||||
y = self.out(y)
|
||||
fmap.append(y)
|
||||
|
||||
return torch.flatten(y, 1, -1), fmap
|
||||
|
||||
class MultiResSpecDiscriminator(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
fft_sizes=[1024, 2048, 512],
|
||||
hop_sizes=[120, 240, 50],
|
||||
win_lengths=[600, 1200, 240],
|
||||
window="hann_window"):
|
||||
|
||||
super(MultiResSpecDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList([
|
||||
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
|
||||
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
|
||||
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)
|
||||
])
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super(DiscriminatorP, self).__init__()
|
||||
self.period = period
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.dwt1d = DWT_1D()
|
||||
self.dwt_conv1 = norm_f(Conv1d(2, 1, 1))
|
||||
self.dwt_proj1 = norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0)))
|
||||
self.dwt_conv2 = norm_f(Conv1d(4, 1, 1))
|
||||
self.dwt_proj2 = norm_f(Conv2d(1, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0)))
|
||||
self.dwt_conv3 = norm_f(Conv1d(8, 1, 1))
|
||||
self.dwt_proj3 = norm_f(Conv2d(1, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0)))
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||
])
|
||||
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# DWT 1
|
||||
x_d1_high1, x_d1_low1 = self.dwt1d(x)
|
||||
x_d1 = self.dwt_conv1(torch.cat([x_d1_high1, x_d1_low1], dim=1))
|
||||
# 1d to 2d
|
||||
b, c, t = x_d1.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x_d1 = F.pad(x_d1, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x_d1 = x_d1.view(b, c, t // self.period, self.period)
|
||||
|
||||
x_d1 = self.dwt_proj1(x_d1)
|
||||
|
||||
# DWT 2
|
||||
x_d2_high1, x_d2_low1 = self.dwt1d(x_d1_high1)
|
||||
x_d2_high2, x_d2_low2 = self.dwt1d(x_d1_low1)
|
||||
x_d2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1))
|
||||
# 1d to 2d
|
||||
b, c, t = x_d2.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x_d2 = F.pad(x_d2, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x_d2 = x_d2.view(b, c, t // self.period, self.period)
|
||||
|
||||
x_d2 = self.dwt_proj2(x_d2)
|
||||
|
||||
# DWT 3
|
||||
|
||||
x_d3_high1, x_d3_low1 = self.dwt1d(x_d2_high1)
|
||||
x_d3_high2, x_d3_low2 = self.dwt1d(x_d2_low1)
|
||||
x_d3_high3, x_d3_low3 = self.dwt1d(x_d2_high2)
|
||||
x_d3_high4, x_d3_low4 = self.dwt1d(x_d2_low2)
|
||||
x_d3 = self.dwt_conv3(
|
||||
torch.cat([x_d3_high1, x_d3_low1, x_d3_high2, x_d3_low2, x_d3_high3, x_d3_low3, x_d3_high4, x_d3_low4],
|
||||
dim=1))
|
||||
# 1d to 2d
|
||||
b, c, t = x_d3.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x_d3 = F.pad(x_d3, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x_d3 = x_d3.view(b, c, t // self.period, self.period)
|
||||
|
||||
x_d3 = self.dwt_proj3(x_d3)
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
i = 0
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
|
||||
fmap.append(x)
|
||||
if i == 0:
|
||||
x = torch.cat([x, x_d1], dim=2)
|
||||
elif i == 1:
|
||||
x = torch.cat([x, x_d2], dim=2)
|
||||
elif i == 2:
|
||||
x = torch.cat([x, x_d3], dim=2)
|
||||
else:
|
||||
x = x
|
||||
i = i + 1
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class ResWiseMultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(ResWiseMultiPeriodDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList([
|
||||
DiscriminatorP(2),
|
||||
DiscriminatorP(3),
|
||||
DiscriminatorP(5),
|
||||
DiscriminatorP(7),
|
||||
DiscriminatorP(11),
|
||||
])
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(DiscriminatorS, self).__init__()
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.dwt1d = DWT_1D()
|
||||
self.dwt_conv1 = norm_f(Conv1d(2, 128, 15, 1, padding=7))
|
||||
self.dwt_conv2 = norm_f(Conv1d(4, 128, 41, 2, padding=20))
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
||||
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
||||
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
||||
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
])
|
||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# DWT 1
|
||||
x_d1_high1, x_d1_low1 = self.dwt1d(x)
|
||||
x_d1 = self.dwt_conv1(torch.cat([x_d1_high1, x_d1_low1], dim=1))
|
||||
|
||||
# DWT 2
|
||||
x_d2_high1, x_d2_low1 = self.dwt1d(x_d1_high1)
|
||||
x_d2_high2, x_d2_low2 = self.dwt1d(x_d1_low1)
|
||||
x_d2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1))
|
||||
|
||||
i = 0
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
if i == 0:
|
||||
x = torch.cat([x, x_d1], dim=2)
|
||||
if i == 1:
|
||||
x = torch.cat([x, x_d2], dim=2)
|
||||
i = i + 1
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class ResWiseMultiScaleDiscriminator(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(ResWiseMultiScaleDiscriminator, self).__init__()
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.dwt1d = DWT_1D()
|
||||
self.dwt_conv1 = norm_f(Conv1d(2, 1, 1))
|
||||
self.dwt_conv2 = norm_f(Conv1d(4, 1, 1))
|
||||
self.discriminators = nn.ModuleList([
|
||||
DiscriminatorS(use_spectral_norm=True),
|
||||
DiscriminatorS(),
|
||||
DiscriminatorS(),
|
||||
])
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
# DWT 1
|
||||
y_hi, y_lo = self.dwt1d(y)
|
||||
y_1 = self.dwt_conv1(torch.cat([y_hi, y_lo], dim=1))
|
||||
x_d1_high1, x_d1_low1 = self.dwt1d(y_hat)
|
||||
y_hat_1 = self.dwt_conv1(torch.cat([x_d1_high1, x_d1_low1], dim=1))
|
||||
|
||||
# DWT 2
|
||||
x_d2_high1, x_d2_low1 = self.dwt1d(y_hi)
|
||||
x_d2_high2, x_d2_low2 = self.dwt1d(y_lo)
|
||||
y_2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1))
|
||||
|
||||
x_d2_high1, x_d2_low1 = self.dwt1d(x_d1_high1)
|
||||
x_d2_high2, x_d2_low2 = self.dwt1d(x_d1_low1)
|
||||
y_hat_2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1))
|
||||
|
||||
for i, d in enumerate(self.discriminators):
|
||||
|
||||
if i == 1:
|
||||
y = y_1
|
||||
y_hat = y_hat_1
|
||||
if i == 2:
|
||||
y = y_2
|
||||
y_hat = y_hat_2
|
||||
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
@@ -1,76 +0,0 @@
|
||||
# Copyright (c) 2019, Adobe Inc. All rights reserved.
|
||||
#
|
||||
# This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
|
||||
# 4.0 International Public License. To view a copy of this license, visit
|
||||
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
|
||||
|
||||
# DWT code borrow from https://github.com/LiQiufu/WaveSNet/blob/12cb9d24208c3d26917bf953618c30f0c6b0f03d/DWT_IDWT/DWT_IDWT_layer.py
|
||||
|
||||
|
||||
import pywt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = ['DWT_1D']
|
||||
Pad_Mode = ['constant', 'reflect', 'replicate', 'circular']
|
||||
|
||||
|
||||
class DWT_1D(nn.Module):
|
||||
def __init__(self, pad_type='reflect', wavename='haar',
|
||||
stride=2, in_channels=1, out_channels=None, groups=None,
|
||||
kernel_size=None, trainable=False):
|
||||
|
||||
super(DWT_1D, self).__init__()
|
||||
self.trainable = trainable
|
||||
self.kernel_size = kernel_size
|
||||
if not self.trainable:
|
||||
assert self.kernel_size == None
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = self.in_channels if out_channels == None else out_channels
|
||||
self.groups = self.in_channels if groups == None else groups
|
||||
assert isinstance(self.groups, int) and self.in_channels % self.groups == 0
|
||||
self.stride = stride
|
||||
assert self.stride == 2
|
||||
self.wavename = wavename
|
||||
self.pad_type = pad_type
|
||||
assert self.pad_type in Pad_Mode
|
||||
self.get_filters()
|
||||
self.initialization()
|
||||
|
||||
def get_filters(self):
|
||||
wavelet = pywt.Wavelet(self.wavename)
|
||||
band_low = torch.tensor(wavelet.rec_lo)
|
||||
band_high = torch.tensor(wavelet.rec_hi)
|
||||
length_band = band_low.size()[0]
|
||||
self.kernel_size = length_band if self.kernel_size == None else self.kernel_size
|
||||
assert self.kernel_size >= length_band
|
||||
a = (self.kernel_size - length_band) // 2
|
||||
b = - (self.kernel_size - length_band - a)
|
||||
b = None if b == 0 else b
|
||||
self.filt_low = torch.zeros(self.kernel_size)
|
||||
self.filt_high = torch.zeros(self.kernel_size)
|
||||
self.filt_low[a:b] = band_low
|
||||
self.filt_high[a:b] = band_high
|
||||
|
||||
def initialization(self):
|
||||
self.filter_low = self.filt_low[None, None, :].repeat((self.out_channels, self.in_channels // self.groups, 1))
|
||||
self.filter_high = self.filt_high[None, None, :].repeat((self.out_channels, self.in_channels // self.groups, 1))
|
||||
if torch.cuda.is_available():
|
||||
self.filter_low = self.filter_low.cuda()
|
||||
self.filter_high = self.filter_high.cuda()
|
||||
if self.trainable:
|
||||
self.filter_low = nn.Parameter(self.filter_low)
|
||||
self.filter_high = nn.Parameter(self.filter_high)
|
||||
if self.kernel_size % 2 == 0:
|
||||
self.pad_sizes = [self.kernel_size // 2 - 1, self.kernel_size // 2 - 1]
|
||||
else:
|
||||
self.pad_sizes = [self.kernel_size // 2, self.kernel_size // 2]
|
||||
|
||||
def forward(self, input):
|
||||
assert isinstance(input, torch.Tensor)
|
||||
assert len(input.size()) == 3
|
||||
assert input.size()[1] == self.in_channels
|
||||
input = F.pad(input, pad=self.pad_sizes, mode=self.pad_type)
|
||||
return F.conv1d(input, self.filter_low.to(input.device), stride=self.stride, groups=self.groups), \
|
||||
F.conv1d(input, self.filter_high.to(input.device), stride=self.stride, groups=self.groups)
|
||||
@@ -1,210 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from vocoder.fregan.utils import init_weights, get_padding
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5, 7)):
|
||||
super(ResBlock1, self).__init__()
|
||||
self.h = h
|
||||
self.convs1 = nn.ModuleList([
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[3],
|
||||
padding=get_padding(kernel_size, dilation[3])))
|
||||
])
|
||||
self.convs1.apply(init_weights)
|
||||
|
||||
self.convs2 = nn.ModuleList([
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1)))
|
||||
])
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super(ResBlock2, self).__init__()
|
||||
self.h = h
|
||||
self.convs = nn.ModuleList([
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1])))
|
||||
])
|
||||
self.convs.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class FreGAN(torch.nn.Module):
|
||||
def __init__(self, h, top_k=4):
|
||||
super(FreGAN, self).__init__()
|
||||
self.h = h
|
||||
|
||||
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
self.upsample_rates = h.upsample_rates
|
||||
self.up_kernels = h.upsample_kernel_sizes
|
||||
self.cond_level = self.num_upsamples - top_k
|
||||
self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
|
||||
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
self.cond_up = nn.ModuleList()
|
||||
self.res_output = nn.ModuleList()
|
||||
upsample_ = 1
|
||||
kr = 80
|
||||
|
||||
for i, (u, k) in enumerate(zip(self.upsample_rates, self.up_kernels)):
|
||||
# self.ups.append(weight_norm(
|
||||
# ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)),
|
||||
# k, u, padding=(k - u) // 2)))
|
||||
self.ups.append(weight_norm(ConvTranspose1d(h.upsample_initial_channel//(2**i),
|
||||
h.upsample_initial_channel//(2**(i+1)),
|
||||
k, u, padding=(u//2 + u%2), output_padding=u%2)))
|
||||
|
||||
if i > (self.num_upsamples - top_k):
|
||||
self.res_output.append(
|
||||
nn.Sequential(
|
||||
nn.Upsample(scale_factor=u, mode='nearest'),
|
||||
weight_norm(nn.Conv1d(h.upsample_initial_channel // (2 ** i),
|
||||
h.upsample_initial_channel // (2 ** (i + 1)), 1))
|
||||
)
|
||||
)
|
||||
if i >= (self.num_upsamples - top_k):
|
||||
self.cond_up.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(kr, h.upsample_initial_channel // (2 ** i),
|
||||
self.up_kernels[i - 1], self.upsample_rates[i - 1],
|
||||
padding=(self.upsample_rates[i-1]//2+self.upsample_rates[i-1]%2), output_padding=self.upsample_rates[i-1]%2))
|
||||
)
|
||||
kr = h.upsample_initial_channel // (2 ** i)
|
||||
|
||||
upsample_ *= u
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(h, ch, k, d))
|
||||
|
||||
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
||||
self.ups.apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
self.cond_up.apply(init_weights)
|
||||
self.res_output.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
mel = x
|
||||
x = self.conv_pre(x)
|
||||
output = None
|
||||
for i in range(self.num_upsamples):
|
||||
if i >= self.cond_level:
|
||||
mel = self.cond_up[i - self.cond_level](mel)
|
||||
x += mel
|
||||
if i > self.cond_level:
|
||||
if output is None:
|
||||
output = self.res_output[i - self.cond_level - 1](x)
|
||||
else:
|
||||
output = self.res_output[i - self.cond_level - 1](output)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
if output is not None:
|
||||
output = output + x
|
||||
|
||||
x = F.leaky_relu(output)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
for l in self.cond_up:
|
||||
remove_weight_norm(l)
|
||||
for l in self.res_output:
|
||||
remove_weight_norm(l[1])
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
|
||||
'''
|
||||
to run this, fix
|
||||
from . import ResStack
|
||||
into
|
||||
from res_stack import ResStack
|
||||
'''
|
||||
if __name__ == '__main__':
|
||||
'''
|
||||
torch.Size([3, 80, 10])
|
||||
torch.Size([3, 1, 2000])
|
||||
4527362
|
||||
'''
|
||||
with open('config.json') as f:
|
||||
data = f.read()
|
||||
from utils import AttrDict
|
||||
import json
|
||||
json_config = json.loads(data)
|
||||
h = AttrDict(json_config)
|
||||
model = FreGAN(h)
|
||||
|
||||
c = torch.randn(3, 80, 10) # (B, channels, T).
|
||||
print(c.shape)
|
||||
|
||||
y = model(c) # (B, 1, T ** prod(upsample_scales)
|
||||
print(y.shape)
|
||||
assert y.shape == torch.Size([3, 1, 2560]) # For normal melgan torch.Size([3, 1, 2560])
|
||||
|
||||
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(pytorch_total_params)
|
||||
@@ -1,74 +0,0 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
from utils.util import AttrDict
|
||||
from vocoder.fregan.generator import FreGAN
|
||||
|
||||
generator = None # type: FreGAN
|
||||
output_sample_rate = None
|
||||
_device = None
|
||||
|
||||
|
||||
def load_checkpoint(filepath, device):
|
||||
assert os.path.isfile(filepath)
|
||||
print("Loading '{}'".format(filepath))
|
||||
checkpoint_dict = torch.load(filepath, map_location=device)
|
||||
print("Complete.")
|
||||
return checkpoint_dict
|
||||
|
||||
|
||||
def load_model(weights_fpath, config_fpath=None, verbose=True):
|
||||
global generator, _device, output_sample_rate
|
||||
|
||||
if verbose:
|
||||
print("Building fregan")
|
||||
|
||||
if config_fpath == None:
|
||||
model_config_fpaths = list(weights_fpath.parent.rglob("*.json"))
|
||||
if len(model_config_fpaths) > 0:
|
||||
config_fpath = model_config_fpaths[0]
|
||||
else:
|
||||
config_fpath = "./vocoder/fregan/config.json"
|
||||
with open(config_fpath) as f:
|
||||
data = f.read()
|
||||
json_config = json.loads(data)
|
||||
h = AttrDict(json_config)
|
||||
output_sample_rate = h.sampling_rate
|
||||
torch.manual_seed(h.seed)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# _model = _model.cuda()
|
||||
_device = torch.device('cuda')
|
||||
else:
|
||||
_device = torch.device('cpu')
|
||||
|
||||
generator = FreGAN(h).to(_device)
|
||||
state_dict_g = load_checkpoint(
|
||||
weights_fpath, _device
|
||||
)
|
||||
generator.load_state_dict(state_dict_g['generator'])
|
||||
generator.eval()
|
||||
generator.remove_weight_norm()
|
||||
|
||||
|
||||
def is_loaded():
|
||||
return generator is not None
|
||||
|
||||
|
||||
def infer_waveform(mel, progress_callback=None):
|
||||
|
||||
if generator is None:
|
||||
raise Exception("Please load fre-gan in memory before using it")
|
||||
|
||||
mel = torch.FloatTensor(mel).to(_device)
|
||||
mel = mel.unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
y_g_hat = generator(mel)
|
||||
audio = y_g_hat.squeeze()
|
||||
audio = audio.cpu().numpy()
|
||||
|
||||
return audio, output_sample_rate
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
def feature_loss(fmap_r, fmap_g):
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
return loss*2
|
||||
|
||||
|
||||
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
r_loss = torch.mean((1-dr)**2)
|
||||
g_loss = torch.mean(dg**2)
|
||||
loss += (r_loss + g_loss)
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
|
||||
return loss, r_losses, g_losses
|
||||
|
||||
|
||||
def generator_loss(disc_outputs):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
l = torch.mean((1-dg)**2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
||||
@@ -1,176 +0,0 @@
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
from librosa.util import normalize
|
||||
from scipy.io.wavfile import read
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
||||
def load_wav(full_path):
|
||||
sampling_rate, data = read(full_path)
|
||||
return data, sampling_rate
|
||||
|
||||
|
||||
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression(x, C=1):
|
||||
return np.exp(x) / C
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes):
|
||||
output = dynamic_range_compression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(magnitudes):
|
||||
output = dynamic_range_decompression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
mel_basis = {}
|
||||
hann_window = {}
|
||||
|
||||
|
||||
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
if torch.min(y) < -1.:
|
||||
print('min value is ', torch.min(y))
|
||||
if torch.max(y) > 1.:
|
||||
print('max value is ', torch.max(y))
|
||||
|
||||
global mel_basis, hann_window
|
||||
if fmax not in mel_basis:
|
||||
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
||||
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
||||
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
||||
|
||||
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
||||
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
||||
|
||||
spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def get_dataset_filelist(a):
|
||||
#with open(a.input_training_file, 'r', encoding='utf-8') as fi:
|
||||
# training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
|
||||
# for x in fi.read().split('\n') if len(x) > 0]
|
||||
|
||||
#with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
|
||||
# validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
|
||||
# for x in fi.read().split('\n') if len(x) > 0]
|
||||
files = os.listdir(a.input_wavs_dir)
|
||||
random.shuffle(files)
|
||||
files = [os.path.join(a.input_wavs_dir, f) for f in files]
|
||||
training_files = files[: -int(len(files) * 0.05)]
|
||||
validation_files = files[-int(len(files) * 0.05):]
|
||||
return training_files, validation_files
|
||||
|
||||
|
||||
class MelDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, training_files, segment_size, n_fft, num_mels,
|
||||
hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
|
||||
device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
|
||||
self.audio_files = training_files
|
||||
random.seed(1234)
|
||||
if shuffle:
|
||||
random.shuffle(self.audio_files)
|
||||
self.segment_size = segment_size
|
||||
self.sampling_rate = sampling_rate
|
||||
self.split = split
|
||||
self.n_fft = n_fft
|
||||
self.num_mels = num_mels
|
||||
self.hop_size = hop_size
|
||||
self.win_size = win_size
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.fmax_loss = fmax_loss
|
||||
self.cached_wav = None
|
||||
self.n_cache_reuse = n_cache_reuse
|
||||
self._cache_ref_count = 0
|
||||
self.device = device
|
||||
self.fine_tuning = fine_tuning
|
||||
self.base_mels_path = base_mels_path
|
||||
|
||||
def __getitem__(self, index):
|
||||
filename = self.audio_files[index]
|
||||
if self._cache_ref_count == 0:
|
||||
#audio, sampling_rate = load_wav(filename)
|
||||
#audio = audio / MAX_WAV_VALUE
|
||||
audio = np.load(filename)
|
||||
if not self.fine_tuning:
|
||||
audio = normalize(audio) * 0.95
|
||||
self.cached_wav = audio
|
||||
#if sampling_rate != self.sampling_rate:
|
||||
# raise ValueError("{} SR doesn't match target {} SR".format(
|
||||
# sampling_rate, self.sampling_rate))
|
||||
self._cache_ref_count = self.n_cache_reuse
|
||||
else:
|
||||
audio = self.cached_wav
|
||||
self._cache_ref_count -= 1
|
||||
|
||||
audio = torch.FloatTensor(audio)
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
if not self.fine_tuning:
|
||||
if self.split:
|
||||
if audio.size(1) >= self.segment_size:
|
||||
max_audio_start = audio.size(1) - self.segment_size
|
||||
audio_start = random.randint(0, max_audio_start)
|
||||
audio = audio[:, audio_start:audio_start+self.segment_size]
|
||||
else:
|
||||
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
|
||||
|
||||
mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
|
||||
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
|
||||
center=False)
|
||||
else:
|
||||
mel_path = os.path.join(self.base_mels_path, "mel" + "-" + filename.split("/")[-1].split("-")[-1])
|
||||
mel = np.load(mel_path).T
|
||||
#mel = np.load(
|
||||
# os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
|
||||
mel = torch.from_numpy(mel)
|
||||
|
||||
if len(mel.shape) < 3:
|
||||
mel = mel.unsqueeze(0)
|
||||
|
||||
if self.split:
|
||||
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
|
||||
|
||||
if audio.size(1) >= self.segment_size:
|
||||
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
|
||||
mel = mel[:, :, mel_start:mel_start + frames_per_seg]
|
||||
audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
|
||||
else:
|
||||
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
|
||||
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
|
||||
|
||||
mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
|
||||
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
|
||||
center=False)
|
||||
|
||||
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audio_files)
|
||||
@@ -1,201 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class KernelPredictor(torch.nn.Module):
|
||||
''' Kernel predictor for the location-variable convolutions
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
cond_channels,
|
||||
conv_in_channels,
|
||||
conv_out_channels,
|
||||
conv_layers,
|
||||
conv_kernel_size=3,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
kpnet_nonlinear_activation="LeakyReLU",
|
||||
kpnet_nonlinear_activation_params={"negative_slope": 0.1}
|
||||
):
|
||||
'''
|
||||
Args:
|
||||
cond_channels (int): number of channel for the conditioning sequence,
|
||||
conv_in_channels (int): number of channel for the input sequence,
|
||||
conv_out_channels (int): number of channel for the output sequence,
|
||||
conv_layers (int):
|
||||
kpnet_
|
||||
'''
|
||||
super().__init__()
|
||||
|
||||
self.conv_in_channels = conv_in_channels
|
||||
self.conv_out_channels = conv_out_channels
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.conv_layers = conv_layers
|
||||
|
||||
l_w = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers
|
||||
l_b = conv_out_channels * conv_layers
|
||||
|
||||
padding = (kpnet_conv_size - 1) // 2
|
||||
self.input_conv = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=(5 - 1) // 2, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.residual_conv = torch.nn.Sequential(
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.kernel_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_w, kpnet_conv_size,
|
||||
padding=padding, bias=True)
|
||||
self.bias_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_b, kpnet_conv_size, padding=padding,
|
||||
bias=True)
|
||||
|
||||
def forward(self, c):
|
||||
'''
|
||||
Args:
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
Returns:
|
||||
'''
|
||||
batch, cond_channels, cond_length = c.shape
|
||||
|
||||
c = self.input_conv(c)
|
||||
c = c + self.residual_conv(c)
|
||||
k = self.kernel_conv(c)
|
||||
b = self.bias_conv(c)
|
||||
|
||||
kernels = k.contiguous().view(batch,
|
||||
self.conv_layers,
|
||||
self.conv_in_channels,
|
||||
self.conv_out_channels,
|
||||
self.conv_kernel_size,
|
||||
cond_length)
|
||||
bias = b.contiguous().view(batch,
|
||||
self.conv_layers,
|
||||
self.conv_out_channels,
|
||||
cond_length)
|
||||
return kernels, bias
|
||||
|
||||
|
||||
class LVCBlock(torch.nn.Module):
|
||||
''' the location-variable convolutions
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
cond_channels,
|
||||
upsample_ratio,
|
||||
conv_layers=4,
|
||||
conv_kernel_size=3,
|
||||
cond_hop_length=256,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cond_hop_length = cond_hop_length
|
||||
self.conv_layers = conv_layers
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.convs = torch.nn.ModuleList()
|
||||
|
||||
self.upsample = torch.nn.ConvTranspose1d(in_channels, in_channels,
|
||||
kernel_size=upsample_ratio*2, stride=upsample_ratio,
|
||||
padding=upsample_ratio // 2 + upsample_ratio % 2,
|
||||
output_padding=upsample_ratio % 2)
|
||||
|
||||
|
||||
self.kernel_predictor = KernelPredictor(
|
||||
cond_channels=cond_channels,
|
||||
conv_in_channels=in_channels,
|
||||
conv_out_channels=2 * in_channels,
|
||||
conv_layers=conv_layers,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
kpnet_hidden_channels=kpnet_hidden_channels,
|
||||
kpnet_conv_size=kpnet_conv_size,
|
||||
kpnet_dropout=kpnet_dropout
|
||||
)
|
||||
|
||||
|
||||
for i in range(conv_layers):
|
||||
padding = (3 ** i) * int((conv_kernel_size - 1) / 2)
|
||||
conv = torch.nn.Conv1d(in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3 ** i)
|
||||
|
||||
self.convs.append(conv)
|
||||
|
||||
|
||||
def forward(self, x, c):
|
||||
''' forward propagation of the location-variable convolutions.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length)
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
|
||||
Returns:
|
||||
Tensor: the output sequence (batch, in_channels, in_length)
|
||||
'''
|
||||
batch, in_channels, in_length = x.shape
|
||||
|
||||
|
||||
kernels, bias = self.kernel_predictor(c)
|
||||
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
x = self.upsample(x)
|
||||
|
||||
for i in range(self.conv_layers):
|
||||
y = F.leaky_relu(x, 0.2)
|
||||
y = self.convs[i](y)
|
||||
y = F.leaky_relu(y, 0.2)
|
||||
|
||||
k = kernels[:, i, :, :, :, :]
|
||||
b = bias[:, i, :, :]
|
||||
y = self.location_variable_convolution(y, k, b, 1, self.cond_hop_length)
|
||||
x = x + torch.sigmoid(y[:, :in_channels, :]) * torch.tanh(y[:, in_channels:, :])
|
||||
return x
|
||||
|
||||
def location_variable_convolution(self, x, kernel, bias, dilation, hop_size):
|
||||
''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
||||
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length).
|
||||
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
||||
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
||||
dilation (int): the dilation of convolution.
|
||||
hop_size (int): the hop_size of the conditioning sequence.
|
||||
Returns:
|
||||
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
||||
'''
|
||||
batch, in_channels, in_length = x.shape
|
||||
batch, in_channels, out_channels, kernel_size, kernel_length = kernel.shape
|
||||
|
||||
|
||||
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
|
||||
|
||||
padding = dilation * int((kernel_size - 1) / 2)
|
||||
x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding)
|
||||
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
||||
|
||||
if hop_size < dilation:
|
||||
x = F.pad(x, (0, dilation), 'constant', 0)
|
||||
x = x.unfold(3, dilation,
|
||||
dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
||||
x = x[:, :, :, :, :hop_size]
|
||||
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
||||
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
||||
|
||||
o = torch.einsum('bildsk,biokl->bolsd', x, kernel)
|
||||
o = o + bias.unsqueeze(-1).unsqueeze(-1)
|
||||
o = o.contiguous().view(batch, out_channels, -1)
|
||||
return o
|
||||
@@ -1,136 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Tomoki Hayashi
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
"""STFT-based Loss modules."""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def stft(x, fft_size, hop_size, win_length, window):
|
||||
"""Perform STFT and convert to magnitude spectrogram.
|
||||
Args:
|
||||
x (Tensor): Input signal tensor (B, T).
|
||||
fft_size (int): FFT size.
|
||||
hop_size (int): Hop size.
|
||||
win_length (int): Window length.
|
||||
window (str): Window function type.
|
||||
Returns:
|
||||
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
||||
"""
|
||||
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
|
||||
real = x_stft[..., 0]
|
||||
imag = x_stft[..., 1]
|
||||
|
||||
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
||||
return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
|
||||
|
||||
|
||||
class SpectralConvergengeLoss(torch.nn.Module):
|
||||
"""Spectral convergence loss module."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initilize spectral convergence loss module."""
|
||||
super(SpectralConvergengeLoss, self).__init__()
|
||||
|
||||
def forward(self, x_mag, y_mag):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
||||
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
||||
Returns:
|
||||
Tensor: Spectral convergence loss value.
|
||||
"""
|
||||
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
|
||||
|
||||
|
||||
class LogSTFTMagnitudeLoss(torch.nn.Module):
|
||||
"""Log STFT magnitude loss module."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initilize los STFT magnitude loss module."""
|
||||
super(LogSTFTMagnitudeLoss, self).__init__()
|
||||
|
||||
def forward(self, x_mag, y_mag):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
||||
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
||||
Returns:
|
||||
Tensor: Log STFT magnitude loss value.
|
||||
"""
|
||||
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
|
||||
|
||||
|
||||
class STFTLoss(torch.nn.Module):
|
||||
"""STFT loss module."""
|
||||
|
||||
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
|
||||
"""Initialize STFT loss module."""
|
||||
super(STFTLoss, self).__init__()
|
||||
self.fft_size = fft_size
|
||||
self.shift_size = shift_size
|
||||
self.win_length = win_length
|
||||
self.window = getattr(torch, window)(win_length)
|
||||
self.spectral_convergenge_loss = SpectralConvergengeLoss()
|
||||
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
||||
|
||||
def forward(self, x, y):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x (Tensor): Predicted signal (B, T).
|
||||
y (Tensor): Groundtruth signal (B, T).
|
||||
Returns:
|
||||
Tensor: Spectral convergence loss value.
|
||||
Tensor: Log STFT magnitude loss value.
|
||||
"""
|
||||
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window.to(x.get_device()))
|
||||
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(x.get_device()))
|
||||
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
||||
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
||||
|
||||
return sc_loss, mag_loss
|
||||
|
||||
|
||||
class MultiResolutionSTFTLoss(torch.nn.Module):
|
||||
"""Multi resolution STFT loss module."""
|
||||
|
||||
def __init__(self,
|
||||
fft_sizes=[1024, 2048, 512],
|
||||
hop_sizes=[120, 240, 50],
|
||||
win_lengths=[600, 1200, 240],
|
||||
window="hann_window"):
|
||||
"""Initialize Multi resolution STFT loss module.
|
||||
Args:
|
||||
fft_sizes (list): List of FFT sizes.
|
||||
hop_sizes (list): List of hop sizes.
|
||||
win_lengths (list): List of window lengths.
|
||||
window (str): Window function type.
|
||||
"""
|
||||
super(MultiResolutionSTFTLoss, self).__init__()
|
||||
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
|
||||
self.stft_losses = torch.nn.ModuleList()
|
||||
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
|
||||
self.stft_losses += [STFTLoss(fs, ss, wl, window)]
|
||||
|
||||
def forward(self, x, y):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x (Tensor): Predicted signal (B, T).
|
||||
y (Tensor): Groundtruth signal (B, T).
|
||||
Returns:
|
||||
Tensor: Multi resolution spectral convergence loss value.
|
||||
Tensor: Multi resolution log STFT magnitude loss value.
|
||||
"""
|
||||
sc_loss = 0.0
|
||||
mag_loss = 0.0
|
||||
for f in self.stft_losses:
|
||||
sc_l, mag_l = f(x, y)
|
||||
sc_loss += sc_l
|
||||
mag_loss += mag_l
|
||||
sc_loss /= len(self.stft_losses)
|
||||
mag_loss /= len(self.stft_losses)
|
||||
|
||||
return sc_loss, mag_loss
|
||||
@@ -1,246 +0,0 @@
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||
import itertools
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.utils.data import DistributedSampler, DataLoader
|
||||
from torch.distributed import init_process_group
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from vocoder.fregan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
|
||||
from vocoder.fregan.generator import FreGAN
|
||||
from vocoder.fregan.discriminator import ResWiseMultiPeriodDiscriminator, ResWiseMultiScaleDiscriminator
|
||||
from vocoder.fregan.loss import feature_loss, generator_loss, discriminator_loss
|
||||
from vocoder.fregan.utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
|
||||
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def train(rank, a, h):
|
||||
|
||||
a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_fregan')
|
||||
a.checkpoint_path.mkdir(exist_ok=True)
|
||||
a.training_epochs = 3100
|
||||
a.stdout_interval = 5
|
||||
a.checkpoint_interval = a.backup_every
|
||||
a.summary_interval = 5000
|
||||
a.validation_interval = 1000
|
||||
a.fine_tuning = True
|
||||
|
||||
a.input_wavs_dir = a.syn_dir.joinpath("audio")
|
||||
a.input_mels_dir = a.syn_dir.joinpath("mels")
|
||||
|
||||
if h.num_gpus > 1:
|
||||
init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
|
||||
world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
|
||||
|
||||
torch.cuda.manual_seed(h.seed)
|
||||
device = torch.device('cuda:{:d}'.format(rank))
|
||||
|
||||
generator = FreGAN(h).to(device)
|
||||
mpd = ResWiseMultiPeriodDiscriminator().to(device)
|
||||
msd = ResWiseMultiScaleDiscriminator().to(device)
|
||||
|
||||
if rank == 0:
|
||||
print(generator)
|
||||
os.makedirs(a.checkpoint_path, exist_ok=True)
|
||||
print("checkpoints directory : ", a.checkpoint_path)
|
||||
|
||||
if os.path.isdir(a.checkpoint_path):
|
||||
cp_g = scan_checkpoint(a.checkpoint_path, 'g_fregan_')
|
||||
cp_do = scan_checkpoint(a.checkpoint_path, 'do_fregan_')
|
||||
|
||||
steps = 0
|
||||
if cp_g is None or cp_do is None:
|
||||
state_dict_do = None
|
||||
last_epoch = -1
|
||||
else:
|
||||
state_dict_g = load_checkpoint(cp_g, device)
|
||||
state_dict_do = load_checkpoint(cp_do, device)
|
||||
generator.load_state_dict(state_dict_g['generator'])
|
||||
mpd.load_state_dict(state_dict_do['mpd'])
|
||||
msd.load_state_dict(state_dict_do['msd'])
|
||||
steps = state_dict_do['steps'] + 1
|
||||
last_epoch = state_dict_do['epoch']
|
||||
|
||||
if h.num_gpus > 1:
|
||||
generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
|
||||
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
|
||||
msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
|
||||
|
||||
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
||||
optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
|
||||
h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
||||
|
||||
if state_dict_do is not None:
|
||||
optim_g.load_state_dict(state_dict_do['optim_g'])
|
||||
optim_d.load_state_dict(state_dict_do['optim_d'])
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||
|
||||
training_filelist, validation_filelist = get_dataset_filelist(a)
|
||||
|
||||
trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
|
||||
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
|
||||
shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
|
||||
fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)
|
||||
|
||||
train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
|
||||
|
||||
train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
|
||||
sampler=train_sampler,
|
||||
batch_size=h.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=True)
|
||||
|
||||
if rank == 0:
|
||||
validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
|
||||
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
|
||||
fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
|
||||
base_mels_path=a.input_mels_dir)
|
||||
validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
|
||||
sampler=None,
|
||||
batch_size=1,
|
||||
pin_memory=True,
|
||||
drop_last=True)
|
||||
|
||||
sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
|
||||
|
||||
generator.train()
|
||||
mpd.train()
|
||||
msd.train()
|
||||
for epoch in range(max(0, last_epoch), a.training_epochs):
|
||||
if rank == 0:
|
||||
start = time.time()
|
||||
print("Epoch: {}".format(epoch + 1))
|
||||
|
||||
if h.num_gpus > 1:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
for i, batch in enumerate(train_loader):
|
||||
if rank == 0:
|
||||
start_b = time.time()
|
||||
x, y, _, y_mel = batch
|
||||
x = torch.autograd.Variable(x.to(device, non_blocking=True))
|
||||
y = torch.autograd.Variable(y.to(device, non_blocking=True))
|
||||
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
|
||||
y = y.unsqueeze(1)
|
||||
y_g_hat = generator(x)
|
||||
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size,
|
||||
h.win_size,
|
||||
h.fmin, h.fmax_for_loss)
|
||||
|
||||
if steps > h.disc_start_step:
|
||||
optim_d.zero_grad()
|
||||
|
||||
# MPD
|
||||
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
||||
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
||||
|
||||
# MSD
|
||||
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
|
||||
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
||||
|
||||
loss_disc_all = loss_disc_s + loss_disc_f
|
||||
|
||||
loss_disc_all.backward()
|
||||
optim_d.step()
|
||||
|
||||
# Generator
|
||||
optim_g.zero_grad()
|
||||
|
||||
|
||||
# L1 Mel-Spectrogram Loss
|
||||
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
|
||||
|
||||
# sc_loss, mag_loss = stft_loss(y_g_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1))
|
||||
# loss_mel = h.lambda_aux * (sc_loss + mag_loss) # STFT Loss
|
||||
|
||||
if steps > h.disc_start_step:
|
||||
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
||||
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
|
||||
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
||||
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
|
||||
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
||||
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
||||
loss_gen_all = loss_gen_s + loss_gen_f + (2 * (loss_fm_s + loss_fm_f)) + loss_mel
|
||||
else:
|
||||
loss_gen_all = loss_mel
|
||||
|
||||
loss_gen_all.backward()
|
||||
optim_g.step()
|
||||
|
||||
if rank == 0:
|
||||
# STDOUT logging
|
||||
if steps % a.stdout_interval == 0:
|
||||
with torch.no_grad():
|
||||
mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
|
||||
|
||||
print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
|
||||
format(steps, loss_gen_all, mel_error, time.time() - start_b))
|
||||
|
||||
# checkpointing
|
||||
if steps % a.checkpoint_interval == 0 and steps != 0:
|
||||
checkpoint_path = "{}/g_fregan_{:08d}.pt".format(a.checkpoint_path, steps)
|
||||
save_checkpoint(checkpoint_path,
|
||||
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
|
||||
checkpoint_path = "{}/do_fregan_{:08d}.pt".format(a.checkpoint_path, steps)
|
||||
save_checkpoint(checkpoint_path,
|
||||
{'mpd': (mpd.module if h.num_gpus > 1
|
||||
else mpd).state_dict(),
|
||||
'msd': (msd.module if h.num_gpus > 1
|
||||
else msd).state_dict(),
|
||||
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
|
||||
'epoch': epoch})
|
||||
|
||||
# Tensorboard summary logging
|
||||
if steps % a.summary_interval == 0:
|
||||
sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
|
||||
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
||||
|
||||
# Validation
|
||||
if steps % a.validation_interval == 0: # and steps != 0:
|
||||
generator.eval()
|
||||
torch.cuda.empty_cache()
|
||||
val_err_tot = 0
|
||||
with torch.no_grad():
|
||||
for j, batch in enumerate(validation_loader):
|
||||
x, y, _, y_mel = batch
|
||||
y_g_hat = generator(x.to(device))
|
||||
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
|
||||
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
|
||||
h.hop_size, h.win_size,
|
||||
h.fmin, h.fmax_for_loss)
|
||||
#val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
|
||||
|
||||
if j <= 4:
|
||||
if steps == 0:
|
||||
sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
|
||||
sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
|
||||
|
||||
sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
|
||||
y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
|
||||
h.sampling_rate, h.hop_size, h.win_size,
|
||||
h.fmin, h.fmax)
|
||||
sw.add_figure('generated/y_hat_spec_{}'.format(j),
|
||||
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
|
||||
|
||||
val_err = val_err_tot / (j + 1)
|
||||
sw.add_scalar("validation/mel_spec_error", val_err, steps)
|
||||
|
||||
generator.train()
|
||||
|
||||
steps += 1
|
||||
|
||||
scheduler_g.step()
|
||||
scheduler_d.step()
|
||||
|
||||
if rank == 0:
|
||||
print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
|
||||
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
import glob
|
||||
import os
|
||||
import matplotlib
|
||||
import torch
|
||||
from torch.nn.utils import weight_norm
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pylab as plt
|
||||
import shutil
|
||||
|
||||
|
||||
def build_env(config, config_name, path):
|
||||
t_path = os.path.join(path, config_name)
|
||||
if config != t_path:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
shutil.copyfile(config, os.path.join(path, config_name))
|
||||
|
||||
|
||||
def plot_spectrogram(spectrogram):
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
||||
interpolation='none')
|
||||
plt.colorbar(im, ax=ax)
|
||||
|
||||
fig.canvas.draw()
|
||||
plt.close()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def apply_weight_norm(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
weight_norm(m)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size*dilation - dilation)/2)
|
||||
|
||||
|
||||
def load_checkpoint(filepath, device):
|
||||
assert os.path.isfile(filepath)
|
||||
print("Loading '{}'".format(filepath))
|
||||
checkpoint_dict = torch.load(filepath, map_location=device)
|
||||
print("Complete.")
|
||||
return checkpoint_dict
|
||||
|
||||
|
||||
def save_checkpoint(filepath, obj):
|
||||
print("Saving checkpoint to {}".format(filepath))
|
||||
torch.save(obj, filepath)
|
||||
print("Complete.")
|
||||
|
||||
|
||||
def scan_checkpoint(cp_dir, prefix):
|
||||
pattern = os.path.join(cp_dir, prefix + '????????.pt')
|
||||
cp_list = glob.glob(pattern)
|
||||
if len(cp_list) == 0:
|
||||
return None
|
||||
return sorted(cp_list)[-1]
|
||||
@@ -7,7 +7,6 @@
|
||||
"adam_b2": 0.99,
|
||||
"lr_decay": 0.999,
|
||||
"seed": 1234,
|
||||
"disc_start_step":0,
|
||||
|
||||
"upsample_rates": [5,5,4,2],
|
||||
"upsample_kernel_sizes": [10,10,8,4],
|
||||
@@ -28,11 +27,5 @@
|
||||
"fmax": 7600,
|
||||
"fmax_for_loss": null,
|
||||
|
||||
"num_workers": 4,
|
||||
|
||||
"dist_config": {
|
||||
"dist_backend": "nccl",
|
||||
"dist_url": "tcp://localhost:54321",
|
||||
"world_size": 1
|
||||
}
|
||||
"num_workers": 4
|
||||
}
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
def build_env(config, config_name, path):
|
||||
t_path = os.path.join(path, config_name)
|
||||
if config != t_path:
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
from utils.util import AttrDict
|
||||
from vocoder.hifigan.env import AttrDict
|
||||
from vocoder.hifigan.models import Generator
|
||||
|
||||
generator = None # type: Generator
|
||||
@@ -19,18 +19,12 @@ def load_checkpoint(filepath, device):
|
||||
return checkpoint_dict
|
||||
|
||||
|
||||
def load_model(weights_fpath, config_fpath=None, verbose=True):
|
||||
def load_model(weights_fpath, config_fpath="./vocoder/saved_models/24k/config.json", verbose=True):
|
||||
global generator, _device, output_sample_rate
|
||||
|
||||
if verbose:
|
||||
print("Building hifigan")
|
||||
|
||||
if config_fpath == None:
|
||||
model_config_fpaths = list(weights_fpath.parent.rglob("*.json"))
|
||||
if len(model_config_fpaths) > 0:
|
||||
config_fpath = model_config_fpaths[0]
|
||||
else:
|
||||
config_fpath = "./vocoder/hifigan/config_16k_.json"
|
||||
with open(config_fpath) as f:
|
||||
data = f.read()
|
||||
json_config = json.loads(data)
|
||||
|
||||
@@ -12,6 +12,7 @@ from torch.utils.data import DistributedSampler, DataLoader
|
||||
import torch.multiprocessing as mp
|
||||
from torch.distributed import init_process_group
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from vocoder.hifigan.env import AttrDict, build_env
|
||||
from vocoder.hifigan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
|
||||
from vocoder.hifigan.models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\
|
||||
discriminator_loss
|
||||
@@ -51,8 +52,8 @@ def train(rank, a, h):
|
||||
print("checkpoints directory : ", a.checkpoint_path)
|
||||
|
||||
if os.path.isdir(a.checkpoint_path):
|
||||
cp_g = scan_checkpoint(a.checkpoint_path, 'g_hifigan_')
|
||||
cp_do = scan_checkpoint(a.checkpoint_path, 'do_hifigan_')
|
||||
cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
|
||||
cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
|
||||
|
||||
steps = 0
|
||||
if cp_g is None or cp_do is None:
|
||||
@@ -137,21 +138,21 @@ def train(rank, a, h):
|
||||
y_g_hat = generator(x)
|
||||
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
|
||||
h.fmin, h.fmax_for_loss)
|
||||
if steps > h.disc_start_step:
|
||||
optim_d.zero_grad()
|
||||
|
||||
# MPD
|
||||
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
||||
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
||||
optim_d.zero_grad()
|
||||
|
||||
# MSD
|
||||
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
|
||||
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
||||
# MPD
|
||||
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
||||
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
||||
|
||||
loss_disc_all = loss_disc_s + loss_disc_f
|
||||
# MSD
|
||||
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
|
||||
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
||||
|
||||
loss_disc_all.backward()
|
||||
optim_d.step()
|
||||
loss_disc_all = loss_disc_s + loss_disc_f
|
||||
|
||||
loss_disc_all.backward()
|
||||
optim_d.step()
|
||||
|
||||
# Generator
|
||||
optim_g.zero_grad()
|
||||
@@ -159,16 +160,13 @@ def train(rank, a, h):
|
||||
# L1 Mel-Spectrogram Loss
|
||||
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
|
||||
|
||||
if steps > h.disc_start_step:
|
||||
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
||||
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
|
||||
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
||||
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
|
||||
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
||||
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
||||
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
||||
else:
|
||||
loss_gen_all = loss_mel
|
||||
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
||||
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
|
||||
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
||||
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
|
||||
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
||||
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
||||
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
||||
|
||||
loss_gen_all.backward()
|
||||
optim_g.step()
|
||||
@@ -184,10 +182,10 @@ def train(rank, a, h):
|
||||
|
||||
# checkpointing
|
||||
if steps % a.checkpoint_interval == 0 and steps != 0:
|
||||
checkpoint_path = "{}/g_hifigan_{:08d}.pt".format(a.checkpoint_path, steps)
|
||||
checkpoint_path = "{}/g_{:08d}.pt".format(a.checkpoint_path, steps)
|
||||
save_checkpoint(checkpoint_path,
|
||||
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
|
||||
checkpoint_path = "{}/do_hifigan_{:08d}.pt".format(a.checkpoint_path, steps)
|
||||
checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
|
||||
save_checkpoint(checkpoint_path,
|
||||
{'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
|
||||
'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(),
|
||||
@@ -205,7 +203,7 @@ def train(rank, a, h):
|
||||
checkpoint_path = "{}/g_hifigan.pt".format(a.checkpoint_path)
|
||||
save_checkpoint(checkpoint_path,
|
||||
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
|
||||
checkpoint_path = "{}/do_hifigan.pt".format(a.checkpoint_path)
|
||||
checkpoint_path = "{}/do_hifigan".format(a.checkpoint_path)
|
||||
save_checkpoint(checkpoint_path,
|
||||
{'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
|
||||
'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(),
|
||||
|
||||
@@ -50,7 +50,7 @@ def save_checkpoint(filepath, obj):
|
||||
|
||||
|
||||
def scan_checkpoint(cp_dir, prefix):
|
||||
pattern = os.path.join(cp_dir, prefix + '????????.pt')
|
||||
pattern = os.path.join(cp_dir, prefix + '????????')
|
||||
cp_list = glob.glob(pattern)
|
||||
if len(cp_list) == 0:
|
||||
return None
|
||||
|
||||
@@ -59,7 +59,7 @@ class WaveRNN(nn.Module) :
|
||||
# Compute all gates for coarse and fine
|
||||
u = F.sigmoid(R_u + I_u + self.bias_u)
|
||||
r = F.sigmoid(R_r + I_r + self.bias_r)
|
||||
e = torch.tanh(r * R_e + I_e + self.bias_e)
|
||||
e = F.tanh(r * R_e + I_e + self.bias_e)
|
||||
hidden = u * prev_hidden + (1. - u) * e
|
||||
|
||||
# Split the hidden state
|
||||
@@ -118,7 +118,7 @@ class WaveRNN(nn.Module) :
|
||||
# Compute the coarse gates
|
||||
u = F.sigmoid(R_coarse_u + I_coarse_u + b_coarse_u)
|
||||
r = F.sigmoid(R_coarse_r + I_coarse_r + b_coarse_r)
|
||||
e = torch.tanh(r * R_coarse_e + I_coarse_e + b_coarse_e)
|
||||
e = F.tanh(r * R_coarse_e + I_coarse_e + b_coarse_e)
|
||||
hidden_coarse = u * hidden_coarse + (1. - u) * e
|
||||
|
||||
# Compute the coarse output
|
||||
@@ -138,7 +138,7 @@ class WaveRNN(nn.Module) :
|
||||
# Compute the fine gates
|
||||
u = F.sigmoid(R_fine_u + I_fine_u + b_fine_u)
|
||||
r = F.sigmoid(R_fine_r + I_fine_r + b_fine_r)
|
||||
e = torch.tanh(r * R_fine_e + I_fine_e + b_fine_e)
|
||||
e = F.tanh(r * R_fine_e + I_fine_e + b_fine_e)
|
||||
hidden_fine = u * hidden_fine + (1. - u) * e
|
||||
|
||||
# Compute the fine output
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
from utils.argutils import print_args
|
||||
from vocoder.wavernn.train import train
|
||||
from vocoder.hifigan.train import train as train_hifigan
|
||||
from vocoder.fregan.train import train as train_fregan
|
||||
from utils.util import AttrDict
|
||||
from vocoder.hifigan.env import AttrDict
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import json
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -63,30 +61,11 @@ if __name__ == "__main__":
|
||||
# Process the arguments
|
||||
if args.vocoder_type == "wavernn":
|
||||
# Run the training wavernn
|
||||
delattr(args, 'vocoder_type')
|
||||
delattr(args, 'config')
|
||||
train(**vars(args))
|
||||
elif args.vocoder_type == "hifigan":
|
||||
with open(args.config) as f:
|
||||
json_config = json.load(f)
|
||||
h = AttrDict(json_config)
|
||||
if h.num_gpus > 1:
|
||||
h.num_gpus = torch.cuda.device_count()
|
||||
h.batch_size = int(h.batch_size / h.num_gpus)
|
||||
print('Batch size per GPU :', h.batch_size)
|
||||
mp.spawn(train_hifigan, nprocs=h.num_gpus, args=(args, h,))
|
||||
else:
|
||||
train_hifigan(0, args, h)
|
||||
elif args.vocoder_type == "fregan":
|
||||
with open('vocoder/fregan/config.json') as f:
|
||||
json_config = json.load(f)
|
||||
h = AttrDict(json_config)
|
||||
if h.num_gpus > 1:
|
||||
h.num_gpus = torch.cuda.device_count()
|
||||
h.batch_size = int(h.batch_size / h.num_gpus)
|
||||
print('Batch size per GPU :', h.batch_size)
|
||||
mp.spawn(train_fregan, nprocs=h.num_gpus, args=(args, h,))
|
||||
else:
|
||||
train_fregan(0, args, h)
|
||||
train_hifigan(0, args, h)
|
||||
|
||||
|
||||
26
web.py
26
web.py
@@ -1,21 +1,11 @@
|
||||
import os
|
||||
import sys
|
||||
import typer
|
||||
from web import webApp
|
||||
from gevent import pywsgi as wsgi
|
||||
|
||||
cli = typer.Typer()
|
||||
|
||||
@cli.command()
|
||||
def launch(port: int = typer.Option(8080, "--port", "-p")) -> None:
|
||||
"""Start a graphical UI server for the opyrator.
|
||||
|
||||
The UI is auto-generated from the input- and output-schema of the given function.
|
||||
"""
|
||||
# Add the current working directory to the sys path
|
||||
# This is required to resolve the opyrator path
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from mkgui.base.ui.streamlit_ui import launch_ui
|
||||
launch_ui(port)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
app = webApp()
|
||||
host = app.config.get("HOST")
|
||||
port = app.config.get("PORT")
|
||||
print(f"Web server: http://{host}:{port}")
|
||||
server = wsgi.WSGIServer((host, port), app)
|
||||
server.serve_forever()
|
||||
|
||||
@@ -94,7 +94,7 @@ def webApp():
|
||||
embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
||||
|
||||
# Load input text
|
||||
texts = filter(None, request.form["text"].split("\n"))
|
||||
texts = request.form["text"].split("\n")
|
||||
punctuation = '!,。、,' # punctuate and split/clean text
|
||||
processed_texts = []
|
||||
for text in texts:
|
||||
@@ -109,7 +109,7 @@ def webApp():
|
||||
spec = np.concatenate(specs, axis=1)
|
||||
sample_rate = Synthesizer.sample_rate
|
||||
if "vocoder" in request.form and request.form["vocoder"] == "WaveRNN":
|
||||
wav, sample_rate = rnn_vocoder.infer_waveform(spec)
|
||||
wav = rnn_vocoder.infer_waveform(spec)
|
||||
else:
|
||||
wav, sample_rate = gan_vocoder.infer_waveform(spec)
|
||||
|
||||
@@ -132,4 +132,4 @@ def webApp():
|
||||
return app
|
||||
|
||||
if __name__ == "__main__":
|
||||
webApp()
|
||||
webApp()
|
||||
Reference in New Issue
Block a user