45 Commits

Author SHA1 Message Date
babysor00
f9ee4d7890 Avoid recursive calls of web ui for M1 2022-08-12 23:13:35 +08:00
babysor00
8d0d22bc00 Fix #657 2022-07-19 23:43:19 +08:00
babysor00
87f4859874 update launch json 2022-07-17 14:27:26 +08:00
babysor00
c3590bffb2 Add description for 2022-07-17 11:56:13 +08:00
babysor00
efbdb21b70 Refactor model 2022-07-17 11:55:53 +08:00
Vega
6abdd0ebf0 Refactor (#649)
* Refactor model

* Refactor and fix bug to save plots
2022-07-17 09:58:17 +08:00
wenqingl
400a7207e3 Update README.md (#640) 2022-07-14 17:41:00 +08:00
babysor00
6f023e313d Add web gui of training and reconstruct taco model methods 2022-06-26 23:21:32 +08:00
babysor00
a39b6d3117 Remove breaking import for Macos 2022-06-26 11:56:50 +08:00
babysor00
885225045d 修复兼容性 - mac + linux 2022-06-25 20:17:06 +08:00
babysor00
ee643d7cbc Fix compatibility issue 2022-06-18 23:46:44 +08:00
flysmart
6a793cea84 Added missing files for Fre-GAN (#579)
* The new vocoder Fre-GAN is now supported

* Improved some fregan details

* Fixed the problem that the existing model could not be loaded to continue training when training GAN

* Updated reference papers

* GAN training now supports DistributedDataParallel (DDP)

* Added requirements.txt

* GAN training uses single card training by default

* Added note about GAN vocoder training with multiple GPUs

* Added missing files for Fre-GAN
2022-05-25 23:29:59 +08:00
Evers
7317ba5ffe add gen_voice.py for handle by python command instead of demon_tool gui. (#560) 2022-05-22 16:28:58 +08:00
flysmart
05f886162c GAN training now supports DistributedDataParallel (DDP) (#558)
* The new vocoder Fre-GAN is now supported

* Improved some fregan details

* Fixed the problem that the existing model could not be loaded to continue training when training GAN

* Updated reference papers

* GAN training now supports DistributedDataParallel (DDP)

* Added requirements.txt

* GAN training uses single card training by default

* Added note about GAN vocoder training with multiple GPUs
2022-05-22 16:24:50 +08:00
babysor00
e726c2eb12 Merge branch 'main' of https://github.com/babysor/Realtime-Voice-Clone-Chinese 2022-05-15 16:09:09 +08:00
babysor00
c00474525a Fix nits of file path 2022-05-15 16:08:58 +08:00
flysmart
350b190662 Solved the problem that the existing model could not be loaded when training the GAN model (#549)
* The new vocoder Fre-GAN is now supported

* Improved some fregan details

* Fixed the problem that the existing model could not be loaded to continue training when training GAN

* Updated reference papers
2022-05-13 13:41:03 +08:00
flysmart
0caed984e3 The new vocoder Fre-GAN is now supported (#546)
* The new vocoder Fre-GAN is now supported

* Improved some fregan details
2022-05-12 12:27:17 +08:00
Vega
c5d03fb3cb Upgrade to new web service (#529)
* Init new GUI

* Remove unused codes

* Reset layout

* Add samples

* Make framework to support multiple pages

* Add vc mode

* Add preprocessing mode

* Add training mode

* Remove text input in vc mode

* Add entry for GUI and revise readme

* Move requirement together

* Add error raise when no model folder found

* Add readme
2022-05-09 18:44:02 +08:00
babysor00
7f799d322f Tell the hifigan type of a vocoder model by searching full text 2022-04-30 10:31:01 +08:00
LZY
a1f2e4a790 Update README-CN.md (#523) 2022-04-28 16:16:37 +08:00
Lix Zhou
b136f80f43 Add an aliyunpan download link (#505)
Baidu Yun Pan is so fxxking slow
2022-04-26 21:20:39 +08:00
Moose W. Oler
f082a82420 fix issue #496 (#498)
pass `wav`, `sampling_rate` (in encoder/audio.py line 59 ) as keyword args instead of postional args to prevent warning messages from massing up console outputs while adopting librosa 0.9.1 occasionally.
2022-04-11 17:26:52 +08:00
babysor00
7f0d983da7 Merge branch 'main' of https://github.com/babysor/Realtime-Voice-Clone-Chinese 2022-04-02 18:11:52 +08:00
babysor00
0353bfc6e6 New GUI in order to combine web and toolbox in future 2022-04-02 18:11:49 +08:00
Vega
9ec114a7c1 Create FUNDING.yml 2022-04-02 10:16:02 +08:00
1itt1eB0y
ddf612e87c Update README-CN.md (#470)
修正一个简单的翻译问题
2022-03-24 12:52:47 +08:00
babysor00
374cc89cfa Fix web generate with rnn bug 2022-03-19 12:16:55 +08:00
babysor00
6009da7072 Merge branch 'main' of https://github.com/babysor/Realtime-Voice-Clone-Chinese 2022-03-19 12:14:24 +08:00
babysor00
1c61a601d1 Remove dependency of pyworld for non-vc mode 2022-03-19 12:14:21 +08:00
Vega
02ee514aa3 Update issue templates 2022-03-12 19:28:03 +08:00
babysor00
6532d65153 Add default param for inferencing of vocoder 2022-03-12 19:14:01 +08:00
babysor00
3fe0690cc6 Update README-CN.md 2022-03-09 09:39:24 +08:00
babysor00
79f424d614 Add default path for hifi 2022-03-08 09:17:56 +08:00
babysor00
3c97d22938 Fix bug when searching for vocoder 2022-03-07 20:08:29 +08:00
babysor00
fc26c38152 Fix compatibility issue 2022-03-07 17:05:22 +08:00
babysor00
6c01b92703 Fix bug introduced by config file reading 2022-03-06 09:35:25 +08:00
babysor00
c36f02634a Add link and separate requirement file 2022-03-05 12:20:09 +08:00
Vega
b05e7441ff Fix nit in readme 2022-03-05 00:55:08 +08:00
Vega
693de98f4d Add instruction image 2022-03-05 00:54:31 +08:00
Vega
252a5e11b3 Ppg vc init (#421)
* Init  ppg extractor and ppg2mel

* add preprocess and training

* FIx known issues

* Update __init__.py

Allow to gen audio

* Fix length issue

* Fix bug of preparing fid

* Fix sample issues

* Add UI usage of PPG-vc

* Add readme
2022-03-05 00:52:36 +08:00
Vega
b617a87ee4 Init ppg extractor and ppg2mel (#375)
* Init  ppg extractor and ppg2mel

* add preprocess and training

* FIx known issues

* Update __init__.py

Allow to gen audio

* Fix length issue

* Fix bug of preparing fid

* Fix sample issues

* Add UI usage of PPG-vc
2022-03-03 23:38:12 +08:00
AyahaShirane
ad22997614 fixed the issues #372 (#379)
修复了一些参数传递造成的问题,把过时的torch.nn.functional.tanh()改成了torch.tanh()
2022-02-27 11:02:01 +08:00
hertz
9e072c2619 Hifigan Support train from existed checkpoint. (#389)
* 1k steps to save tmp hifigan model

* hifigan support train from existed ckpt
2022-02-27 11:01:47 +08:00
Alex Newton
b79e9d68e4 连续换行造成的多了个None (#405)
小问题,gui好像没有这个问题,自己测试web的时候直接调用的函数发现的这个情况
2022-02-27 10:55:00 +08:00
72 changed files with 4763 additions and 521 deletions

1
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1 @@
github: babysor

17
.github/ISSUE_TEMPLATE/issue.md vendored Normal file
View File

@@ -0,0 +1,17 @@
---
name: Issue
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
**Summary[问题简述(一句话)]**
A clear and concise description of what the issue is.
**Env & To Reproduce[复现与环境]**
描述你用的环境、代码版本、模型
**Screenshots[截图(如有)]**
If applicable, add screenshots to help

1
.gitignore vendored
View File

@@ -13,7 +13,6 @@
*.bbl
*.bcf
*.toc
*.wav
*.sh
*/saved_models
!vocoder/saved_models/pretrained/**

2
.vscode/launch.json vendored
View File

@@ -60,6 +60,6 @@
"args": ["-c", ".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml",
"-m", ".\\ppg2mel\\saved_models\\best_loss_step_304000.pth", "--wav_dir", ".\\wavs\\input", "--ref_wav_path", ".\\wavs\\pkq.mp3", "-o", ".\\wavs\\output\\"
]
},
}
]
}

View File

@@ -18,10 +18,19 @@
🌍 **Webserver Ready** 可伺服你的训练结果,供远程调用
### 进行中的工作
* GUI/客户端大升级与合并
[X] 初始化框架 `./mkgui` 基于streamlit + fastapi和 [技术设计](https://vaj2fgg8yn.feishu.cn/docs/doccnvotLWylBub8VJIjKzoEaee)
[X] 增加 Voice Cloning and Conversion的演示页面
[X] 增加Voice Conversion的预处理preprocessing 和训练 training 页面
[ ] 增加其他的的预处理preprocessing 和训练 training 页面
* 模型后端基于ESPnet2升级
## 开始
### 1. 安装要求
> 按照原始存储库测试您是否已准备好所有环境。
**Python 3.7 或更高版本** 需要运行工具箱
运行工具箱(demo_toolbox.py)需要 **Python 3.7 或更高版本**
* 安装 [PyTorch](https://pytorch.org/get-started/locally/)。
> 如果在用 pip 方式安装的时候出现 `ERROR: Could not find a version that satisfies the requirement torch==1.9.0+cu102 (from versions: 0.1.2, 0.1.2.post1, 0.1.2.post2)` 这个错误可能是 python 版本过低3.9 可以安装成功
@@ -68,7 +77,7 @@
对效果影响不大已经预置3款如果希望自己训练可以参考以下命令。
* 预处理数据:
`python vocoder_preprocess.py <datasets_root> -m <synthesizer_model_path>`
> `<datasets_root>`替换为你的数据集目录,`<synthesizer_model_path>`替换为一个你最好的synthesizer模型目录例如 *sythensizer\saved_mode\xxx*
> `<datasets_root>`替换为你的数据集目录,`<synthesizer_model_path>`替换为一个你最好的synthesizer模型目录例如 *sythensizer\saved_models\xxx*
* 训练wavernn声码器:
@@ -78,19 +87,17 @@
* 训练hifigan声码器:
`python vocoder_train.py <trainid> <datasets_root> hifigan`
> `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
* 训练fregan声码器:
`python vocoder_train.py <trainid> <datasets_root> --config config.json fregan`
> `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
* 将GAN声码器的训练切换为多GPU模式修改GAN文件夹下.json文件中的"num_gpus"参数
### 3. 启动程序或工具箱
您可以尝试使用以下命令:
### 3.1 启动Web程序
### 3.1 启动Web程序v2
`python web.py`
运行成功后在浏览器打开地址, 默认为 `http://localhost:8080`
![123](https://user-images.githubusercontent.com/12797292/135494044-ae59181c-fe3a-406f-9c7d-d21d12fdb4cb.png)
> 目前界面比较buggy,
> * 第一次点击`录制`要等待几秒浏览器正常启动录音,否则会有重音
> * 录制结束不要再点`录制`而是`停止`
> * 仅支持手动新录音16khz, 不支持超过4MB的录音最佳长度在5~15秒
> * 默认使用第一个找到的模型,有动手能力的可以看代码修改 `web\__init__.py`。
### 3.2 启动工具箱:
`python demo_toolbox.py -d <datasets_root>`
@@ -101,11 +108,12 @@
### 4. 番外语音转换Voice Conversion(PPG based)
想像柯南拿着变声器然后发出毛利小五郎的声音吗本项目现基于PPG-VC引入额外两个模块PPG extractor + PPG2Mel, 可以实现变声功能。(文档不全,尤其是训练部分,正在努力补充中)
#### 4.0 准备环境
* 确保项目以上环境已经安装ok运行`pip install -r requirements.txt` 来安装剩余的必要包。
* 下载以下模型
* 24K采样率专用的vocoderhifigan*vocoder\saved_mode\xxx*
* 预训练的ppg特征encoder(ppg_extractor)到 *ppg_extractor\saved_mode\xxx*
* 预训练的PPG2Mel到 *ppg2mel\saved_mode\xxx*
* 确保项目以上环境已经安装ok运行`pip install espnet` 来安装剩余的必要包。
* 下载以下模型 链接https://pan.baidu.com/s/1bl_x_DHJSAUyN2fma-Q_Wg
提取码gh41
* 24K采样率专用的vocoderhifigan*vocoder\saved_models\xxx*
* 预训练的ppg特征encoder(ppg_extractor)到 *ppg_extractor\saved_models\xxx*
* 预训练的PPG2Mel到 *ppg2mel\saved_models\xxx*
#### 4.1 使用数据集自己训练PPG2Mel模型 (可选)
@@ -123,8 +131,9 @@
#### 4.2 启动工具箱VC模式
您可以尝试使用以下命令:
`python demo_toolbox.py vc -d <datasets_root>`
`python demo_toolbox.py -vc -d <datasets_root>`
> 请指定一个可用的数据集文件路径,如果有支持的数据集则会自动加载供调试,也同时会作为手动录制音频的存储目录。
<img width="971" alt="微信图片_20220305005351" src="https://user-images.githubusercontent.com/7423248/156805733-2b093dbc-d989-4e68-8609-db11f365886a.png">
## 引用及论文
> 该库一开始从仅支持英语的[Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) 分叉出来的,鸣谢作者。
@@ -133,6 +142,7 @@
| --- | ----------- | ----- | --------------------- |
| [1803.09017](https://arxiv.org/abs/1803.09017) | GlobalStyleToken (synthesizer)| Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis | 本代码库 |
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | 本代码库 |
| [2106.02297](https://arxiv.org/abs/2106.02297) | Fre-GAN (vocoder)| Fre-GAN: Adversarial Frequency-consistent Audio Synthesis | 本代码库 |
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | SV2TTS | Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis | 本代码库 |
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|[1703.10135](https://arxiv.org/pdf/1703.10135.pdf) | Tacotron (synthesizer) | Tacotron: Towards End-to-End Speech Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN)

View File

@@ -18,6 +18,14 @@
### [DEMO VIDEO](https://www.bilibili.com/video/BV17Q4y1B7mY/)
### Ongoing Works(Helps Needed)
* Major upgrade on GUI/Client and unifying web and toolbox
[X] Init framework `./mkgui` and [tech design](https://vaj2fgg8yn.feishu.cn/docs/doccnvotLWylBub8VJIjKzoEaee)
[X] Add demo part of Voice Cloning and Conversion
[X] Add preprocessing and training for Voice Conversion
[ ] Add preprocessing and training for Encoder/Synthesizer/Vocoder
* Major upgrade on model backend based on ESPnet2(not yet started)
## Quick Start
### 1. Install Requirements
@@ -29,7 +37,7 @@
* Install [ffmpeg](https://ffmpeg.org/download.html#get-packages).
* Run `pip install -r requirements.txt` to install the remaining necessary packages.
* Install webrtcvad `pip install webrtcvad-wheels`(If you need)
> Note that we are using the pretrained encoder/vocoder but synthesizer, since the original model is incompatible with the Chinese sympols. It means the demo_cli is not working at this moment.
> Note that we are using the pretrained encoder/vocoder but synthesizer since the original model is incompatible with the Chinese symbols. It means the demo_cli is not working at this moment.
### 2. Prepare your models
You can either train your models or use existing ones:
@@ -60,7 +68,7 @@ Allowing parameter `--dataset {dataset}` to support aidatatang_200zh, magicdata,
| @author | https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g [Baidu](https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g) 4j5d | | 75k steps trained by multiple datasets
| @author | https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw [Baidu](https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw) codeom7f | | 25k steps trained by multiple datasets, only works under version 0.0.1
|@FawenYo | https://drive.google.com/file/d/1H-YGOUHpmqKxJ9FRc6vAjPuqQki24UbC/view?usp=sharing https://u.teknik.io/AYxWf.pt | [input](https://github.com/babysor/MockingBird/wiki/audio/self_test.mp3) [output](https://github.com/babysor/MockingBird/wiki/audio/export.wav) | 200k steps with local accent of Taiwan, only works under version 0.0.1
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ code2021 | https://www.bilibili.com/video/BV1uh411B7AD/ | only works under version 0.0.1
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ code: 2021 https://www.aliyundrive.com/s/AwPsbo8mcSP code: z2m0 | https://www.bilibili.com/video/BV1uh411B7AD/ | only works under version 0.0.1
#### 2.4 Train vocoder (Optional)
> note: vocoder has little difference in effect, so you may not need to train a new one.
@@ -82,6 +90,11 @@ You can then try to run:`python web.py` and open it in browser, default as `http
You can then try the toolbox:
`python demo_toolbox.py -d <datasets_root>`
#### 3.3 Using the command line
You can then try the command:
`python gen_voice.py <text_file.txt> your_wav_file.wav`
you may need to install cn2an by "pip install cn2an" for better digital number result.
## Reference
> This repository is forked from [Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) which only support English.
@@ -89,6 +102,7 @@ You can then try the toolbox:
| --- | ----------- | ----- | --------------------- |
| [1803.09017](https://arxiv.org/abs/1803.09017) | GlobalStyleToken (synthesizer)| Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis | This repo |
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | This repo |
| [2106.02297](https://arxiv.org/abs/2106.02297) | Fre-GAN (vocoder)| Fre-GAN: Adversarial Frequency-consistent Audio Synthesis | This repo |
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | **SV2TTS** | **Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis** | This repo |
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|[1703.10135](https://arxiv.org/pdf/1703.10135.pdf) | Tacotron (synthesizer) | Tacotron: Towards End-to-End Speech Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN)

View File

@@ -56,8 +56,8 @@ def wav_to_mel_spectrogram(wav):
Note: this not a log-mel spectrogram.
"""
frames = librosa.feature.melspectrogram(
wav,
sampling_rate,
y=wav,
sr=sampling_rate,
n_fft=int(sampling_rate * mel_window_length / 1000),
hop_length=int(sampling_rate * mel_window_step / 1000),
n_mels=mel_n_channels

128
gen_voice.py Normal file
View File

@@ -0,0 +1,128 @@
from encoder.params_model import model_embedding_size as speaker_embedding_size
from utils.argutils import print_args
from utils.modelutils import check_model_paths
from synthesizer.inference import Synthesizer
from encoder import inference as encoder
from vocoder.wavernn import inference as rnn_vocoder
from vocoder.hifigan import inference as gan_vocoder
from pathlib import Path
import numpy as np
import soundfile as sf
import librosa
import argparse
import torch
import sys
import os
import re
import cn2an
import glob
from audioread.exceptions import NoBackendError
vocoder = gan_vocoder
def gen_one_wav(synthesizer, in_fpath, embed, texts, file_name, seq):
embeds = [embed] * len(texts)
# If you know what the attention layer alignments are, you can retrieve them here by
# passing return_alignments=True
specs = synthesizer.synthesize_spectrograms(texts, embeds, style_idx=-1, min_stop_token=4, steps=400)
#spec = specs[0]
breaks = [spec.shape[1] for spec in specs]
spec = np.concatenate(specs, axis=1)
# If seed is specified, reset torch seed and reload vocoder
# Synthesizing the waveform is fairly straightforward. Remember that the longer the
# spectrogram, the more time-efficient the vocoder.
generated_wav, output_sample_rate = vocoder.infer_waveform(spec)
# Add breaks
b_ends = np.cumsum(np.array(breaks) * synthesizer.hparams.hop_size)
b_starts = np.concatenate(([0], b_ends[:-1]))
wavs = [generated_wav[start:end] for start, end, in zip(b_starts, b_ends)]
breaks = [np.zeros(int(0.15 * synthesizer.sample_rate))] * len(breaks)
generated_wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
## Post-generation
# There's a bug with sounddevice that makes the audio cut one second earlier, so we
# pad it.
# Trim excess silences to compensate for gaps in spectrograms (issue #53)
generated_wav = encoder.preprocess_wav(generated_wav)
generated_wav = generated_wav / np.abs(generated_wav).max() * 0.97
# Save it on the disk
model=os.path.basename(in_fpath)
filename = "%s_%d_%s.wav" %(file_name, seq, model)
sf.write(filename, generated_wav, synthesizer.sample_rate)
print("\nSaved output as %s\n\n" % filename)
def generate_wav(enc_model_fpath, syn_model_fpath, voc_model_fpath, in_fpath, input_txt, file_name):
if torch.cuda.is_available():
device_id = torch.cuda.current_device()
gpu_properties = torch.cuda.get_device_properties(device_id)
## Print some environment information (for debugging purposes)
print("Found %d GPUs available. Using GPU %d (%s) of compute capability %d.%d with "
"%.1fGb total memory.\n" %
(torch.cuda.device_count(),
device_id,
gpu_properties.name,
gpu_properties.major,
gpu_properties.minor,
gpu_properties.total_memory / 1e9))
else:
print("Using CPU for inference.\n")
print("Preparing the encoder, the synthesizer and the vocoder...")
encoder.load_model(enc_model_fpath)
synthesizer = Synthesizer(syn_model_fpath)
vocoder.load_model(voc_model_fpath)
encoder_wav = synthesizer.load_preprocess_wav(in_fpath)
embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
texts = input_txt.split("\n")
seq=0
each_num=1500
punctuation = '!,。、,' # punctuate and split/clean text
processed_texts = []
cur_num = 0
for text in texts:
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
if processed_text:
processed_texts.append(processed_text.strip())
cur_num += len(processed_text.strip())
if cur_num > each_num:
seq = seq +1
gen_one_wav(synthesizer, in_fpath, embed, processed_texts, file_name, seq)
processed_texts = []
cur_num = 0
if len(processed_texts)>0:
seq = seq +1
gen_one_wav(synthesizer, in_fpath, embed, processed_texts, file_name, seq)
if (len(sys.argv)>=3):
my_txt = ""
print("reading from :", sys.argv[1])
with open(sys.argv[1], "r") as f:
for line in f.readlines():
#line = line.strip('\n')
my_txt += line
txt_file_name = sys.argv[1]
wav_file_name = sys.argv[2]
output = cn2an.transform(my_txt, "an2cn")
print(output)
generate_wav(
Path("encoder/saved_models/pretrained.pt"),
Path("synthesizer/saved_models/mandarin.pt"),
Path("vocoder/saved_models/pretrained/g_hifigan.pt"), wav_file_name, output, txt_file_name
)
else:
print("please input the file name")
exit(1)

0
mkgui/__init__.py Normal file
View File

145
mkgui/app.py Normal file
View File

@@ -0,0 +1,145 @@
from pydantic import BaseModel, Field
import os
from pathlib import Path
from enum import Enum
from encoder import inference as encoder
import librosa
from scipy.io.wavfile import write
import re
import numpy as np
from mkgui.base.components.types import FileContent
from vocoder.hifigan import inference as gan_vocoder
from synthesizer.inference import Synthesizer
from typing import Any, Tuple
import matplotlib.pyplot as plt
# Constants
AUDIO_SAMPLES_DIR = f"samples{os.sep}"
SYN_MODELS_DIRT = f"synthesizer{os.sep}saved_models"
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
VOC_MODELS_DIRT = f"vocoder{os.sep}saved_models"
TEMP_SOURCE_AUDIO = f"wavs{os.sep}temp_source.wav"
TEMP_RESULT_AUDIO = f"wavs{os.sep}temp_result.wav"
if not os.path.isdir("wavs"):
os.makedirs("wavs")
# Load local sample audio as options TODO: load dataset
if os.path.isdir(AUDIO_SAMPLES_DIR):
audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
# Pre-Load models
if os.path.isdir(SYN_MODELS_DIRT):
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
print("Loaded synthesizer models: " + str(len(synthesizers)))
else:
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.")
if os.path.isdir(ENC_MODELS_DIRT):
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
print("Loaded encoders models: " + str(len(encoders)))
else:
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
if os.path.isdir(VOC_MODELS_DIRT):
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
print("Loaded vocoders models: " + str(len(synthesizers)))
else:
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
class Input(BaseModel):
message: str = Field(
..., example="欢迎使用工具箱, 现已支持中文输入!", alias="文本内容"
)
local_audio_file: audio_input_selection = Field(
..., alias="输入语音本地wav",
description="选择本地语音文件."
)
upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
description="拖拽或点击上传.", mime_type="audio/wav")
encoder: encoders = Field(
..., alias="编码模型",
description="选择语音编码模型文件."
)
synthesizer: synthesizers = Field(
..., alias="合成模型",
description="选择语音合成模型文件."
)
vocoder: vocoders = Field(
..., alias="语音解码模型",
description="选择语音解码模型文件(目前只支持HifiGan类型)."
)
class AudioEntity(BaseModel):
content: bytes
mel: Any
class Output(BaseModel):
__root__: Tuple[AudioEntity, AudioEntity]
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
"""Custom output UI.
If this method is implmeneted, it will be used instead of the default Output UI renderer.
"""
src, result = self.__root__
streamlit_app.subheader("Synthesized Audio")
streamlit_app.audio(result.content, format="audio/wav")
fig, ax = plt.subplots()
ax.imshow(src.mel, aspect="equal", interpolation="none")
ax.set_title("mel spectrogram(Source Audio)")
streamlit_app.pyplot(fig)
fig, ax = plt.subplots()
ax.imshow(result.mel, aspect="equal", interpolation="none")
ax.set_title("mel spectrogram(Result Audio)")
streamlit_app.pyplot(fig)
def synthesize(input: Input) -> Output:
"""synthesize(合成)"""
# load models
encoder.load_model(Path(input.encoder.value))
current_synt = Synthesizer(Path(input.synthesizer.value))
gan_vocoder.load_model(Path(input.vocoder.value))
# load file
if input.upload_audio_file != None:
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
f.write(input.upload_audio_file.as_bytes())
f.seek(0)
wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
else:
wav, sample_rate = librosa.load(input.local_audio_file.value)
write(TEMP_SOURCE_AUDIO, sample_rate, wav) #Make sure we get the correct wav
source_spec = Synthesizer.make_spectrogram(wav)
# preprocess
encoder_wav = encoder.preprocess_wav(wav, sample_rate)
embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
# Load input text
texts = filter(None, input.message.split("\n"))
punctuation = '!,。、,' # punctuate and split/clean text
processed_texts = []
for text in texts:
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
if processed_text:
processed_texts.append(processed_text.strip())
texts = processed_texts
# synthesize and vocode
embeds = [embed] * len(texts)
specs = current_synt.synthesize_spectrograms(texts, embeds)
spec = np.concatenate(specs, axis=1)
sample_rate = Synthesizer.sample_rate
wav, sample_rate = gan_vocoder.infer_waveform(spec)
# write and output
write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
with open(TEMP_SOURCE_AUDIO, "rb") as f:
source_file = f.read()
with open(TEMP_RESULT_AUDIO, "rb") as f:
result_file = f.read()
return Output(__root__=(AudioEntity(content=source_file, mel=source_spec), AudioEntity(content=result_file, mel=spec)))

166
mkgui/app_vc.py Normal file
View File

@@ -0,0 +1,166 @@
from synthesizer.inference import Synthesizer
from pydantic import BaseModel, Field
from encoder import inference as speacker_encoder
import torch
import os
from pathlib import Path
from enum import Enum
import ppg_extractor as Extractor
import ppg2mel as Convertor
import librosa
from scipy.io.wavfile import write
import re
import numpy as np
from mkgui.base.components.types import FileContent
from vocoder.hifigan import inference as gan_vocoder
from typing import Any, Tuple
import matplotlib.pyplot as plt
# Constants
AUDIO_SAMPLES_DIR = f'sample{os.sep}'
EXT_MODELS_DIRT = f'ppg_extractor{os.sep}saved_models'
CONV_MODELS_DIRT = f'ppg2mel{os.sep}saved_models'
VOC_MODELS_DIRT = f'vocoder{os.sep}saved_models'
TEMP_SOURCE_AUDIO = f'wavs{os.sep}temp_source.wav'
TEMP_TARGET_AUDIO = f'wavs{os.sep}temp_target.wav'
TEMP_RESULT_AUDIO = f'wavs{os.sep}temp_result.wav'
# Load local sample audio as options TODO: load dataset
if os.path.isdir(AUDIO_SAMPLES_DIR):
audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
# Pre-Load models
if os.path.isdir(EXT_MODELS_DIRT):
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
print("Loaded extractor models: " + str(len(extractors)))
else:
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
if os.path.isdir(CONV_MODELS_DIRT):
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
print("Loaded convertor models: " + str(len(convertors)))
else:
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
if os.path.isdir(VOC_MODELS_DIRT):
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
print("Loaded vocoders models: " + str(len(vocoders)))
else:
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
class Input(BaseModel):
local_audio_file: audio_input_selection = Field(
..., alias="输入语音本地wav",
description="选择本地语音文件."
)
upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
description="拖拽或点击上传.", mime_type="audio/wav")
local_audio_file_target: audio_input_selection = Field(
..., alias="目标语音本地wav",
description="选择本地语音文件."
)
upload_audio_file_target: FileContent = Field(default=None, alias="或上传目标语音",
description="拖拽或点击上传.", mime_type="audio/wav")
extractor: extractors = Field(
..., alias="编码模型",
description="选择语音编码模型文件."
)
convertor: convertors = Field(
..., alias="转换模型",
description="选择语音转换模型文件."
)
vocoder: vocoders = Field(
..., alias="语音解码模型",
description="选择语音解码模型文件(目前只支持HifiGan类型)."
)
class AudioEntity(BaseModel):
content: bytes
mel: Any
class Output(BaseModel):
__root__: Tuple[AudioEntity, AudioEntity, AudioEntity]
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
"""Custom output UI.
If this method is implmeneted, it will be used instead of the default Output UI renderer.
"""
src, target, result = self.__root__
streamlit_app.subheader("Synthesized Audio")
streamlit_app.audio(result.content, format="audio/wav")
fig, ax = plt.subplots()
ax.imshow(src.mel, aspect="equal", interpolation="none")
ax.set_title("mel spectrogram(Source Audio)")
streamlit_app.pyplot(fig)
fig, ax = plt.subplots()
ax.imshow(target.mel, aspect="equal", interpolation="none")
ax.set_title("mel spectrogram(Target Audio)")
streamlit_app.pyplot(fig)
fig, ax = plt.subplots()
ax.imshow(result.mel, aspect="equal", interpolation="none")
ax.set_title("mel spectrogram(Result Audio)")
streamlit_app.pyplot(fig)
def convert(input: Input) -> Output:
"""convert(转换)"""
# load models
extractor = Extractor.load_model(Path(input.extractor.value))
convertor = Convertor.load_model(Path(input.convertor.value))
# current_synt = Synthesizer(Path(input.synthesizer.value))
gan_vocoder.load_model(Path(input.vocoder.value))
# load file
if input.upload_audio_file != None:
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
f.write(input.upload_audio_file.as_bytes())
f.seek(0)
src_wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
else:
src_wav, sample_rate = librosa.load(input.local_audio_file.value)
write(TEMP_SOURCE_AUDIO, sample_rate, src_wav) #Make sure we get the correct wav
if input.upload_audio_file_target != None:
with open(TEMP_TARGET_AUDIO, "w+b") as f:
f.write(input.upload_audio_file_target.as_bytes())
f.seek(0)
ref_wav, _ = librosa.load(TEMP_TARGET_AUDIO)
else:
ref_wav, _ = librosa.load(input.local_audio_file_target.value)
write(TEMP_TARGET_AUDIO, sample_rate, ref_wav) #Make sure we get the correct wav
ppg = extractor.extract_from_wav(src_wav)
# Import necessary dependency of Voice Conversion
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
speacker_encoder.load_model(Path("encoder{os.sep}saved_models{os.sep}pretrained_bak_5805000.pt"))
embed = speacker_encoder.embed_utterance(ref_wav)
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
min_len = min(ppg.shape[1], len(lf0_uv))
ppg = ppg[:, :min_len]
lf0_uv = lf0_uv[:min_len]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_, mel_pred, att_ws = convertor.inference(
ppg,
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
spembs=torch.from_numpy(embed).unsqueeze(0).to(device),
)
mel_pred= mel_pred.transpose(0, 1)
breaks = [mel_pred.shape[1]]
mel_pred= mel_pred.detach().cpu().numpy()
# synthesize and vocode
wav, sample_rate = gan_vocoder.infer_waveform(mel_pred)
# write and output
write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
with open(TEMP_SOURCE_AUDIO, "rb") as f:
source_file = f.read()
with open(TEMP_TARGET_AUDIO, "rb") as f:
target_file = f.read()
with open(TEMP_RESULT_AUDIO, "rb") as f:
result_file = f.read()
return Output(__root__=(AudioEntity(content=source_file, mel=Synthesizer.make_spectrogram(src_wav)), AudioEntity(content=target_file, mel=Synthesizer.make_spectrogram(ref_wav)), AudioEntity(content=result_file, mel=Synthesizer.make_spectrogram(wav))))

2
mkgui/base/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
from .core import Opyrator

View File

@@ -0,0 +1 @@
from .fastapi_app import create_api

View File

@@ -0,0 +1,102 @@
"""Collection of utilities for FastAPI apps."""
import inspect
from typing import Any, Type
from fastapi import FastAPI, Form
from pydantic import BaseModel
def as_form(cls: Type[BaseModel]) -> Any:
"""Adds an as_form class method to decorated models.
The as_form class method can be used with FastAPI endpoints
"""
new_params = [
inspect.Parameter(
field.alias,
inspect.Parameter.POSITIONAL_ONLY,
default=(Form(field.default) if not field.required else Form(...)),
)
for field in cls.__fields__.values()
]
async def _as_form(**data): # type: ignore
return cls(**data)
sig = inspect.signature(_as_form)
sig = sig.replace(parameters=new_params)
_as_form.__signature__ = sig # type: ignore
setattr(cls, "as_form", _as_form)
return cls
def patch_fastapi(app: FastAPI) -> None:
"""Patch function to allow relative url resolution.
This patch is required to make fastapi fully functional with a relative url path.
This code snippet can be copy-pasted to any Fastapi application.
"""
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from starlette.requests import Request
from starlette.responses import HTMLResponse
async def redoc_ui_html(req: Request) -> HTMLResponse:
assert app.openapi_url is not None
redoc_ui = get_redoc_html(
openapi_url="./" + app.openapi_url.lstrip("/"),
title=app.title + " - Redoc UI",
)
return HTMLResponse(redoc_ui.body.decode("utf-8"))
async def swagger_ui_html(req: Request) -> HTMLResponse:
assert app.openapi_url is not None
swagger_ui = get_swagger_ui_html(
openapi_url="./" + app.openapi_url.lstrip("/"),
title=app.title + " - Swagger UI",
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
)
# insert request interceptor to have all request run on relativ path
request_interceptor = (
"requestInterceptor: (e) => {"
"\n\t\t\tvar url = window.location.origin + window.location.pathname"
'\n\t\t\turl = url.substring( 0, url.lastIndexOf( "/" ) + 1);'
"\n\t\t\turl = e.url.replace(/http(s)?:\/\/[^/]*\//i, url);" # noqa: W605
"\n\t\t\te.contextUrl = url"
"\n\t\t\te.url = url"
"\n\t\t\treturn e;}"
)
return HTMLResponse(
swagger_ui.body.decode("utf-8").replace(
"dom_id: '#swagger-ui',",
"dom_id: '#swagger-ui',\n\t\t" + request_interceptor + ",",
)
)
# remove old docs route and add our patched route
routes_new = []
for app_route in app.routes:
if app_route.path == "/docs": # type: ignore
continue
if app_route.path == "/redoc": # type: ignore
continue
routes_new.append(app_route)
app.router.routes = routes_new
assert app.docs_url is not None
app.add_route(app.docs_url, swagger_ui_html, include_in_schema=False)
assert app.redoc_url is not None
app.add_route(app.redoc_url, redoc_ui_html, include_in_schema=False)
# Make graphql realtive
from starlette import graphql
graphql.GRAPHIQL = graphql.GRAPHIQL.replace(
"({{REQUEST_PATH}}", '("." + {{REQUEST_PATH}}'
)

View File

View File

@@ -0,0 +1,43 @@
from typing import List
from pydantic import BaseModel
class ScoredLabel(BaseModel):
label: str
score: float
class ClassificationOutput(BaseModel):
__root__: List[ScoredLabel]
def __iter__(self): # type: ignore
return iter(self.__root__)
def __getitem__(self, item): # type: ignore
return self.__root__[item]
def render_output_ui(self, streamlit) -> None: # type: ignore
import plotly.express as px
sorted_predictions = sorted(
[prediction.dict() for prediction in self.__root__],
key=lambda k: k["score"],
)
num_labels = len(sorted_predictions)
if len(sorted_predictions) > 10:
num_labels = streamlit.slider(
"Maximum labels to show: ",
min_value=1,
max_value=len(sorted_predictions),
value=len(sorted_predictions),
)
fig = px.bar(
sorted_predictions[len(sorted_predictions) - num_labels :],
x="score",
y="label",
orientation="h",
)
streamlit.plotly_chart(fig, use_container_width=True)
# fig.show()

View File

@@ -0,0 +1,46 @@
import base64
from typing import Any, Dict, overload
class FileContent(str):
def as_bytes(self) -> bytes:
return base64.b64decode(self, validate=True)
def as_str(self) -> str:
return self.as_bytes().decode()
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
field_schema.update(format="byte")
@classmethod
def __get_validators__(cls) -> Any: # type: ignore
yield cls.validate
@classmethod
def validate(cls, value: Any) -> "FileContent":
if isinstance(value, FileContent):
return value
elif isinstance(value, str):
return FileContent(value)
elif isinstance(value, (bytes, bytearray, memoryview)):
return FileContent(base64.b64encode(value).decode())
else:
raise Exception("Wrong type")
# # 暂时无法使用,因为浏览器中没有考虑选择文件夹
# class DirectoryContent(FileContent):
# @classmethod
# def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
# field_schema.update(format="path")
# @classmethod
# def validate(cls, value: Any) -> "DirectoryContent":
# if isinstance(value, DirectoryContent):
# return value
# elif isinstance(value, str):
# return DirectoryContent(value)
# elif isinstance(value, (bytes, bytearray, memoryview)):
# return DirectoryContent(base64.b64encode(value).decode())
# else:
# raise Exception("Wrong type")

203
mkgui/base/core.py Normal file
View File

@@ -0,0 +1,203 @@
import importlib
import inspect
import re
from typing import Any, Callable, Type, Union, get_type_hints
from pydantic import BaseModel, parse_raw_as
from pydantic.tools import parse_obj_as
def name_to_title(name: str) -> str:
"""Converts a camelCase or snake_case name to title case."""
# If camelCase -> convert to snake case
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
# Convert to title case
return name.replace("_", " ").strip().title()
def is_compatible_type(type: Type) -> bool:
"""Returns `True` if the type is opyrator-compatible."""
try:
if issubclass(type, BaseModel):
return True
except Exception:
pass
try:
# valid list type
if type.__origin__ is list and issubclass(type.__args__[0], BaseModel):
return True
except Exception:
pass
return False
def get_input_type(func: Callable) -> Type:
"""Returns the input type of a given function (callable).
Args:
func: The function for which to get the input type.
Raises:
ValueError: If the function does not have a valid input type annotation.
"""
type_hints = get_type_hints(func)
if "input" not in type_hints:
raise ValueError(
"The callable MUST have a parameter with the name `input` with typing annotation. "
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
)
input_type = type_hints["input"]
if not is_compatible_type(input_type):
raise ValueError(
"The `input` parameter MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
)
# TODO: return warning if more than one input parameters
return input_type
def get_output_type(func: Callable) -> Type:
"""Returns the output type of a given function (callable).
Args:
func: The function for which to get the output type.
Raises:
ValueError: If the function does not have a valid output type annotation.
"""
type_hints = get_type_hints(func)
if "return" not in type_hints:
raise ValueError(
"The return type of the callable MUST be annotated with type hints."
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
)
output_type = type_hints["return"]
if not is_compatible_type(output_type):
raise ValueError(
"The return value MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
)
return output_type
def get_callable(import_string: str) -> Callable:
"""Import a callable from an string."""
callable_seperator = ":"
if callable_seperator not in import_string:
# Use dot as seperator
callable_seperator = "."
if callable_seperator not in import_string:
raise ValueError("The callable path MUST specify the function. ")
mod_name, callable_name = import_string.rsplit(callable_seperator, 1)
mod = importlib.import_module(mod_name)
return getattr(mod, callable_name)
class Opyrator:
def __init__(self, func: Union[Callable, str]) -> None:
if isinstance(func, str):
# Try to load the function from a string notion
self.function = get_callable(func)
else:
self.function = func
self._action = "Execute"
self._input_type = None
self._output_type = None
if not callable(self.function):
raise ValueError("The provided function parameters is not a callable.")
if inspect.isclass(self.function):
raise ValueError(
"The provided callable is an uninitialized Class. This is not allowed."
)
if inspect.isfunction(self.function):
# The provided callable is a function
self._input_type = get_input_type(self.function)
self._output_type = get_output_type(self.function)
try:
# Get name
self._name = name_to_title(self.function.__name__)
except Exception:
pass
try:
# Get description from function
doc_string = inspect.getdoc(self.function)
if doc_string:
self._action = doc_string
except Exception:
pass
elif hasattr(self.function, "__call__"):
# The provided callable is a function
self._input_type = get_input_type(self.function.__call__) # type: ignore
self._output_type = get_output_type(self.function.__call__) # type: ignore
try:
# Get name
self._name = name_to_title(type(self.function).__name__)
except Exception:
pass
try:
# Get action from
doc_string = inspect.getdoc(self.function.__call__) # type: ignore
if doc_string:
self._action = doc_string
if (
not self._action
or self._action == "Call"
):
# Get docstring from class instead of __call__ function
doc_string = inspect.getdoc(self.function)
if doc_string:
self._action = doc_string
except Exception:
pass
else:
raise ValueError("Unknown callable type.")
@property
def name(self) -> str:
return self._name
@property
def action(self) -> str:
return self._action
@property
def input_type(self) -> Any:
return self._input_type
@property
def output_type(self) -> Any:
return self._output_type
def __call__(self, input: Any, **kwargs: Any) -> Any:
input_obj = input
if isinstance(input, str):
# Allow json input
input_obj = parse_raw_as(self.input_type, input)
if isinstance(input, dict):
# Allow dict input
input_obj = parse_obj_as(self.input_type, input)
return self.function(input_obj, **kwargs)

View File

@@ -0,0 +1 @@
from .streamlit_ui import render_streamlit_ui

View File

@@ -0,0 +1,129 @@
from typing import Dict
def resolve_reference(reference: str, references: Dict) -> Dict:
return references[reference.split("/")[-1]]
def get_single_reference_item(property: Dict, references: Dict) -> Dict:
# Ref can either be directly in the properties or the first element of allOf
reference = property.get("$ref")
if reference is None:
reference = property["allOf"][0]["$ref"]
return resolve_reference(reference, references)
def is_single_string_property(property: Dict) -> bool:
return property.get("type") == "string"
def is_single_datetime_property(property: Dict) -> bool:
if property.get("type") != "string":
return False
return property.get("format") in ["date-time", "time", "date"]
def is_single_boolean_property(property: Dict) -> bool:
return property.get("type") == "boolean"
def is_single_number_property(property: Dict) -> bool:
return property.get("type") in ["integer", "number"]
def is_single_file_property(property: Dict) -> bool:
if property.get("type") != "string":
return False
# TODO: binary?
return property.get("format") == "byte"
def is_single_directory_property(property: Dict) -> bool:
if property.get("type") != "string":
return False
return property.get("format") == "path"
def is_multi_enum_property(property: Dict, references: Dict) -> bool:
if property.get("type") != "array":
return False
if property.get("uniqueItems") is not True:
# Only relevant if it is a set or other datastructures with unique items
return False
try:
_ = resolve_reference(property["items"]["$ref"], references)["enum"]
return True
except Exception:
return False
def is_single_enum_property(property: Dict, references: Dict) -> bool:
try:
_ = get_single_reference_item(property, references)["enum"]
return True
except Exception:
return False
def is_single_dict_property(property: Dict) -> bool:
if property.get("type") != "object":
return False
return "additionalProperties" in property
def is_single_reference(property: Dict) -> bool:
if property.get("type") is not None:
return False
return bool(property.get("$ref"))
def is_multi_file_property(property: Dict) -> bool:
if property.get("type") != "array":
return False
if property.get("items") is None:
return False
try:
# TODO: binary
return property["items"]["format"] == "byte"
except Exception:
return False
def is_single_object(property: Dict, references: Dict) -> bool:
try:
object_reference = get_single_reference_item(property, references)
if object_reference["type"] != "object":
return False
return "properties" in object_reference
except Exception:
return False
def is_property_list(property: Dict) -> bool:
if property.get("type") != "array":
return False
if property.get("items") is None:
return False
try:
return property["items"]["type"] in ["string", "number", "integer"]
except Exception:
return False
def is_object_list_property(property: Dict, references: Dict) -> bool:
if property.get("type") != "array":
return False
try:
object_reference = resolve_reference(property["items"]["$ref"], references)
if object_reference["type"] != "object":
return False
return "properties" in object_reference
except Exception:
return False

View File

@@ -0,0 +1,888 @@
import datetime
import inspect
import mimetypes
import sys
from os import getcwd, unlink
from platform import system
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Dict, List, Type
from PIL import Image
import pandas as pd
import streamlit as st
from fastapi.encoders import jsonable_encoder
from loguru import logger
from pydantic import BaseModel, ValidationError, parse_obj_as
from mkgui.base import Opyrator
from mkgui.base.core import name_to_title
from mkgui.base.ui import schema_utils
from mkgui.base.ui.streamlit_utils import CUSTOM_STREAMLIT_CSS
STREAMLIT_RUNNER_SNIPPET = """
from mkgui.base.ui import render_streamlit_ui
from mkgui.base import Opyrator
import streamlit as st
# TODO: Make it configurable
# Page config can only be setup once
st.set_page_config(
page_title="MockingBird",
page_icon="🧊",
layout="wide")
render_streamlit_ui()
"""
# with st.spinner("Loading MockingBird GUI. Please wait..."):
# opyrator = Opyrator("{opyrator_path}")
def launch_ui(port: int = 8501) -> None:
with NamedTemporaryFile(
suffix=".py", mode="w", encoding="utf-8", delete=False
) as f:
f.write(STREAMLIT_RUNNER_SNIPPET)
f.seek(0)
import subprocess
python_path = f'PYTHONPATH="$PYTHONPATH:{getcwd()}"'
if system() == "Windows":
python_path = f"set PYTHONPATH=%PYTHONPATH%;{getcwd()} &&"
subprocess.run(
f"""set STREAMLIT_GLOBAL_SHOW_WARNING_ON_DIRECT_EXECUTION=false""",
shell=True,
)
subprocess.run(
f"""{python_path} "{sys.executable}" -m streamlit run --server.port={port} --server.headless=True --runner.magicEnabled=False --server.maxUploadSize=50 --browser.gatherUsageStats=False {f.name}""",
shell=True,
)
f.close()
unlink(f.name)
def function_has_named_arg(func: Callable, parameter: str) -> bool:
try:
sig = inspect.signature(func)
for param in sig.parameters.values():
if param.name == "input":
return True
except Exception:
return False
return False
def has_output_ui_renderer(data_item: BaseModel) -> bool:
return hasattr(data_item, "render_output_ui")
def has_input_ui_renderer(input_class: Type[BaseModel]) -> bool:
return hasattr(input_class, "render_input_ui")
def is_compatible_audio(mime_type: str) -> bool:
return mime_type in ["audio/mpeg", "audio/ogg", "audio/wav"]
def is_compatible_image(mime_type: str) -> bool:
return mime_type in ["image/png", "image/jpeg"]
def is_compatible_video(mime_type: str) -> bool:
return mime_type in ["video/mp4"]
class InputUI:
def __init__(self, session_state, input_class: Type[BaseModel]):
self._session_state = session_state
self._input_class = input_class
self._schema_properties = input_class.schema(by_alias=True).get(
"properties", {}
)
self._schema_references = input_class.schema(by_alias=True).get(
"definitions", {}
)
def render_ui(self, streamlit_app_root) -> None:
if has_input_ui_renderer(self._input_class):
# The input model has a rendering function
# The rendering also returns the current state of input data
self._session_state.input_data = self._input_class.render_input_ui( # type: ignore
st, self._session_state.input_data
)
return
# print(self._schema_properties)
for property_key in self._schema_properties.keys():
property = self._schema_properties[property_key]
if not property.get("title"):
# Set property key as fallback title
property["title"] = name_to_title(property_key)
try:
if "input_data" in self._session_state:
self._store_value(
property_key,
self._render_property(streamlit_app_root, property_key, property),
)
except Exception as e:
print("Exception!", e)
pass
def _get_default_streamlit_input_kwargs(self, key: str, property: Dict) -> Dict:
streamlit_kwargs = {
"label": property.get("title"),
"key": key,
}
if property.get("description"):
streamlit_kwargs["help"] = property.get("description")
return streamlit_kwargs
def _store_value(self, key: str, value: Any) -> None:
data_element = self._session_state.input_data
key_elements = key.split(".")
for i, key_element in enumerate(key_elements):
if i == len(key_elements) - 1:
# add value to this element
data_element[key_element] = value
return
if key_element not in data_element:
data_element[key_element] = {}
data_element = data_element[key_element]
def _get_value(self, key: str) -> Any:
data_element = self._session_state.input_data
key_elements = key.split(".")
for i, key_element in enumerate(key_elements):
if i == len(key_elements) - 1:
# add value to this element
if key_element not in data_element:
return None
return data_element[key_element]
if key_element not in data_element:
data_element[key_element] = {}
data_element = data_element[key_element]
return None
def _render_single_datetime_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
if property.get("format") == "time":
if property.get("default"):
try:
streamlit_kwargs["value"] = datetime.time.fromisoformat( # type: ignore
property.get("default")
)
except Exception:
pass
return streamlit_app.time_input(**streamlit_kwargs)
elif property.get("format") == "date":
if property.get("default"):
try:
streamlit_kwargs["value"] = datetime.date.fromisoformat( # type: ignore
property.get("default")
)
except Exception:
pass
return streamlit_app.date_input(**streamlit_kwargs)
elif property.get("format") == "date-time":
if property.get("default"):
try:
streamlit_kwargs["value"] = datetime.datetime.fromisoformat( # type: ignore
property.get("default")
)
except Exception:
pass
with streamlit_app.container():
streamlit_app.subheader(streamlit_kwargs.get("label"))
if streamlit_kwargs.get("description"):
streamlit_app.text(streamlit_kwargs.get("description"))
selected_date = None
selected_time = None
date_col, time_col = streamlit_app.columns(2)
with date_col:
date_kwargs = {"label": "Date", "key": key + "-date-input"}
if streamlit_kwargs.get("value"):
try:
date_kwargs["value"] = streamlit_kwargs.get( # type: ignore
"value"
).date()
except Exception:
pass
selected_date = streamlit_app.date_input(**date_kwargs)
with time_col:
time_kwargs = {"label": "Time", "key": key + "-time-input"}
if streamlit_kwargs.get("value"):
try:
time_kwargs["value"] = streamlit_kwargs.get( # type: ignore
"value"
).time()
except Exception:
pass
selected_time = streamlit_app.time_input(**time_kwargs)
return datetime.datetime.combine(selected_date, selected_time)
else:
streamlit_app.warning(
"Date format is not supported: " + str(property.get("format"))
)
def _render_single_file_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
file_extension = None
if "mime_type" in property:
file_extension = mimetypes.guess_extension(property["mime_type"])
uploaded_file = streamlit_app.file_uploader(
**streamlit_kwargs, accept_multiple_files=False, type=file_extension
)
if uploaded_file is None:
return None
bytes = uploaded_file.getvalue()
if property.get("mime_type"):
if is_compatible_audio(property["mime_type"]):
# Show audio
streamlit_app.audio(bytes, format=property.get("mime_type"))
if is_compatible_image(property["mime_type"]):
# Show image
streamlit_app.image(bytes)
if is_compatible_video(property["mime_type"]):
# Show video
streamlit_app.video(bytes, format=property.get("mime_type"))
return bytes
def _render_single_string_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
if property.get("default"):
streamlit_kwargs["value"] = property.get("default")
elif property.get("example"):
# TODO: also use example for other property types
# Use example as value if it is provided
streamlit_kwargs["value"] = property.get("example")
if property.get("maxLength") is not None:
streamlit_kwargs["max_chars"] = property.get("maxLength")
if (
property.get("format")
or (
property.get("maxLength") is not None
and int(property.get("maxLength")) < 140 # type: ignore
)
or property.get("writeOnly")
):
# If any format is set, use single text input
# If max chars is set to less than 140, use single text input
# If write only -> password field
if property.get("writeOnly"):
streamlit_kwargs["type"] = "password"
return streamlit_app.text_input(**streamlit_kwargs)
else:
# Otherwise use multiline text area
return streamlit_app.text_area(**streamlit_kwargs)
def _render_multi_enum_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
reference_item = schema_utils.resolve_reference(
property["items"]["$ref"], self._schema_references
)
# TODO: how to select defaults
return streamlit_app.multiselect(
**streamlit_kwargs, options=reference_item["enum"]
)
def _render_single_enum_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
reference_item = schema_utils.get_single_reference_item(
property, self._schema_references
)
if property.get("default") is not None:
try:
streamlit_kwargs["index"] = reference_item["enum"].index(
property.get("default")
)
except Exception:
# Use default selection
pass
return streamlit_app.selectbox(
**streamlit_kwargs, options=reference_item["enum"]
)
def _render_single_dict_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
# Add title and subheader
streamlit_app.subheader(property.get("title"))
if property.get("description"):
streamlit_app.markdown(property.get("description"))
streamlit_app.markdown("---")
current_dict = self._get_value(key)
if not current_dict:
current_dict = {}
key_col, value_col = streamlit_app.columns(2)
with key_col:
updated_key = streamlit_app.text_input(
"Key", value="", key=key + "-new-key"
)
with value_col:
# TODO: also add boolean?
value_kwargs = {"label": "Value", "key": key + "-new-value"}
if property["additionalProperties"].get("type") == "integer":
value_kwargs["value"] = 0 # type: ignore
updated_value = streamlit_app.number_input(**value_kwargs)
elif property["additionalProperties"].get("type") == "number":
value_kwargs["value"] = 0.0 # type: ignore
value_kwargs["format"] = "%f"
updated_value = streamlit_app.number_input(**value_kwargs)
else:
value_kwargs["value"] = ""
updated_value = streamlit_app.text_input(**value_kwargs)
streamlit_app.markdown("---")
with streamlit_app.container():
clear_col, add_col = streamlit_app.columns([1, 2])
with clear_col:
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
current_dict = {}
with add_col:
if (
streamlit_app.button("Add Item", key=key + "-add-item")
and updated_key
):
current_dict[updated_key] = updated_value
streamlit_app.write(current_dict)
return current_dict
def _render_single_reference(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
reference_item = schema_utils.get_single_reference_item(
property, self._schema_references
)
return self._render_property(streamlit_app, key, reference_item)
def _render_multi_file_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
file_extension = None
if "mime_type" in property:
file_extension = mimetypes.guess_extension(property["mime_type"])
uploaded_files = streamlit_app.file_uploader(
**streamlit_kwargs, accept_multiple_files=True, type=file_extension
)
uploaded_files_bytes = []
if uploaded_files:
for uploaded_file in uploaded_files:
uploaded_files_bytes.append(uploaded_file.read())
return uploaded_files_bytes
def _render_single_boolean_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
if property.get("default"):
streamlit_kwargs["value"] = property.get("default")
return streamlit_app.checkbox(**streamlit_kwargs)
def _render_single_number_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
number_transform = int
if property.get("type") == "number":
number_transform = float # type: ignore
streamlit_kwargs["format"] = "%f"
if "multipleOf" in property:
# Set stepcount based on multiple of parameter
streamlit_kwargs["step"] = number_transform(property["multipleOf"])
elif number_transform == int:
# Set step size to 1 as default
streamlit_kwargs["step"] = 1
elif number_transform == float:
# Set step size to 0.01 as default
# TODO: adapt to default value
streamlit_kwargs["step"] = 0.01
if "minimum" in property:
streamlit_kwargs["min_value"] = number_transform(property["minimum"])
if "exclusiveMinimum" in property:
streamlit_kwargs["min_value"] = number_transform(
property["exclusiveMinimum"] + streamlit_kwargs["step"]
)
if "maximum" in property:
streamlit_kwargs["max_value"] = number_transform(property["maximum"])
if "exclusiveMaximum" in property:
streamlit_kwargs["max_value"] = number_transform(
property["exclusiveMaximum"] - streamlit_kwargs["step"]
)
if property.get("default") is not None:
streamlit_kwargs["value"] = number_transform(property.get("default")) # type: ignore
else:
if "min_value" in streamlit_kwargs:
streamlit_kwargs["value"] = streamlit_kwargs["min_value"]
elif number_transform == int:
streamlit_kwargs["value"] = 0
else:
# Set default value to step
streamlit_kwargs["value"] = number_transform(streamlit_kwargs["step"])
if "min_value" in streamlit_kwargs and "max_value" in streamlit_kwargs:
# TODO: Only if less than X steps
return streamlit_app.slider(**streamlit_kwargs)
else:
return streamlit_app.number_input(**streamlit_kwargs)
def _render_object_input(self, streamlit_app: st, key: str, property: Dict) -> Any:
properties = property["properties"]
object_inputs = {}
for property_key in properties:
property = properties[property_key]
if not property.get("title"):
# Set property key as fallback title
property["title"] = name_to_title(property_key)
# construct full key based on key parts -> required later to get the value
full_key = key + "." + property_key
object_inputs[property_key] = self._render_property(
streamlit_app, full_key, property
)
return object_inputs
def _render_single_object_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
# Add title and subheader
title = property.get("title")
streamlit_app.subheader(title)
if property.get("description"):
streamlit_app.markdown(property.get("description"))
object_reference = schema_utils.get_single_reference_item(
property, self._schema_references
)
return self._render_object_input(streamlit_app, key, object_reference)
def _render_property_list_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
# Add title and subheader
streamlit_app.subheader(property.get("title"))
if property.get("description"):
streamlit_app.markdown(property.get("description"))
streamlit_app.markdown("---")
current_list = self._get_value(key)
if not current_list:
current_list = []
value_kwargs = {"label": "Value", "key": key + "-new-value"}
if property["items"]["type"] == "integer":
value_kwargs["value"] = 0 # type: ignore
new_value = streamlit_app.number_input(**value_kwargs)
elif property["items"]["type"] == "number":
value_kwargs["value"] = 0.0 # type: ignore
value_kwargs["format"] = "%f"
new_value = streamlit_app.number_input(**value_kwargs)
else:
value_kwargs["value"] = ""
new_value = streamlit_app.text_input(**value_kwargs)
streamlit_app.markdown("---")
with streamlit_app.container():
clear_col, add_col = streamlit_app.columns([1, 2])
with clear_col:
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
current_list = []
with add_col:
if (
streamlit_app.button("Add Item", key=key + "-add-item")
and new_value is not None
):
current_list.append(new_value)
streamlit_app.write(current_list)
return current_list
def _render_object_list_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
# TODO: support max_items, and min_items properties
# Add title and subheader
streamlit_app.subheader(property.get("title"))
if property.get("description"):
streamlit_app.markdown(property.get("description"))
streamlit_app.markdown("---")
current_list = self._get_value(key)
if not current_list:
current_list = []
object_reference = schema_utils.resolve_reference(
property["items"]["$ref"], self._schema_references
)
input_data = self._render_object_input(streamlit_app, key, object_reference)
streamlit_app.markdown("---")
with streamlit_app.container():
clear_col, add_col = streamlit_app.columns([1, 2])
with clear_col:
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
current_list = []
with add_col:
if (
streamlit_app.button("Add Item", key=key + "-add-item")
and input_data
):
current_list.append(input_data)
streamlit_app.write(current_list)
return current_list
def _render_property(self, streamlit_app: st, key: str, property: Dict) -> Any:
if schema_utils.is_single_enum_property(property, self._schema_references):
return self._render_single_enum_input(streamlit_app, key, property)
if schema_utils.is_multi_enum_property(property, self._schema_references):
return self._render_multi_enum_input(streamlit_app, key, property)
if schema_utils.is_single_file_property(property):
return self._render_single_file_input(streamlit_app, key, property)
if schema_utils.is_multi_file_property(property):
return self._render_multi_file_input(streamlit_app, key, property)
if schema_utils.is_single_datetime_property(property):
return self._render_single_datetime_input(streamlit_app, key, property)
if schema_utils.is_single_boolean_property(property):
return self._render_single_boolean_input(streamlit_app, key, property)
if schema_utils.is_single_dict_property(property):
return self._render_single_dict_input(streamlit_app, key, property)
if schema_utils.is_single_number_property(property):
return self._render_single_number_input(streamlit_app, key, property)
if schema_utils.is_single_string_property(property):
return self._render_single_string_input(streamlit_app, key, property)
if schema_utils.is_single_object(property, self._schema_references):
return self._render_single_object_input(streamlit_app, key, property)
if schema_utils.is_object_list_property(property, self._schema_references):
return self._render_object_list_input(streamlit_app, key, property)
if schema_utils.is_property_list(property):
return self._render_property_list_input(streamlit_app, key, property)
if schema_utils.is_single_reference(property):
return self._render_single_reference(streamlit_app, key, property)
streamlit_app.warning(
"The type of the following property is currently not supported: "
+ str(property.get("title"))
)
raise Exception("Unsupported property")
class OutputUI:
def __init__(self, output_data: Any, input_data: Any):
self._output_data = output_data
self._input_data = input_data
def render_ui(self, streamlit_app) -> None:
try:
if isinstance(self._output_data, BaseModel):
self._render_single_output(streamlit_app, self._output_data)
return
if type(self._output_data) == list:
self._render_list_output(streamlit_app, self._output_data)
return
except Exception as ex:
streamlit_app.exception(ex)
# Fallback to
streamlit_app.json(jsonable_encoder(self._output_data))
def _render_single_text_property(
self, streamlit: st, property_schema: Dict, value: Any
) -> None:
# Add title and subheader
streamlit.subheader(property_schema.get("title"))
if property_schema.get("description"):
streamlit.markdown(property_schema.get("description"))
if value is None or value == "":
streamlit.info("No value returned!")
else:
streamlit.code(str(value), language="plain")
def _render_single_file_property(
self, streamlit: st, property_schema: Dict, value: Any
) -> None:
# Add title and subheader
streamlit.subheader(property_schema.get("title"))
if property_schema.get("description"):
streamlit.markdown(property_schema.get("description"))
if value is None or value == "":
streamlit.info("No value returned!")
else:
# TODO: Detect if it is a FileContent instance
# TODO: detect if it is base64
file_extension = ""
if "mime_type" in property_schema:
mime_type = property_schema["mime_type"]
file_extension = mimetypes.guess_extension(mime_type) or ""
if is_compatible_audio(mime_type):
streamlit.audio(value.as_bytes(), format=mime_type)
return
if is_compatible_image(mime_type):
streamlit.image(value.as_bytes())
return
if is_compatible_video(mime_type):
streamlit.video(value.as_bytes(), format=mime_type)
return
filename = (
(property_schema["title"] + file_extension)
.lower()
.strip()
.replace(" ", "-")
)
streamlit.markdown(
f'<a href="data:application/octet-stream;base64,{value}" download="{filename}"><input type="button" value="Download File"></a>',
unsafe_allow_html=True,
)
def _render_single_complex_property(
self, streamlit: st, property_schema: Dict, value: Any
) -> None:
# Add title and subheader
streamlit.subheader(property_schema.get("title"))
if property_schema.get("description"):
streamlit.markdown(property_schema.get("description"))
streamlit.json(jsonable_encoder(value))
def _render_single_output(self, streamlit: st, output_data: BaseModel) -> None:
try:
if has_output_ui_renderer(output_data):
if function_has_named_arg(output_data.render_output_ui, "input"): # type: ignore
# render method also requests the input data
output_data.render_output_ui(streamlit, input=self._input_data) # type: ignore
else:
output_data.render_output_ui(streamlit) # type: ignore
return
except Exception:
# Use default auto-generation methods if the custom rendering throws an exception
logger.exception(
"Failed to execute custom render_output_ui function. Using auto-generation instead"
)
model_schema = output_data.schema(by_alias=False)
model_properties = model_schema.get("properties")
definitions = model_schema.get("definitions")
if model_properties:
for property_key in output_data.__dict__:
property_schema = model_properties.get(property_key)
if not property_schema.get("title"):
# Set property key as fallback title
property_schema["title"] = property_key
output_property_value = output_data.__dict__[property_key]
if has_output_ui_renderer(output_property_value):
output_property_value.render_output_ui(streamlit) # type: ignore
continue
if isinstance(output_property_value, BaseModel):
# Render output recursivly
streamlit.subheader(property_schema.get("title"))
if property_schema.get("description"):
streamlit.markdown(property_schema.get("description"))
self._render_single_output(streamlit, output_property_value)
continue
if property_schema:
if schema_utils.is_single_file_property(property_schema):
self._render_single_file_property(
streamlit, property_schema, output_property_value
)
continue
if (
schema_utils.is_single_string_property(property_schema)
or schema_utils.is_single_number_property(property_schema)
or schema_utils.is_single_datetime_property(property_schema)
or schema_utils.is_single_boolean_property(property_schema)
):
self._render_single_text_property(
streamlit, property_schema, output_property_value
)
continue
if definitions and schema_utils.is_single_enum_property(
property_schema, definitions
):
self._render_single_text_property(
streamlit, property_schema, output_property_value.value
)
continue
# TODO: render dict as table
self._render_single_complex_property(
streamlit, property_schema, output_property_value
)
return
def _render_list_output(self, streamlit: st, output_data: List) -> None:
try:
data_items: List = []
for data_item in output_data:
if has_output_ui_renderer(data_item):
# Render using the render function
data_item.render_output_ui(streamlit) # type: ignore
continue
data_items.append(data_item.dict())
# Try to show as dataframe
streamlit.table(pd.DataFrame(data_items))
except Exception:
# Fallback to
streamlit.json(jsonable_encoder(output_data))
def getOpyrator(mode: str) -> Opyrator:
if mode == None or mode.startswith('VC'):
from mkgui.app_vc import convert
return Opyrator(convert)
if mode == None or mode.startswith('预处理'):
from mkgui.preprocess import preprocess
return Opyrator(preprocess)
if mode == None or mode.startswith('模型训练'):
from mkgui.train import train
return Opyrator(train)
if mode == None or mode.startswith('模型训练(VC)'):
from mkgui.train_vc import train_vc
return Opyrator(train_vc)
from mkgui.app import synthesize
return Opyrator(synthesize)
def render_streamlit_ui() -> None:
# init
session_state = st.session_state
session_state.input_data = {}
# Add custom css settings
st.markdown(f"<style>{CUSTOM_STREAMLIT_CSS}</style>", unsafe_allow_html=True)
with st.spinner("Loading MockingBird GUI. Please wait..."):
session_state.mode = st.sidebar.selectbox(
'模式选择',
( "AI拟音", "VC拟音", "预处理", "模型训练", "模型训练(VC)")
)
if "mode" in session_state:
mode = session_state.mode
else:
mode = ""
opyrator = getOpyrator(mode)
title = opyrator.name + mode
col1, col2, _ = st.columns(3)
col2.title(title)
col2.markdown("欢迎使用MockingBird Web 2")
image = Image.open('.\\mkgui\\static\\mb.png')
col1.image(image)
st.markdown("---")
left, right = st.columns([0.4, 0.6])
with left:
st.header("Control 控制")
InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(st)
execute_selected = st.button(opyrator.action)
if execute_selected:
with st.spinner("Executing operation. Please wait..."):
try:
input_data_obj = parse_obj_as(
opyrator.input_type, session_state.input_data
)
session_state.output_data = opyrator(input=input_data_obj)
session_state.latest_operation_input = input_data_obj # should this really be saved as additional session object?
except ValidationError as ex:
st.error(ex)
else:
# st.success("Operation executed successfully.")
pass
with right:
st.header("Result 结果")
if 'output_data' in session_state:
OutputUI(
session_state.output_data, session_state.latest_operation_input
).render_ui(st)
if st.button("Clear"):
# Clear all state
for key in st.session_state.keys():
del st.session_state[key]
session_state.input_data = {}
st.experimental_rerun()
else:
# placeholder
st.caption("请使用左侧控制板进行输入并运行获得结果")

View File

@@ -0,0 +1,13 @@
CUSTOM_STREAMLIT_CSS = """
div[data-testid="stBlock"] button {
width: 100% !important;
margin-bottom: 20px !important;
border-color: #bfbfbf !important;
}
section[data-testid="stSidebar"] div {
max-width: 10rem;
}
pre code {
white-space: pre-wrap;
}
"""

96
mkgui/preprocess.py Normal file
View File

@@ -0,0 +1,96 @@
from pydantic import BaseModel, Field
import os
from pathlib import Path
from enum import Enum
from typing import Any, Tuple
# Constants
EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
if os.path.isdir(EXT_MODELS_DIRT):
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
print("Loaded extractor models: " + str(len(extractors)))
else:
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
if os.path.isdir(ENC_MODELS_DIRT):
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
print("Loaded encoders models: " + str(len(encoders)))
else:
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
class Model(str, Enum):
VC_PPG2MEL = "ppg2mel"
class Dataset(str, Enum):
AIDATATANG_200ZH = "aidatatang_200zh"
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
class Input(BaseModel):
# def render_input_ui(st, input) -> Dict:
# input["selected_dataset"] = st.selectbox(
# '选择数据集',
# ("aidatatang_200zh", "aidatatang_200zh_s")
# )
# return input
model: Model = Field(
Model.VC_PPG2MEL, title="目标模型",
)
dataset: Dataset = Field(
Dataset.AIDATATANG_200ZH, title="数据集选择",
)
datasets_root: str = Field(
..., alias="数据集根目录", description="输入数据集根目录(相对/绝对)",
format=True,
example="..\\trainning_data\\"
)
output_root: str = Field(
..., alias="输出根目录", description="输出结果根目录(相对/绝对)",
format=True,
example="..\\trainning_data\\"
)
n_processes: int = Field(
2, alias="处理线程数", description="根据CPU线程数来设置",
le=32, ge=1
)
extractor: extractors = Field(
..., alias="特征提取模型",
description="选择PPG特征提取模型文件."
)
encoder: encoders = Field(
..., alias="语音编码模型",
description="选择语音编码模型文件."
)
class AudioEntity(BaseModel):
content: bytes
mel: Any
class Output(BaseModel):
__root__: Tuple[str, int]
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
"""Custom output UI.
If this method is implmeneted, it will be used instead of the default Output UI renderer.
"""
sr, count = self.__root__
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
def preprocess(input: Input) -> Output:
"""Preprocess(预处理)"""
finished = 0
if input.model == Model.VC_PPG2MEL:
from ppg2mel.preprocess import preprocess_dataset
finished = preprocess_dataset(
datasets_root=Path(input.datasets_root),
dataset=input.dataset,
out_dir=Path(input.output_root),
n_processes=input.n_processes,
ppg_encoder_model_fpath=Path(input.extractor.value),
speaker_encoder_model=Path(input.encoder.value)
)
# TODO: pass useful return code
return Output(__root__=(input.dataset, finished))

BIN
mkgui/static/mb.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.6 KiB

106
mkgui/train.py Normal file
View File

@@ -0,0 +1,106 @@
from pydantic import BaseModel, Field
import os
from pathlib import Path
from enum import Enum
from typing import Any
from synthesizer.hparams import hparams
from synthesizer.train import train as synt_train
# Constants
SYN_MODELS_DIRT = f"synthesizer{os.sep}saved_models"
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
# EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
# CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models"
# ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
# Pre-Load models
if os.path.isdir(SYN_MODELS_DIRT):
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
print("Loaded synthesizer models: " + str(len(synthesizers)))
else:
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.")
if os.path.isdir(ENC_MODELS_DIRT):
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
print("Loaded encoders models: " + str(len(encoders)))
else:
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
class Model(str, Enum):
DEFAULT = "default"
class Input(BaseModel):
model: Model = Field(
Model.DEFAULT, title="模型类型",
)
# datasets_root: str = Field(
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
# format=True,
# example="..\\trainning_data\\"
# )
input_root: str = Field(
..., alias="输入目录", description="预处理数据根目录",
format=True,
example=f"..{os.sep}audiodata{os.sep}SV2TTS{os.sep}synthesizer"
)
run_id: str = Field(
"", alias="新模型名/运行ID", description="使用新ID进行重新训练否则选择下面的模型进行继续训练",
)
synthesizer: synthesizers = Field(
..., alias="已有合成模型",
description="选择语音合成模型文件."
)
gpu: bool = Field(
True, alias="GPU训练", description="选择“是”则使用GPU训练",
)
verbose: bool = Field(
True, alias="打印详情", description="选择“是”,输出更多详情",
)
encoder: encoders = Field(
..., alias="语音编码模型",
description="选择语音编码模型文件."
)
save_every: int = Field(
1000, alias="更新间隔", description="每隔n步则更新一次模型",
)
backup_every: int = Field(
10000, alias="保存间隔", description="每隔n步则保存一次模型",
)
log_every: int = Field(
500, alias="打印间隔", description="每隔n步则打印一次训练统计",
)
class AudioEntity(BaseModel):
content: bytes
mel: Any
class Output(BaseModel):
__root__: int
def render_output_ui(self, streamlit_app) -> None: # type: ignore
"""Custom output UI.
If this method is implmeneted, it will be used instead of the default Output UI renderer.
"""
streamlit_app.subheader(f"Training started with code: {self.__root__}")
def train(input: Input) -> Output:
"""Train(训练)"""
print(">>> Start training ...")
force_restart = len(input.run_id) > 0
if not force_restart:
input.run_id = Path(input.synthesizer.value).name.split('.')[0]
synt_train(
input.run_id,
input.input_root,
f"synthesizer{os.sep}saved_models",
input.save_every,
input.backup_every,
input.log_every,
force_restart,
hparams
)
return Output(__root__=0)

155
mkgui/train_vc.py Normal file
View File

@@ -0,0 +1,155 @@
from pydantic import BaseModel, Field
import os
from pathlib import Path
from enum import Enum
from typing import Any, Tuple
import numpy as np
from utils.load_yaml import HpsYaml
from utils.util import AttrDict
import torch
# Constants
EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models"
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
if os.path.isdir(EXT_MODELS_DIRT):
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
print("Loaded extractor models: " + str(len(extractors)))
else:
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
if os.path.isdir(CONV_MODELS_DIRT):
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
print("Loaded convertor models: " + str(len(convertors)))
else:
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
if os.path.isdir(ENC_MODELS_DIRT):
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
print("Loaded encoders models: " + str(len(encoders)))
else:
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
class Model(str, Enum):
VC_PPG2MEL = "ppg2mel"
class Dataset(str, Enum):
AIDATATANG_200ZH = "aidatatang_200zh"
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
class Input(BaseModel):
# def render_input_ui(st, input) -> Dict:
# input["selected_dataset"] = st.selectbox(
# '选择数据集',
# ("aidatatang_200zh", "aidatatang_200zh_s")
# )
# return input
model: Model = Field(
Model.VC_PPG2MEL, title="模型类型",
)
# datasets_root: str = Field(
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
# format=True,
# example="..\\trainning_data\\"
# )
output_root: str = Field(
..., alias="输出目录(可选)", description="建议不填,保持默认",
format=True,
example=""
)
continue_mode: bool = Field(
True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练",
)
gpu: bool = Field(
True, alias="GPU训练", description="选择“是”则使用GPU训练",
)
verbose: bool = Field(
True, alias="打印详情", description="选择“是”,输出更多详情",
)
# TODO: Move to hiden fields by default
convertor: convertors = Field(
..., alias="转换模型",
description="选择语音转换模型文件."
)
extractor: extractors = Field(
..., alias="特征提取模型",
description="选择PPG特征提取模型文件."
)
encoder: encoders = Field(
..., alias="语音编码模型",
description="选择语音编码模型文件."
)
njobs: int = Field(
8, alias="进程数", description="适用于ppg2mel",
)
seed: int = Field(
default=0, alias="初始随机数", description="适用于ppg2mel",
)
model_name: str = Field(
..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效",
example="test"
)
model_config: str = Field(
..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效",
example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2"
)
class AudioEntity(BaseModel):
content: bytes
mel: Any
class Output(BaseModel):
__root__: Tuple[str, int]
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
"""Custom output UI.
If this method is implmeneted, it will be used instead of the default Output UI renderer.
"""
sr, count = self.__root__
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
def train_vc(input: Input) -> Output:
"""Train VC(训练 VC)"""
print(">>> OneShot VC training ...")
params = AttrDict()
params.update({
"gpu": input.gpu,
"cpu": not input.gpu,
"njobs": input.njobs,
"seed": input.seed,
"verbose": input.verbose,
"load": input.convertor.value,
"warm_start": False,
})
if input.continue_mode:
# trace old model and config
p = Path(input.convertor.value)
params.name = p.parent.name
# search a config file
model_config_fpaths = list(p.parent.rglob("*.yaml"))
if len(model_config_fpaths) == 0:
raise "No model yaml config found for convertor"
config = HpsYaml(model_config_fpaths[0])
params.ckpdir = p.parent.parent
params.config = model_config_fpaths[0]
params.logdir = os.path.join(p.parent, "log")
else:
# Make the config dict dot visitable
config = HpsYaml(input.config)
np.random.seed(input.seed)
torch.manual_seed(input.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(input.seed)
mode = "train"
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
solver = Solver(config, params, mode)
solver.load_data()
solver.set_model()
solver.exec()
print(">>> Oneshot VC train finished!")
# TODO: pass useful return code
return Output(__root__=(input.dataset, 0))

View File

@@ -191,12 +191,15 @@ class MelDecoderMOLv2(AbsMelDecoder):
return mel_outputs[0], mel_outputs_postnet[0], alignments[0]
def load_model(train_config, model_file, device=None):
def load_model(model_file, device=None):
# search a config file
model_config_fpaths = list(model_file.parent.rglob("*.yaml"))
if len(model_config_fpaths) == 0:
raise "No model yaml config found for convertor"
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_config = HpsYaml(train_config)
model_config = HpsYaml(model_config_fpaths[0])
ppg2mel_model = MelDecoderMOLv2(
**model_config["model"]
).to(device)

View File

@@ -110,3 +110,4 @@ def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder
t_fid_file.close()
d_fid_file.close()
e_fid_file.close()
return len(wav_file_list)

View File

@@ -31,15 +31,10 @@ def main():
parser.add_argument('--njobs', default=8, type=int,
help='Number of threads for dataloader/decoding.', required=False)
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
parser.add_argument('--no-pin', action='store_true',
help='Disable pin-memory for dataloader')
parser.add_argument('--test', action='store_true', help='Test the model.')
# parser.add_argument('--no-pin', action='store_true',
# help='Disable pin-memory for dataloader')
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
parser.add_argument('--finetune', action='store_true', help='Finetune model')
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
###
paras = parser.parse_args()

View File

@@ -93,6 +93,7 @@ class BaseSolver():
def load_ckpt(self):
''' Load ckpt if --load option is specified '''
print(self.paras)
if self.paras.load is not None:
if self.paras.warm_start:
self.verbose(f"Warm starting model from checkpoint {self.paras.load}.")
@@ -100,7 +101,7 @@ class BaseSolver():
self.paras.load, map_location=self.device if self.mode == 'train'
else 'cpu')
model_dict = ckpt['model']
if len(self.config.model.ignore_layers) > 0:
if "ignore_layers" in self.config.model and len(self.config.model.ignore_layers) > 0:
model_dict = {k:v for k, v in model_dict.items()
if k not in self.config.model.ignore_layers}
dummy_dict = self.model.state_dict()

View File

@@ -21,6 +21,8 @@ flask_cors==3.0.10
gevent==21.8.0
flask_restx
tensorboard
streamlit==1.8.0
PyYAML==5.4.1
torch_complex
espnet
espnet
PyWavelets

142
run.py Normal file
View File

@@ -0,0 +1,142 @@
import time
import os
import argparse
import torch
import numpy as np
import glob
from pathlib import Path
from tqdm import tqdm
from ppg_extractor import load_model
import librosa
import soundfile as sf
from utils.load_yaml import HpsYaml
from encoder.audio import preprocess_wav
from encoder import inference as speacker_encoder
from vocoder.hifigan import inference as vocoder
from ppg2mel import MelDecoderMOLv2
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
def _build_ppg2mel_model(model_config, model_file, device):
ppg2mel_model = MelDecoderMOLv2(
**model_config["model"]
).to(device)
ckpt = torch.load(model_file, map_location=device)
ppg2mel_model.load_state_dict(ckpt["model"])
ppg2mel_model.eval()
return ppg2mel_model
@torch.no_grad()
def convert(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
step = os.path.basename(args.ppg2mel_model_file)[:-4].split("_")[-1]
# Build models
print("Load PPG-model, PPG2Mel-model, Vocoder-model...")
ppg_model = load_model(
Path('./ppg_extractor/saved_models/24epoch.pt'),
device,
)
ppg2mel_model = _build_ppg2mel_model(HpsYaml(args.ppg2mel_model_train_config), args.ppg2mel_model_file, device)
# vocoder.load_model('./vocoder/saved_models/pretrained/g_hifigan.pt', "./vocoder/hifigan/config_16k_.json")
vocoder.load_model('./vocoder/saved_models/24k/g_02830000.pt')
# Data related
ref_wav_path = args.ref_wav_path
ref_wav = preprocess_wav(ref_wav_path)
ref_fid = os.path.basename(ref_wav_path)[:-4]
# TODO: specify encoder
speacker_encoder.load_model(Path("encoder/saved_models/pretrained_bak_5805000.pt"))
ref_spk_dvec = speacker_encoder.embed_utterance(ref_wav)
ref_spk_dvec = torch.from_numpy(ref_spk_dvec).unsqueeze(0).to(device)
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
source_file_list = sorted(glob.glob(f"{args.wav_dir}/*.wav"))
print(f"Number of source utterances: {len(source_file_list)}.")
total_rtf = 0.0
cnt = 0
for src_wav_path in tqdm(source_file_list):
# Load the audio to a numpy array:
src_wav, _ = librosa.load(src_wav_path, sr=16000)
src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(device)
src_wav_lengths = torch.LongTensor([len(src_wav)]).to(device)
ppg = ppg_model(src_wav_tensor, src_wav_lengths)
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
min_len = min(ppg.shape[1], len(lf0_uv))
ppg = ppg[:, :min_len]
lf0_uv = lf0_uv[:min_len]
start = time.time()
_, mel_pred, att_ws = ppg2mel_model.inference(
ppg,
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
spembs=ref_spk_dvec,
)
src_fid = os.path.basename(src_wav_path)[:-4]
wav_fname = f"{output_dir}/vc_{src_fid}_ref_{ref_fid}_step{step}.wav"
mel_len = mel_pred.shape[0]
rtf = (time.time() - start) / (0.01 * mel_len)
total_rtf += rtf
cnt += 1
# continue
mel_pred= mel_pred.transpose(0, 1)
y, output_sample_rate = vocoder.infer_waveform(mel_pred.cpu())
sf.write(wav_fname, y.squeeze(), output_sample_rate, "PCM_16")
print("RTF:")
print(total_rtf / cnt)
def get_parser():
parser = argparse.ArgumentParser(description="Conversion from wave input")
parser.add_argument(
"--wav_dir",
type=str,
default=None,
required=True,
help="Source wave directory.",
)
parser.add_argument(
"--ref_wav_path",
type=str,
required=True,
help="Reference wave file path.",
)
parser.add_argument(
"--ppg2mel_model_train_config", "-c",
type=str,
default=None,
required=True,
help="Training config file (yaml file)",
)
parser.add_argument(
"--ppg2mel_model_file", "-m",
type=str,
default=None,
required=True,
help="ppg2mel model checkpoint file path"
)
parser.add_argument(
"--output_dir", "-o",
type=str,
default="vc_gens_vctk_oneshot",
help="Output folder to save the converted wave."
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
convert(args)
if __name__ == "__main__":
main()

BIN
samples/T0055G0013S0005.wav Normal file

Binary file not shown.

View File

@@ -167,7 +167,7 @@ def _mel_to_linear(mel_spectrogram, hparams):
def _build_mel_basis(hparams):
assert hparams.fmax <= hparams.sample_rate // 2
return librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels,
return librosa.filters.mel(sr=hparams.sample_rate, n_fft=hparams.n_fft, n_mels=hparams.num_mels,
fmin=hparams.fmin, fmax=hparams.fmax)
def _amp_to_db(x, hparams):

View File

@@ -22,7 +22,8 @@ class HParams(object):
def loadJson(self, dict):
print("\Loading the json with %s\n", dict)
for k in dict.keys():
self.__dict__[k] = dict[k]
if k not in ["tts_schedule", "tts_finetune_layers"]:
self.__dict__[k] = dict[k]
return self
def dumpJson(self, fp):

View File

@@ -149,7 +149,7 @@ class Synthesizer:
Loads and preprocesses an audio file under the same conditions the audio files were used to
train the synthesizer.
"""
wav = librosa.load(str(fpath), hparams.sample_rate)[0]
wav = librosa.load(path=str(fpath), sr=hparams.sample_rate)[0]
if hparams.rescale:
wav = wav / np.abs(wav).max() * hparams.rescaling_max
# denoise

View File

@@ -0,0 +1,73 @@
import torch
import torch.nn as nn
import imp
import numpy as np
class Base(nn.Module):
def __init__(self, stop_threshold):
super().__init__()
self.init_model()
self.num_params()
self.register_buffer("step", torch.zeros(1, dtype=torch.long))
self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
@property
def r(self):
return self.decoder.r.item()
@r.setter
def r(self, value):
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
def init_model(self):
for p in self.parameters():
if p.dim() > 1: nn.init.xavier_uniform_(p)
def finetune_partial(self, whitelist_layers):
self.zero_grad()
for name, child in self.named_children():
if name in whitelist_layers:
print("Trainable Layer: %s" % name)
print("Trainable Parameters: %.3f" % sum([np.prod(p.size()) for p in child.parameters()]))
for param in child.parameters():
param.requires_grad = False
def get_step(self):
return self.step.data.item()
def reset_step(self):
# assignment to parameters or buffers is overloaded, updates internal dict entry
self.step = self.step.data.new_tensor(1)
def log(self, path, msg):
with open(path, "a") as f:
print(msg, file=f)
def load(self, path, device, optimizer=None):
# Use device of model params as location for loaded state
checkpoint = torch.load(str(path), map_location=device)
self.load_state_dict(checkpoint["model_state"], strict=False)
if "optimizer_state" in checkpoint and optimizer is not None:
optimizer.load_state_dict(checkpoint["optimizer_state"])
def save(self, path, optimizer=None):
if optimizer is not None:
torch.save({
"model_state": self.state_dict(),
"optimizer_state": optimizer.state_dict(),
}, str(path))
else:
torch.save({
"model_state": self.state_dict(),
}, str(path))
def num_params(self, print_out=True):
parameters = filter(lambda p: p.requires_grad, self.parameters())
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
if print_out:
print("Trainable Parameters: %.3fM" % parameters)
return parameters

View File

@@ -0,0 +1 @@
#

View File

@@ -0,0 +1,85 @@
import torch
import torch.nn as nn
from .common.batch_norm_conv import BatchNormConv
from .common.highway_network import HighwayNetwork
class CBHG(nn.Module):
def __init__(self, K, in_channels, channels, proj_channels, num_highways):
super().__init__()
# List of all rnns to call `flatten_parameters()` on
self._to_flatten = []
self.bank_kernels = [i for i in range(1, K + 1)]
self.conv1d_bank = nn.ModuleList()
for k in self.bank_kernels:
conv = BatchNormConv(in_channels, channels, k)
self.conv1d_bank.append(conv)
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
# Fix the highway input if necessary
if proj_channels[-1] != channels:
self.highway_mismatch = True
self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
else:
self.highway_mismatch = False
self.highways = nn.ModuleList()
for i in range(num_highways):
hn = HighwayNetwork(channels)
self.highways.append(hn)
self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
self._to_flatten.append(self.rnn)
# Avoid fragmentation of RNN parameters and associated warning
self._flatten_parameters()
def forward(self, x):
# Although we `_flatten_parameters()` on init, when using DataParallel
# the model gets replicated, making it no longer guaranteed that the
# weights are contiguous in GPU memory. Hence, we must call it again
self.rnn.flatten_parameters()
# Save these for later
residual = x
seq_len = x.size(-1)
conv_bank = []
# Convolution Bank
for conv in self.conv1d_bank:
c = conv(x) # Convolution
conv_bank.append(c[:, :, :seq_len])
# Stack along the channel axis
conv_bank = torch.cat(conv_bank, dim=1)
# dump the last padding to fit residual
x = self.maxpool(conv_bank)[:, :, :seq_len]
# Conv1d projections
x = self.conv_project1(x)
x = self.conv_project2(x)
# Residual Connect
x = x + residual
# Through the highways
x = x.transpose(1, 2)
if self.highway_mismatch is True:
x = self.pre_highway(x)
for h in self.highways: x = h(x)
# And then the RNN
x, _ = self.rnn(x)
return x
def _flatten_parameters(self):
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
to improve efficiency and avoid PyTorch yelling at us."""
[m.flatten_parameters() for m in self._to_flatten]

View File

@@ -0,0 +1,14 @@
import torch.nn as nn
import torch.nn.functional as F
class BatchNormConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel, relu=True):
super().__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
self.bnorm = nn.BatchNorm1d(out_channels)
self.relu = relu
def forward(self, x):
x = self.conv(x)
x = F.relu(x) if self.relu is True else x
return self.bnorm(x)

View File

@@ -0,0 +1,17 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class HighwayNetwork(nn.Module):
def __init__(self, size):
super().__init__()
self.W1 = nn.Linear(size, size)
self.W2 = nn.Linear(size, size)
self.W1.bias.data.fill_(0.)
def forward(self, x):
x1 = self.W1(x)
x2 = self.W2(x)
g = torch.sigmoid(x2)
y = g * F.relu(x1) + (1. - g) * x
return y

View File

@@ -97,7 +97,7 @@ class STL(nn.Module):
def forward(self, inputs):
N = inputs.size(0)
query = inputs.unsqueeze(1) # [N, 1, E//2]
keys = tFunctional.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads]
keys = torch.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads]
style_embed = self.attention(query, keys)
return style_embed

View File

@@ -0,0 +1,42 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class LSA(nn.Module):
def __init__(self, attn_dim, kernel_size=31, filters=32):
super().__init__()
self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
self.L = nn.Linear(filters, attn_dim, bias=False)
self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
self.v = nn.Linear(attn_dim, 1, bias=False)
self.cumulative = None
self.attention = None
def init_attention(self, encoder_seq_proj):
device = encoder_seq_proj.device # use same device as parameters
b, t, c = encoder_seq_proj.size()
self.cumulative = torch.zeros(b, t, device=device)
self.attention = torch.zeros(b, t, device=device)
def forward(self, encoder_seq_proj, query, times, chars):
if times == 0: self.init_attention(encoder_seq_proj)
processed_query = self.W(query).unsqueeze(1)
location = self.cumulative.unsqueeze(1)
processed_loc = self.L(self.conv(location).transpose(1, 2))
u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
u = u.squeeze(-1)
# Mask zero padding chars
u = u * (chars != 0).float()
# Smooth Attention
# scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
scores = F.softmax(u, dim=1)
self.attention = scores
self.cumulative = self.cumulative + self.attention
return scores.unsqueeze(-1).transpose(1, 2)

View File

@@ -0,0 +1,27 @@
import torch.nn as nn
import torch.nn.functional as F
class PreNet(nn.Module):
def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
super().__init__()
self.fc1 = nn.Linear(in_dims, fc1_dims)
self.fc2 = nn.Linear(fc1_dims, fc2_dims)
self.p = dropout
def forward(self, x):
"""forward
Args:
x (3D tensor with size `[batch_size, num_chars, tts_embed_dims]`): input texts list
Returns:
3D tensor with size `[batch_size, num_chars, encoder_dims]`
"""
x = self.fc1(x)
x = F.relu(x)
x = F.dropout(x, self.p, training=True)
x = self.fc2(x)
x = F.relu(x)
x = F.dropout(x, self.p, training=True)
return x

View File

@@ -1,277 +1,88 @@
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from synthesizer.models.global_style_token import GlobalStyleToken
from .sublayer.global_style_token import GlobalStyleToken
from .sublayer.pre_net import PreNet
from .sublayer.cbhg import CBHG
from .sublayer.lsa import LSA
from .base import Base
from synthesizer.gst_hyperparameters import GSTHyperparameters as gst_hp
from synthesizer.hparams import hparams
class HighwayNetwork(nn.Module):
def __init__(self, size):
super().__init__()
self.W1 = nn.Linear(size, size)
self.W2 = nn.Linear(size, size)
self.W1.bias.data.fill_(0.)
def forward(self, x):
x1 = self.W1(x)
x2 = self.W2(x)
g = torch.sigmoid(x2)
y = g * F.relu(x1) + (1. - g) * x
return y
class Encoder(nn.Module):
def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
def __init__(self, num_chars, embed_dims=512, encoder_dims=256, K=5, num_highways=4, dropout=0.5):
""" Encoder for SV2TTS
Args:
num_chars (int): length of symbols
embed_dims (int, optional): embedding dim for input texts. Defaults to 512.
encoder_dims (int, optional): output dim for encoder. Defaults to 256.
K (int, optional): _description_. Defaults to 5.
num_highways (int, optional): _description_. Defaults to 4.
dropout (float, optional): _description_. Defaults to 0.5.
"""
super().__init__()
prenet_dims = (encoder_dims, encoder_dims)
cbhg_channels = encoder_dims
self.embedding = nn.Embedding(num_chars, embed_dims)
self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
self.pre_net = PreNet(embed_dims, fc1_dims=encoder_dims, fc2_dims=encoder_dims,
dropout=dropout)
self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
proj_channels=[cbhg_channels, cbhg_channels],
self.cbhg = CBHG(K=K, in_channels=encoder_dims, channels=encoder_dims,
proj_channels=[encoder_dims, encoder_dims],
num_highways=num_highways)
def forward(self, x, speaker_embedding=None):
x = self.embedding(x)
x = self.pre_net(x)
x.transpose_(1, 2)
x = self.cbhg(x)
if speaker_embedding is not None:
x = self.add_speaker_embedding(x, speaker_embedding)
return x
def add_speaker_embedding(self, x, speaker_embedding):
# SV2TTS
# The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
# When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
# (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
# This concats the speaker embedding for each char in the encoder output
# Save the dimensions as human-readable names
batch_size = x.size()[0]
num_chars = x.size()[1]
if speaker_embedding.dim() == 1:
idx = 0
else:
idx = 1
# Start by making a copy of each speaker embedding to match the input text length
# The output of this has size (batch_size, num_chars * speaker_embedding_size)
speaker_embedding_size = speaker_embedding.size()[idx]
e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
# Reshape it and transpose
e = e.reshape(batch_size, speaker_embedding_size, num_chars)
e = e.transpose(1, 2)
# Concatenate the tiled speaker embedding with the encoder output
x = torch.cat((x, e), 2)
return x
class BatchNormConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel, relu=True):
super().__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
self.bnorm = nn.BatchNorm1d(out_channels)
self.relu = relu
def forward(self, x):
x = self.conv(x)
x = F.relu(x) if self.relu is True else x
return self.bnorm(x)
"""forward pass for encoder
Args:
x (2D tensor with size `[batch_size, text_num_chars]`): input texts list
class CBHG(nn.Module):
def __init__(self, K, in_channels, channels, proj_channels, num_highways):
super().__init__()
# List of all rnns to call `flatten_parameters()` on
self._to_flatten = []
self.bank_kernels = [i for i in range(1, K + 1)]
self.conv1d_bank = nn.ModuleList()
for k in self.bank_kernels:
conv = BatchNormConv(in_channels, channels, k)
self.conv1d_bank.append(conv)
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
# Fix the highway input if necessary
if proj_channels[-1] != channels:
self.highway_mismatch = True
self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
else:
self.highway_mismatch = False
self.highways = nn.ModuleList()
for i in range(num_highways):
hn = HighwayNetwork(channels)
self.highways.append(hn)
self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
self._to_flatten.append(self.rnn)
# Avoid fragmentation of RNN parameters and associated warning
self._flatten_parameters()
def forward(self, x):
# Although we `_flatten_parameters()` on init, when using DataParallel
# the model gets replicated, making it no longer guaranteed that the
# weights are contiguous in GPU memory. Hence, we must call it again
self.rnn.flatten_parameters()
# Save these for later
residual = x
seq_len = x.size(-1)
conv_bank = []
# Convolution Bank
for conv in self.conv1d_bank:
c = conv(x) # Convolution
conv_bank.append(c[:, :, :seq_len])
# Stack along the channel axis
conv_bank = torch.cat(conv_bank, dim=1)
# dump the last padding to fit residual
x = self.maxpool(conv_bank)[:, :, :seq_len]
# Conv1d projections
x = self.conv_project1(x)
x = self.conv_project2(x)
# Residual Connect
x = x + residual
# Through the highways
x = x.transpose(1, 2)
if self.highway_mismatch is True:
x = self.pre_highway(x)
for h in self.highways: x = h(x)
# And then the RNN
x, _ = self.rnn(x)
return x
def _flatten_parameters(self):
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
to improve efficiency and avoid PyTorch yelling at us."""
[m.flatten_parameters() for m in self._to_flatten]
class PreNet(nn.Module):
def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
super().__init__()
self.fc1 = nn.Linear(in_dims, fc1_dims)
self.fc2 = nn.Linear(fc1_dims, fc2_dims)
self.p = dropout
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = F.dropout(x, self.p, training=True)
x = self.fc2(x)
x = F.relu(x)
x = F.dropout(x, self.p, training=True)
return x
class Attention(nn.Module):
def __init__(self, attn_dims):
super().__init__()
self.W = nn.Linear(attn_dims, attn_dims, bias=False)
self.v = nn.Linear(attn_dims, 1, bias=False)
def forward(self, encoder_seq_proj, query, t):
# print(encoder_seq_proj.shape)
# Transform the query vector
query_proj = self.W(query).unsqueeze(1)
# Compute the scores
u = self.v(torch.tanh(encoder_seq_proj + query_proj))
scores = F.softmax(u, dim=1)
return scores.transpose(1, 2)
class LSA(nn.Module):
def __init__(self, attn_dim, kernel_size=31, filters=32):
super().__init__()
self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
self.L = nn.Linear(filters, attn_dim, bias=False)
self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
self.v = nn.Linear(attn_dim, 1, bias=False)
self.cumulative = None
self.attention = None
def init_attention(self, encoder_seq_proj):
device = encoder_seq_proj.device # use same device as parameters
b, t, c = encoder_seq_proj.size()
self.cumulative = torch.zeros(b, t, device=device)
self.attention = torch.zeros(b, t, device=device)
def forward(self, encoder_seq_proj, query, t, chars):
if t == 0: self.init_attention(encoder_seq_proj)
processed_query = self.W(query).unsqueeze(1)
location = self.cumulative.unsqueeze(1)
processed_loc = self.L(self.conv(location).transpose(1, 2))
u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
u = u.squeeze(-1)
# Mask zero padding chars
u = u * (chars != 0).float()
# Smooth Attention
# scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
scores = F.softmax(u, dim=1)
self.attention = scores
self.cumulative = self.cumulative + self.attention
return scores.unsqueeze(-1).transpose(1, 2)
Returns:
3D tensor with size `[batch_size, text_num_chars, encoder_dims]`
"""
x = self.embedding(x) # return: [batch_size, text_num_chars, tts_embed_dims]
x = self.pre_net(x) # return: [batch_size, text_num_chars, encoder_dims]
x.transpose_(1, 2) # return: [batch_size, encoder_dims, text_num_chars]
return self.cbhg(x) # return: [batch_size, text_num_chars, encoder_dims]
class Decoder(nn.Module):
# Class variable because its value doesn't change between classes
# yet ought to be scoped by class because its a property of a Decoder
max_r = 20
def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims,
def __init__(self, n_mels, input_dims, decoder_dims, lstm_dims,
dropout, speaker_embedding_size):
super().__init__()
self.register_buffer("r", torch.tensor(1, dtype=torch.int))
self.n_mels = n_mels
prenet_dims = (decoder_dims * 2, decoder_dims * 2)
self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
self.prenet = PreNet(n_mels, fc1_dims=decoder_dims * 2, fc2_dims=decoder_dims * 2,
dropout=dropout)
self.attn_net = LSA(decoder_dims)
if hparams.use_gst:
speaker_embedding_size += gst_hp.E
self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
self.attn_rnn = nn.GRUCell(input_dims + decoder_dims * 2, decoder_dims)
self.rnn_input = nn.Linear(input_dims + decoder_dims, lstm_dims)
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
self.stop_proj = nn.Linear(input_dims + lstm_dims, 1)
def zoneout(self, prev, current, device, p=0.1):
mask = torch.zeros(prev.size(),device=device).bernoulli_(p)
return prev * mask + current * (1 - mask)
def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
hidden_states, cell_states, context_vec, t, chars):
hidden_states, cell_states, context_vec, times, chars):
"""_summary_
Args:
encoder_seq (3D tensor `[batch_size, text_num_chars, project_dim(default to 512)]`): _description_
encoder_seq_proj (3D tensor `[batch_size, text_num_chars, decoder_dims(default to 128)]`): _description_
prenet_in (2D tensor `[batch_size, n_mels]`): _description_
hidden_states (_type_): _description_
cell_states (_type_): _description_
context_vec (2D tensor `[batch_size, project_dim(default to 512)]`): _description_
times (int): the number of times runned
chars (2D tensor with size `[batch_size, text_num_chars]`): original texts list input
"""
# Need this for reshaping mels
batch_size = encoder_seq.size(0)
device = encoder_seq.device
@@ -280,25 +91,25 @@ class Decoder(nn.Module):
rnn1_cell, rnn2_cell = cell_states
# PreNet for the Attention RNN
prenet_out = self.prenet(prenet_in)
prenet_out = self.prenet(prenet_in) # return: `[batch_size, decoder_dims * 2(256)]`
# Compute the Attention RNN hidden state
attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1) # `[batch_size, project_dim + decoder_dims * 2 (768)]`
attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden) # `[batch_size, decoder_dims (128)]`
# Compute the attention scores
scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)
scores = self.attn_net(encoder_seq_proj, attn_hidden, times, chars)
# Dot product to create the context vector
context_vec = scores @ encoder_seq
context_vec = context_vec.squeeze(1)
# Concat Attention RNN output w. Context Vector & project
x = torch.cat([context_vec, attn_hidden], dim=1)
x = self.rnn_input(x)
x = torch.cat([context_vec, attn_hidden], dim=1) # `[batch_size, project_dim + decoder_dims (630)]`
x = self.rnn_input(x) # `[batch_size, lstm_dims(1024)]`
# Compute first Residual RNN
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
# Compute first Residual RNN, training with fixed zoneout rate 0.1
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell)) # `[batch_size, lstm_dims(1024)]`
if self.training:
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next,device=device)
else:
@@ -306,7 +117,7 @@ class Decoder(nn.Module):
x = x + rnn1_hidden
# Compute second Residual RNN
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell)) # `[batch_size, lstm_dims(1024)]`
if self.training:
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next, device=device)
else:
@@ -314,8 +125,8 @@ class Decoder(nn.Module):
x = x + rnn2_hidden
# Project Mels
mels = self.mel_proj(x)
mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
mels = self.mel_proj(x) # `[batch_size, 1600]`
mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r] # `[batch_size, n_mels, r]`
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
cell_states = (rnn1_cell, rnn2_cell)
@@ -326,45 +137,30 @@ class Decoder(nn.Module):
return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
class Tacotron(nn.Module):
class Tacotron(Base):
def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels,
fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways,
dropout, stop_threshold, speaker_embedding_size):
super().__init__()
super().__init__(stop_threshold)
self.n_mels = n_mels
self.lstm_dims = lstm_dims
self.encoder_dims = encoder_dims
self.decoder_dims = decoder_dims
self.speaker_embedding_size = speaker_embedding_size
self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
self.encoder = Encoder(num_chars, embed_dims, encoder_dims,
encoder_K, num_highways, dropout)
project_dims = encoder_dims + speaker_embedding_size
self.project_dims = encoder_dims + speaker_embedding_size
if hparams.use_gst:
project_dims += gst_hp.E
self.encoder_proj = nn.Linear(project_dims, decoder_dims, bias=False)
self.project_dims += gst_hp.E
self.encoder_proj = nn.Linear(self.project_dims, decoder_dims, bias=False)
if hparams.use_gst:
self.gst = GlobalStyleToken(speaker_embedding_size)
self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
self.decoder = Decoder(n_mels, self.project_dims, decoder_dims, lstm_dims,
dropout, speaker_embedding_size)
self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
[postnet_dims, fft_bins], num_highways)
self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
self.init_model()
self.num_params()
self.register_buffer("step", torch.zeros(1, dtype=torch.long))
self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
@property
def r(self):
return self.decoder.r.item()
@r.setter
def r(self, value):
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
@staticmethod
def _concat_speaker_embedding(outputs, speaker_embeddings):
speaker_embeddings_ = speaker_embeddings.expand(
@@ -372,11 +168,52 @@ class Tacotron(nn.Module):
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
return outputs
def forward(self, texts, mels, speaker_embedding):
@staticmethod
def _add_speaker_embedding(x, speaker_embedding):
"""Add speaker embedding
This concats the speaker embedding for each char in the encoder output
Args:
x (3D tensor with size `[batch_size, text_num_chars, encoder_dims]`): the encoder output
speaker_embedding (2D tensor `[batch_size, speaker_embedding_size]`): the speaker embedding
Returns:
3D tensor with size `[batch_size, text_num_chars, encoder_dims+speaker_embedding_size]`
"""
# Save the dimensions as human-readable names
batch_size = x.size()[0]
text_num_chars = x.size()[1]
# Start by making a copy of each speaker embedding to match the input text length
# The output of this has size (batch_size, text_num_chars * speaker_embedding_size)
speaker_embedding_size = speaker_embedding.size()[1]
e = speaker_embedding.repeat_interleave(text_num_chars, dim=1)
# Reshape it and transpose
e = e.reshape(batch_size, speaker_embedding_size, text_num_chars)
e = e.transpose(1, 2)
# Concatenate the tiled speaker embedding with the encoder output
x = torch.cat((x, e), 2)
return x
def forward(self, texts, mels, speaker_embedding, steps=2000, style_idx=0, min_stop_token=5):
"""Forward pass for Tacotron
Args:
texts (`[batch_size, text_num_chars]`): input texts list
mels (`[batch_size, varied_mel_lengths, steps]`): mels for comparison (training only)
speaker_embedding (`[batch_size, speaker_embedding_size(default to 256)]`): referring embedding.
steps (int, optional): . Defaults to 2000.
style_idx (int, optional): GST style selected. Defaults to 0.
min_stop_token (int, optional): decoder min_stop_token. Defaults to 5.
"""
device = texts.device # use same device as parameters
self.step += 1
batch_size, _, steps = mels.size()
if self.training:
self.step += 1
batch_size, _, steps = mels.size()
else:
batch_size, _ = texts.size()
# Initialise all hidden states and pack into tuple
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
@@ -392,35 +229,50 @@ class Tacotron(nn.Module):
# <GO> Frame for start of decoder loop
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
# Need an initial context vector
size = self.encoder_dims + self.speaker_embedding_size
if hparams.use_gst:
size += gst_hp.E
context_vec = torch.zeros(batch_size, size, device=device)
# SV2TTS: Run the encoder with the speaker embedding
# The projection avoids unnecessary matmuls in the decoder loop
encoder_seq = self.encoder(texts, speaker_embedding)
# put after encoder
encoder_seq = self.encoder(texts)
encoder_seq = self._add_speaker_embedding(encoder_seq, speaker_embedding)
if hparams.use_gst and self.gst is not None:
style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced
# style_embed = style_embed.expand_as(encoder_seq)
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
encoder_seq_proj = self.encoder_proj(encoder_seq)
if self.training:
style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced
# style_embed = style_embed.expand_as(encoder_seq)
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
elif style_idx >= 0 and style_idx < 10:
query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
if device.type == 'cuda':
query = query.cuda()
gst_embed = torch.tanh(self.gst.stl.embed)
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
style_embed = self.gst.stl.attention(query, key)
else:
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed) # return: [batch_size, text_num_chars, project_dims]
encoder_seq_proj = self.encoder_proj(encoder_seq) # return: [batch_size, text_num_chars, decoder_dims]
# Need a couple of lists for outputs
mel_outputs, attn_scores, stop_outputs = [], [], []
# Need an initial context vector
context_vec = torch.zeros(batch_size, self.project_dims, device=device)
# Run the decoder loop
for t in range(0, steps, self.r):
prenet_in = mels[:, :, t - 1] if t > 0 else go_frame
if self.training:
prenet_in = mels[:, :, t -1] if t > 0 else go_frame
else:
prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
hidden_states, cell_states, context_vec, t, texts)
mel_outputs.append(mel_frames)
attn_scores.append(scores)
stop_outputs.extend([stop_tokens] * self.r)
if not self.training and (stop_tokens * 10 > min_stop_token).all() and t > 10: break
# Concat the mel outputs into sequence
mel_outputs = torch.cat(mel_outputs, dim=2)
@@ -435,135 +287,12 @@ class Tacotron(nn.Module):
# attn_scores = attn_scores.cpu().data.numpy()
stop_outputs = torch.cat(stop_outputs, 1)
if self.training:
self.train()
return mel_outputs, linear, attn_scores, stop_outputs
def generate(self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5):
def generate(self, x, speaker_embedding, steps=2000, style_idx=0, min_stop_token=5):
self.eval()
device = x.device # use same device as parameters
batch_size, _ = x.size()
# Need to initialise all hidden states and pack into tuple for tidyness
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
# Need to initialise all lstm cell states and pack into tuple for tidyness
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
cell_states = (rnn1_cell, rnn2_cell)
# Need a <GO> Frame for start of decoder loop
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
# Need an initial context vector
size = self.encoder_dims + self.speaker_embedding_size
if hparams.use_gst:
size += gst_hp.E
context_vec = torch.zeros(batch_size, size, device=device)
# SV2TTS: Run the encoder with the speaker embedding
# The projection avoids unnecessary matmuls in the decoder loop
encoder_seq = self.encoder(x, speaker_embedding)
# put after encoder
if hparams.use_gst and self.gst is not None:
if style_idx >= 0 and style_idx < 10:
query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
if device.type == 'cuda':
query = query.cuda()
gst_embed = torch.tanh(self.gst.stl.embed)
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
style_embed = self.gst.stl.attention(query, key)
else:
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
# style_embed = style_embed.expand_as(encoder_seq)
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
encoder_seq_proj = self.encoder_proj(encoder_seq)
# Need a couple of lists for outputs
mel_outputs, attn_scores, stop_outputs = [], [], []
# Run the decoder loop
for t in range(0, steps, self.r):
prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
hidden_states, cell_states, context_vec, t, x)
mel_outputs.append(mel_frames)
attn_scores.append(scores)
stop_outputs.extend([stop_tokens] * self.r)
# Stop the loop when all stop tokens in batch exceed threshold
if (stop_tokens * 10 > min_stop_token).all() and t > 10: break
# Concat the mel outputs into sequence
mel_outputs = torch.cat(mel_outputs, dim=2)
# Post-Process for Linear Spectrograms
postnet_out = self.postnet(mel_outputs)
linear = self.post_proj(postnet_out)
linear = linear.transpose(1, 2)
# For easy visualisation
attn_scores = torch.cat(attn_scores, 1)
stop_outputs = torch.cat(stop_outputs, 1)
self.train()
mel_outputs, linear, attn_scores, _ = self.forward(x, None, speaker_embedding, steps, style_idx, min_stop_token)
return mel_outputs, linear, attn_scores
def init_model(self):
for p in self.parameters():
if p.dim() > 1: nn.init.xavier_uniform_(p)
def finetune_partial(self, whitelist_layers):
self.zero_grad()
for name, child in self.named_children():
if name in whitelist_layers:
print("Trainable Layer: %s" % name)
print("Trainable Parameters: %.3f" % sum([np.prod(p.size()) for p in child.parameters()]))
for param in child.parameters():
param.requires_grad = False
def get_step(self):
return self.step.data.item()
def reset_step(self):
# assignment to parameters or buffers is overloaded, updates internal dict entry
self.step = self.step.data.new_tensor(1)
def log(self, path, msg):
with open(path, "a") as f:
print(msg, file=f)
def load(self, path, device, optimizer=None):
# Use device of model params as location for loaded state
checkpoint = torch.load(str(path), map_location=device)
self.load_state_dict(checkpoint["model_state"], strict=False)
if "optimizer_state" in checkpoint and optimizer is not None:
optimizer.load_state_dict(checkpoint["optimizer_state"])
def save(self, path, optimizer=None):
if optimizer is not None:
torch.save({
"model_state": self.state_dict(),
"optimizer_state": optimizer.state_dict(),
}, str(path))
else:
torch.save({
"model_state": self.state_dict(),
}, str(path))
def num_params(self, print_out=True):
parameters = filter(lambda p: p.requires_grad, self.parameters())
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
if print_out:
print("Trainable Parameters: %.3fM" % parameters)
return parameters

View File

@@ -63,7 +63,7 @@ def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
def _split_on_silences(wav_fpath, words, hparams):
# Load the audio waveform
wav, _ = librosa.load(wav_fpath, hparams.sample_rate)
wav, _ = librosa.load(wav_fpath, sr= hparams.sample_rate)
wav = librosa.effects.trim(wav, top_db= 40, frame_length=2048, hop_length=512)[0]
if hparams.rescale:
wav = wav / np.abs(wav).max() * hparams.rescaling_max

View File

@@ -15,9 +15,8 @@ from datetime import datetime
import json
import numpy as np
from pathlib import Path
import sys
import time
import os
def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
@@ -265,7 +264,19 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
loss=loss,
hparams=hparams,
sw=sw)
MAX_SAVED_COUNT = 20
if (step / hparams.tts_eval_interval) % MAX_SAVED_COUNT == 0:
# clean up and save last MAX_SAVED_COUNT;
plots = next(os.walk(plot_dir), (None, None, []))[2]
for plot in plots[-MAX_SAVED_COUNT:]:
os.remove(plot_dir.joinpath(plot))
mel_files = next(os.walk(mel_output_dir), (None, None, []))[2]
for mel_file in mel_files[-MAX_SAVED_COUNT:]:
os.remove(mel_output_dir.joinpath(mel_file))
wavs = next(os.walk(wav_dir), (None, None, []))[2]
for w in wavs[-MAX_SAVED_COUNT:]:
os.remove(wav_dir.joinpath(w))
# Break out of loop to update training schedule
if step >= max_step:
break

View File

@@ -3,12 +3,10 @@ from encoder import inference as encoder
from synthesizer.inference import Synthesizer
from vocoder.wavernn import inference as rnn_vocoder
from vocoder.hifigan import inference as gan_vocoder
import ppg_extractor as extractor
import ppg2mel as convertor
from vocoder.fregan import inference as fgan_vocoder
from pathlib import Path
from time import perf_counter as timer
from toolbox.utterance import Utterance
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
import numpy as np
import traceback
import sys
@@ -373,6 +371,8 @@ class Toolbox:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ref_wav = self.ui.selected_utterance.wav
# Import necessary dependency of Voice Conversion
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
min_len = min(ppg.shape[1], len(lf0_uv))
@@ -397,6 +397,7 @@ class Toolbox:
self.ui.log("Loading the extractor %s... " % model_fpath)
self.ui.set_loading(1)
start = timer()
import ppg_extractor as extractor
self.extractor = extractor.load_model(model_fpath)
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
self.ui.set_loading(0)
@@ -405,15 +406,11 @@ class Toolbox:
if self.ui.current_convertor_fpath is None:
return
model_fpath = self.ui.current_convertor_fpath
# search a config file
model_config_fpaths = list(model_fpath.parent.rglob("*.yaml"))
if self.ui.current_convertor_fpath is None:
return
model_config_fpath = model_config_fpaths[0]
self.ui.log("Loading the convertor %s... " % model_fpath)
self.ui.set_loading(1)
start = timer()
self.convertor = convertor.load_model(model_config_fpath, model_fpath)
import ppg2mel as convertor
self.convertor = convertor.load_model( model_fpath)
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
self.ui.set_loading(0)
@@ -446,14 +443,24 @@ class Toolbox:
return
# Sekect vocoder based on model name
model_config_fpath = None
if model_fpath.name[0] == "g":
if model_fpath.name is not None and model_fpath.name.find("hifigan") > -1:
vocoder = gan_vocoder
self.ui.log("set hifigan as vocoder")
# search a config file
model_config_fpaths = list(model_fpath.parent.rglob("*.json"))
if self.ui.current_extractor_fpath is None:
if self.vc_mode and self.ui.current_extractor_fpath is None:
return
model_config_fpath = model_config_fpaths[0]
if len(model_config_fpaths) > 0:
model_config_fpath = model_config_fpaths[0]
elif model_fpath.name is not None and model_fpath.name.find("fregan") > -1:
vocoder = fgan_vocoder
self.ui.log("set fregan as vocoder")
# search a config file
model_config_fpaths = list(model_fpath.parent.rglob("*.json"))
if self.vc_mode and self.ui.current_extractor_fpath is None:
return
if len(model_config_fpaths) > 0:
model_config_fpath = model_config_fpaths[0]
else:
vocoder = rnn_vocoder
self.ui.log("set wavernn as vocoder")

67
train.py Normal file
View File

@@ -0,0 +1,67 @@
import sys
import torch
import argparse
import numpy as np
from utils.load_yaml import HpsYaml
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
# For reproducibility, comment these may speed up training
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main():
# Arguments
parser = argparse.ArgumentParser(description=
'Training PPG2Mel VC model.')
parser.add_argument('--config', type=str,
help='Path to experiment config, e.g., config/vc.yaml')
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
parser.add_argument('--logdir', default='log/', type=str,
help='Logging path.', required=False)
parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str,
help='Checkpoint path.', required=False)
parser.add_argument('--outdir', default='result/', type=str,
help='Decode output path.', required=False)
parser.add_argument('--load', default=None, type=str,
help='Load pre-trained model (for training only)', required=False)
parser.add_argument('--warm_start', action='store_true',
help='Load model weights only, ignore specified layers.')
parser.add_argument('--seed', default=0, type=int,
help='Random seed for reproducable results.', required=False)
parser.add_argument('--njobs', default=8, type=int,
help='Number of threads for dataloader/decoding.', required=False)
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
parser.add_argument('--no-pin', action='store_true',
help='Disable pin-memory for dataloader')
parser.add_argument('--test', action='store_true', help='Test the model.')
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
parser.add_argument('--finetune', action='store_true', help='Finetune model')
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
###
paras = parser.parse_args()
setattr(paras, 'gpu', not paras.cpu)
setattr(paras, 'pin_memory', not paras.no_pin)
setattr(paras, 'verbose', not paras.no_msg)
# Make the config dict dot visitable
config = HpsYaml(paras.config)
np.random.seed(paras.seed)
torch.manual_seed(paras.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(paras.seed)
print(">>> OneShot VC training ...")
mode = "train"
solver = Solver(config, paras, mode)
solver.load_data()
solver.set_model()
solver.exec()
print(">>> Oneshot VC train finished!")
sys.exit(0)
if __name__ == "__main__":
main()

View File

@@ -42,3 +42,9 @@ def human_format(num):
# add more suffixes if you need them
return '{:3.1f}{}'.format(num, [' ', 'K', 'M', 'G', 'T', 'P'][magnitude])
# provide easy access of attribute from dict, such abc.key
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self

129
vocoder/fregan/.gitignore vendored Normal file
View File

@@ -0,0 +1,129 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/

21
vocoder/fregan/LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2021 Rishikesh (ऋषिकेश)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,42 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 16,
"learning_rate": 0.0002,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.999,
"seed": 1234,
"disc_start_step":0,
"upsample_rates": [5,5,2,2,2],
"upsample_kernel_sizes": [10,10,4,4,4],
"upsample_initial_channel": 512,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1, 3, 5, 7], [1,3,5,7], [1,3,5,7]],
"segment_size": 6400,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 200,
"win_size": 800,
"sampling_rate": 16000,
"fmin": 0,
"fmax": 7600,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@@ -0,0 +1,303 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, spectral_norm
from vocoder.fregan.utils import get_padding
from vocoder.fregan.stft_loss import stft
from vocoder.fregan.dwt import DWT_1D
LRELU_SLOPE = 0.1
class SpecDiscriminator(nn.Module):
"""docstring for Discriminator."""
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
super(SpecDiscriminator, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.window = getattr(torch, window)(win_length)
self.discriminators = nn.ModuleList([
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1,1), padding=(1, 1))),
])
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
def forward(self, y):
fmap = []
with torch.no_grad():
y = y.squeeze(1)
y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.get_device()))
y = y.unsqueeze(1)
for i, d in enumerate(self.discriminators):
y = d(y)
y = F.leaky_relu(y, LRELU_SLOPE)
fmap.append(y)
y = self.out(y)
fmap.append(y)
return torch.flatten(y, 1, -1), fmap
class MultiResSpecDiscriminator(torch.nn.Module):
def __init__(self,
fft_sizes=[1024, 2048, 512],
hop_sizes=[120, 240, 50],
win_lengths=[600, 1200, 240],
window="hann_window"):
super(MultiResSpecDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)
])
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.dwt1d = DWT_1D()
self.dwt_conv1 = norm_f(Conv1d(2, 1, 1))
self.dwt_proj1 = norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0)))
self.dwt_conv2 = norm_f(Conv1d(4, 1, 1))
self.dwt_proj2 = norm_f(Conv2d(1, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0)))
self.dwt_conv3 = norm_f(Conv1d(8, 1, 1))
self.dwt_proj3 = norm_f(Conv2d(1, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0)))
self.convs = nn.ModuleList([
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
])
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# DWT 1
x_d1_high1, x_d1_low1 = self.dwt1d(x)
x_d1 = self.dwt_conv1(torch.cat([x_d1_high1, x_d1_low1], dim=1))
# 1d to 2d
b, c, t = x_d1.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x_d1 = F.pad(x_d1, (0, n_pad), "reflect")
t = t + n_pad
x_d1 = x_d1.view(b, c, t // self.period, self.period)
x_d1 = self.dwt_proj1(x_d1)
# DWT 2
x_d2_high1, x_d2_low1 = self.dwt1d(x_d1_high1)
x_d2_high2, x_d2_low2 = self.dwt1d(x_d1_low1)
x_d2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1))
# 1d to 2d
b, c, t = x_d2.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x_d2 = F.pad(x_d2, (0, n_pad), "reflect")
t = t + n_pad
x_d2 = x_d2.view(b, c, t // self.period, self.period)
x_d2 = self.dwt_proj2(x_d2)
# DWT 3
x_d3_high1, x_d3_low1 = self.dwt1d(x_d2_high1)
x_d3_high2, x_d3_low2 = self.dwt1d(x_d2_low1)
x_d3_high3, x_d3_low3 = self.dwt1d(x_d2_high2)
x_d3_high4, x_d3_low4 = self.dwt1d(x_d2_low2)
x_d3 = self.dwt_conv3(
torch.cat([x_d3_high1, x_d3_low1, x_d3_high2, x_d3_low2, x_d3_high3, x_d3_low3, x_d3_high4, x_d3_low4],
dim=1))
# 1d to 2d
b, c, t = x_d3.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x_d3 = F.pad(x_d3, (0, n_pad), "reflect")
t = t + n_pad
x_d3 = x_d3.view(b, c, t // self.period, self.period)
x_d3 = self.dwt_proj3(x_d3)
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
i = 0
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
if i == 0:
x = torch.cat([x, x_d1], dim=2)
elif i == 1:
x = torch.cat([x, x_d2], dim=2)
elif i == 2:
x = torch.cat([x, x_d3], dim=2)
else:
x = x
i = i + 1
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class ResWiseMultiPeriodDiscriminator(torch.nn.Module):
def __init__(self):
super(ResWiseMultiPeriodDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([
DiscriminatorP(2),
DiscriminatorP(3),
DiscriminatorP(5),
DiscriminatorP(7),
DiscriminatorP(11),
])
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.dwt1d = DWT_1D()
self.dwt_conv1 = norm_f(Conv1d(2, 128, 15, 1, padding=7))
self.dwt_conv2 = norm_f(Conv1d(4, 128, 41, 2, padding=20))
self.convs = nn.ModuleList([
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
])
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
# DWT 1
x_d1_high1, x_d1_low1 = self.dwt1d(x)
x_d1 = self.dwt_conv1(torch.cat([x_d1_high1, x_d1_low1], dim=1))
# DWT 2
x_d2_high1, x_d2_low1 = self.dwt1d(x_d1_high1)
x_d2_high2, x_d2_low2 = self.dwt1d(x_d1_low1)
x_d2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1))
i = 0
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
if i == 0:
x = torch.cat([x, x_d1], dim=2)
if i == 1:
x = torch.cat([x, x_d2], dim=2)
i = i + 1
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class ResWiseMultiScaleDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(ResWiseMultiScaleDiscriminator, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.dwt1d = DWT_1D()
self.dwt_conv1 = norm_f(Conv1d(2, 1, 1))
self.dwt_conv2 = norm_f(Conv1d(4, 1, 1))
self.discriminators = nn.ModuleList([
DiscriminatorS(use_spectral_norm=True),
DiscriminatorS(),
DiscriminatorS(),
])
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
# DWT 1
y_hi, y_lo = self.dwt1d(y)
y_1 = self.dwt_conv1(torch.cat([y_hi, y_lo], dim=1))
x_d1_high1, x_d1_low1 = self.dwt1d(y_hat)
y_hat_1 = self.dwt_conv1(torch.cat([x_d1_high1, x_d1_low1], dim=1))
# DWT 2
x_d2_high1, x_d2_low1 = self.dwt1d(y_hi)
x_d2_high2, x_d2_low2 = self.dwt1d(y_lo)
y_2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1))
x_d2_high1, x_d2_low1 = self.dwt1d(x_d1_high1)
x_d2_high2, x_d2_low2 = self.dwt1d(x_d1_low1)
y_hat_2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1))
for i, d in enumerate(self.discriminators):
if i == 1:
y = y_1
y_hat = y_hat_1
if i == 2:
y = y_2
y_hat = y_hat_2
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs

76
vocoder/fregan/dwt.py Normal file
View File

@@ -0,0 +1,76 @@
# Copyright (c) 2019, Adobe Inc. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
# 4.0 International Public License. To view a copy of this license, visit
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
# DWT code borrow from https://github.com/LiQiufu/WaveSNet/blob/12cb9d24208c3d26917bf953618c30f0c6b0f03d/DWT_IDWT/DWT_IDWT_layer.py
import pywt
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['DWT_1D']
Pad_Mode = ['constant', 'reflect', 'replicate', 'circular']
class DWT_1D(nn.Module):
def __init__(self, pad_type='reflect', wavename='haar',
stride=2, in_channels=1, out_channels=None, groups=None,
kernel_size=None, trainable=False):
super(DWT_1D, self).__init__()
self.trainable = trainable
self.kernel_size = kernel_size
if not self.trainable:
assert self.kernel_size == None
self.in_channels = in_channels
self.out_channels = self.in_channels if out_channels == None else out_channels
self.groups = self.in_channels if groups == None else groups
assert isinstance(self.groups, int) and self.in_channels % self.groups == 0
self.stride = stride
assert self.stride == 2
self.wavename = wavename
self.pad_type = pad_type
assert self.pad_type in Pad_Mode
self.get_filters()
self.initialization()
def get_filters(self):
wavelet = pywt.Wavelet(self.wavename)
band_low = torch.tensor(wavelet.rec_lo)
band_high = torch.tensor(wavelet.rec_hi)
length_band = band_low.size()[0]
self.kernel_size = length_band if self.kernel_size == None else self.kernel_size
assert self.kernel_size >= length_band
a = (self.kernel_size - length_band) // 2
b = - (self.kernel_size - length_band - a)
b = None if b == 0 else b
self.filt_low = torch.zeros(self.kernel_size)
self.filt_high = torch.zeros(self.kernel_size)
self.filt_low[a:b] = band_low
self.filt_high[a:b] = band_high
def initialization(self):
self.filter_low = self.filt_low[None, None, :].repeat((self.out_channels, self.in_channels // self.groups, 1))
self.filter_high = self.filt_high[None, None, :].repeat((self.out_channels, self.in_channels // self.groups, 1))
if torch.cuda.is_available():
self.filter_low = self.filter_low.cuda()
self.filter_high = self.filter_high.cuda()
if self.trainable:
self.filter_low = nn.Parameter(self.filter_low)
self.filter_high = nn.Parameter(self.filter_high)
if self.kernel_size % 2 == 0:
self.pad_sizes = [self.kernel_size // 2 - 1, self.kernel_size // 2 - 1]
else:
self.pad_sizes = [self.kernel_size // 2, self.kernel_size // 2]
def forward(self, input):
assert isinstance(input, torch.Tensor)
assert len(input.size()) == 3
assert input.size()[1] == self.in_channels
input = F.pad(input, pad=self.pad_sizes, mode=self.pad_type)
return F.conv1d(input, self.filter_low.to(input.device), stride=self.stride, groups=self.groups), \
F.conv1d(input, self.filter_high.to(input.device), stride=self.stride, groups=self.groups)

210
vocoder/fregan/generator.py Normal file
View File

@@ -0,0 +1,210 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from vocoder.fregan.utils import init_weights, get_padding
LRELU_SLOPE = 0.1
class ResBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5, 7)):
super(ResBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[3],
padding=get_padding(kernel_size, dilation[3])))
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.convs2.apply(init_weights)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.h = h
self.convs = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])))
])
self.convs.apply(init_weights)
def forward(self, x):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class FreGAN(torch.nn.Module):
def __init__(self, h, top_k=4):
super(FreGAN, self).__init__()
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
self.upsample_rates = h.upsample_rates
self.up_kernels = h.upsample_kernel_sizes
self.cond_level = self.num_upsamples - top_k
self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
self.ups = nn.ModuleList()
self.cond_up = nn.ModuleList()
self.res_output = nn.ModuleList()
upsample_ = 1
kr = 80
for i, (u, k) in enumerate(zip(self.upsample_rates, self.up_kernels)):
# self.ups.append(weight_norm(
# ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)),
# k, u, padding=(k - u) // 2)))
self.ups.append(weight_norm(ConvTranspose1d(h.upsample_initial_channel//(2**i),
h.upsample_initial_channel//(2**(i+1)),
k, u, padding=(u//2 + u%2), output_padding=u%2)))
if i > (self.num_upsamples - top_k):
self.res_output.append(
nn.Sequential(
nn.Upsample(scale_factor=u, mode='nearest'),
weight_norm(nn.Conv1d(h.upsample_initial_channel // (2 ** i),
h.upsample_initial_channel // (2 ** (i + 1)), 1))
)
)
if i >= (self.num_upsamples - top_k):
self.cond_up.append(
weight_norm(
ConvTranspose1d(kr, h.upsample_initial_channel // (2 ** i),
self.up_kernels[i - 1], self.upsample_rates[i - 1],
padding=(self.upsample_rates[i-1]//2+self.upsample_rates[i-1]%2), output_padding=self.upsample_rates[i-1]%2))
)
kr = h.upsample_initial_channel // (2 ** i)
upsample_ *= u
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
self.cond_up.apply(init_weights)
self.res_output.apply(init_weights)
def forward(self, x):
mel = x
x = self.conv_pre(x)
output = None
for i in range(self.num_upsamples):
if i >= self.cond_level:
mel = self.cond_up[i - self.cond_level](mel)
x += mel
if i > self.cond_level:
if output is None:
output = self.res_output[i - self.cond_level - 1](x)
else:
output = self.res_output[i - self.cond_level - 1](output)
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
if output is not None:
output = output + x
x = F.leaky_relu(output)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
for l in self.cond_up:
remove_weight_norm(l)
for l in self.res_output:
remove_weight_norm(l[1])
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
'''
to run this, fix
from . import ResStack
into
from res_stack import ResStack
'''
if __name__ == '__main__':
'''
torch.Size([3, 80, 10])
torch.Size([3, 1, 2000])
4527362
'''
with open('config.json') as f:
data = f.read()
from utils import AttrDict
import json
json_config = json.loads(data)
h = AttrDict(json_config)
model = FreGAN(h)
c = torch.randn(3, 80, 10) # (B, channels, T).
print(c.shape)
y = model(c) # (B, 1, T ** prod(upsample_scales)
print(y.shape)
assert y.shape == torch.Size([3, 1, 2560]) # For normal melgan torch.Size([3, 1, 2560])
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)

View File

@@ -0,0 +1,74 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import json
import torch
from utils.util import AttrDict
from vocoder.fregan.generator import FreGAN
generator = None # type: FreGAN
output_sample_rate = None
_device = None
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print("Loading '{}'".format(filepath))
checkpoint_dict = torch.load(filepath, map_location=device)
print("Complete.")
return checkpoint_dict
def load_model(weights_fpath, config_fpath=None, verbose=True):
global generator, _device, output_sample_rate
if verbose:
print("Building fregan")
if config_fpath == None:
model_config_fpaths = list(weights_fpath.parent.rglob("*.json"))
if len(model_config_fpaths) > 0:
config_fpath = model_config_fpaths[0]
else:
config_fpath = "./vocoder/fregan/config.json"
with open(config_fpath) as f:
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
output_sample_rate = h.sampling_rate
torch.manual_seed(h.seed)
if torch.cuda.is_available():
# _model = _model.cuda()
_device = torch.device('cuda')
else:
_device = torch.device('cpu')
generator = FreGAN(h).to(_device)
state_dict_g = load_checkpoint(
weights_fpath, _device
)
generator.load_state_dict(state_dict_g['generator'])
generator.eval()
generator.remove_weight_norm()
def is_loaded():
return generator is not None
def infer_waveform(mel, progress_callback=None):
if generator is None:
raise Exception("Please load fre-gan in memory before using it")
mel = torch.FloatTensor(mel).to(_device)
mel = mel.unsqueeze(0)
with torch.no_grad():
y_g_hat = generator(mel)
audio = y_g_hat.squeeze()
audio = audio.cpu().numpy()
return audio, output_sample_rate

35
vocoder/fregan/loss.py Normal file
View File

@@ -0,0 +1,35 @@
import torch
def feature_loss(fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss*2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1-dr)**2)
g_loss = torch.mean(dg**2)
loss += (r_loss + g_loss)
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_loss(disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1-dg)**2)
gen_losses.append(l)
loss += l
return loss, gen_losses

View File

@@ -0,0 +1,176 @@
import math
import os
import random
import torch
import torch.utils.data
import numpy as np
from librosa.util import normalize
from scipy.io.wavfile import read
from librosa.filters import mel as librosa_mel_fn
MAX_WAV_VALUE = 32768.0
def load_wav(full_path):
sampling_rate, data = read(full_path)
return data, sampling_rate
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
mel_basis = {}
hann_window = {}
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.:
print('min value is ', torch.min(y))
if torch.max(y) > 1.:
print('max value is ', torch.max(y))
global mel_basis, hann_window
if fmax not in mel_basis:
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
center=center, pad_mode='reflect', normalized=False, onesided=True)
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
spec = spectral_normalize_torch(spec)
return spec
def get_dataset_filelist(a):
#with open(a.input_training_file, 'r', encoding='utf-8') as fi:
# training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
# for x in fi.read().split('\n') if len(x) > 0]
#with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
# validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
# for x in fi.read().split('\n') if len(x) > 0]
files = os.listdir(a.input_wavs_dir)
random.shuffle(files)
files = [os.path.join(a.input_wavs_dir, f) for f in files]
training_files = files[: -int(len(files) * 0.05)]
validation_files = files[-int(len(files) * 0.05):]
return training_files, validation_files
class MelDataset(torch.utils.data.Dataset):
def __init__(self, training_files, segment_size, n_fft, num_mels,
hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
self.audio_files = training_files
random.seed(1234)
if shuffle:
random.shuffle(self.audio_files)
self.segment_size = segment_size
self.sampling_rate = sampling_rate
self.split = split
self.n_fft = n_fft
self.num_mels = num_mels
self.hop_size = hop_size
self.win_size = win_size
self.fmin = fmin
self.fmax = fmax
self.fmax_loss = fmax_loss
self.cached_wav = None
self.n_cache_reuse = n_cache_reuse
self._cache_ref_count = 0
self.device = device
self.fine_tuning = fine_tuning
self.base_mels_path = base_mels_path
def __getitem__(self, index):
filename = self.audio_files[index]
if self._cache_ref_count == 0:
#audio, sampling_rate = load_wav(filename)
#audio = audio / MAX_WAV_VALUE
audio = np.load(filename)
if not self.fine_tuning:
audio = normalize(audio) * 0.95
self.cached_wav = audio
#if sampling_rate != self.sampling_rate:
# raise ValueError("{} SR doesn't match target {} SR".format(
# sampling_rate, self.sampling_rate))
self._cache_ref_count = self.n_cache_reuse
else:
audio = self.cached_wav
self._cache_ref_count -= 1
audio = torch.FloatTensor(audio)
audio = audio.unsqueeze(0)
if not self.fine_tuning:
if self.split:
if audio.size(1) >= self.segment_size:
max_audio_start = audio.size(1) - self.segment_size
audio_start = random.randint(0, max_audio_start)
audio = audio[:, audio_start:audio_start+self.segment_size]
else:
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
center=False)
else:
mel_path = os.path.join(self.base_mels_path, "mel" + "-" + filename.split("/")[-1].split("-")[-1])
mel = np.load(mel_path).T
#mel = np.load(
# os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
mel = torch.from_numpy(mel)
if len(mel.shape) < 3:
mel = mel.unsqueeze(0)
if self.split:
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
if audio.size(1) >= self.segment_size:
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
mel = mel[:, :, mel_start:mel_start + frames_per_seg]
audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
else:
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
center=False)
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
def __len__(self):
return len(self.audio_files)

201
vocoder/fregan/modules.py Normal file
View File

@@ -0,0 +1,201 @@
import torch
import torch.nn.functional as F
class KernelPredictor(torch.nn.Module):
''' Kernel predictor for the location-variable convolutions
'''
def __init__(self,
cond_channels,
conv_in_channels,
conv_out_channels,
conv_layers,
conv_kernel_size=3,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0,
kpnet_nonlinear_activation="LeakyReLU",
kpnet_nonlinear_activation_params={"negative_slope": 0.1}
):
'''
Args:
cond_channels (int): number of channel for the conditioning sequence,
conv_in_channels (int): number of channel for the input sequence,
conv_out_channels (int): number of channel for the output sequence,
conv_layers (int):
kpnet_
'''
super().__init__()
self.conv_in_channels = conv_in_channels
self.conv_out_channels = conv_out_channels
self.conv_kernel_size = conv_kernel_size
self.conv_layers = conv_layers
l_w = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers
l_b = conv_out_channels * conv_layers
padding = (kpnet_conv_size - 1) // 2
self.input_conv = torch.nn.Sequential(
torch.nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=(5 - 1) // 2, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
self.residual_conv = torch.nn.Sequential(
torch.nn.Dropout(kpnet_dropout),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Dropout(kpnet_dropout),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Dropout(kpnet_dropout),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
self.kernel_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_w, kpnet_conv_size,
padding=padding, bias=True)
self.bias_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_b, kpnet_conv_size, padding=padding,
bias=True)
def forward(self, c):
'''
Args:
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
Returns:
'''
batch, cond_channels, cond_length = c.shape
c = self.input_conv(c)
c = c + self.residual_conv(c)
k = self.kernel_conv(c)
b = self.bias_conv(c)
kernels = k.contiguous().view(batch,
self.conv_layers,
self.conv_in_channels,
self.conv_out_channels,
self.conv_kernel_size,
cond_length)
bias = b.contiguous().view(batch,
self.conv_layers,
self.conv_out_channels,
cond_length)
return kernels, bias
class LVCBlock(torch.nn.Module):
''' the location-variable convolutions
'''
def __init__(self,
in_channels,
cond_channels,
upsample_ratio,
conv_layers=4,
conv_kernel_size=3,
cond_hop_length=256,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0
):
super().__init__()
self.cond_hop_length = cond_hop_length
self.conv_layers = conv_layers
self.conv_kernel_size = conv_kernel_size
self.convs = torch.nn.ModuleList()
self.upsample = torch.nn.ConvTranspose1d(in_channels, in_channels,
kernel_size=upsample_ratio*2, stride=upsample_ratio,
padding=upsample_ratio // 2 + upsample_ratio % 2,
output_padding=upsample_ratio % 2)
self.kernel_predictor = KernelPredictor(
cond_channels=cond_channels,
conv_in_channels=in_channels,
conv_out_channels=2 * in_channels,
conv_layers=conv_layers,
conv_kernel_size=conv_kernel_size,
kpnet_hidden_channels=kpnet_hidden_channels,
kpnet_conv_size=kpnet_conv_size,
kpnet_dropout=kpnet_dropout
)
for i in range(conv_layers):
padding = (3 ** i) * int((conv_kernel_size - 1) / 2)
conv = torch.nn.Conv1d(in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3 ** i)
self.convs.append(conv)
def forward(self, x, c):
''' forward propagation of the location-variable convolutions.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length)
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
Returns:
Tensor: the output sequence (batch, in_channels, in_length)
'''
batch, in_channels, in_length = x.shape
kernels, bias = self.kernel_predictor(c)
x = F.leaky_relu(x, 0.2)
x = self.upsample(x)
for i in range(self.conv_layers):
y = F.leaky_relu(x, 0.2)
y = self.convs[i](y)
y = F.leaky_relu(y, 0.2)
k = kernels[:, i, :, :, :, :]
b = bias[:, i, :, :]
y = self.location_variable_convolution(y, k, b, 1, self.cond_hop_length)
x = x + torch.sigmoid(y[:, :in_channels, :]) * torch.tanh(y[:, in_channels:, :])
return x
def location_variable_convolution(self, x, kernel, bias, dilation, hop_size):
''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length).
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
dilation (int): the dilation of convolution.
hop_size (int): the hop_size of the conditioning sequence.
Returns:
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
'''
batch, in_channels, in_length = x.shape
batch, in_channels, out_channels, kernel_size, kernel_length = kernel.shape
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
padding = dilation * int((kernel_size - 1) / 2)
x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding)
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
if hop_size < dilation:
x = F.pad(x, (0, dilation), 'constant', 0)
x = x.unfold(3, dilation,
dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
x = x[:, :, :, :, :hop_size]
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
o = torch.einsum('bildsk,biokl->bolsd', x, kernel)
o = o + bias.unsqueeze(-1).unsqueeze(-1)
o = o.contiguous().view(batch, out_channels, -1)
return o

136
vocoder/fregan/stft_loss.py Normal file
View File

@@ -0,0 +1,136 @@
# -*- coding: utf-8 -*-
# Copyright 2019 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
"""STFT-based Loss modules."""
import torch
import torch.nn.functional as F
def stft(x, fft_size, hop_size, win_length, window):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
real = x_stft[..., 0]
imag = x_stft[..., 1]
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
class SpectralConvergengeLoss(torch.nn.Module):
"""Spectral convergence loss module."""
def __init__(self):
"""Initilize spectral convergence loss module."""
super(SpectralConvergengeLoss, self).__init__()
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Spectral convergence loss value.
"""
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
class LogSTFTMagnitudeLoss(torch.nn.Module):
"""Log STFT magnitude loss module."""
def __init__(self):
"""Initilize los STFT magnitude loss module."""
super(LogSTFTMagnitudeLoss, self).__init__()
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Log STFT magnitude loss value.
"""
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
class STFTLoss(torch.nn.Module):
"""STFT loss module."""
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
"""Initialize STFT loss module."""
super(STFTLoss, self).__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.window = getattr(torch, window)(win_length)
self.spectral_convergenge_loss = SpectralConvergengeLoss()
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Spectral convergence loss value.
Tensor: Log STFT magnitude loss value.
"""
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window.to(x.get_device()))
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(x.get_device()))
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
return sc_loss, mag_loss
class MultiResolutionSTFTLoss(torch.nn.Module):
"""Multi resolution STFT loss module."""
def __init__(self,
fft_sizes=[1024, 2048, 512],
hop_sizes=[120, 240, 50],
win_lengths=[600, 1200, 240],
window="hann_window"):
"""Initialize Multi resolution STFT loss module.
Args:
fft_sizes (list): List of FFT sizes.
hop_sizes (list): List of hop sizes.
win_lengths (list): List of window lengths.
window (str): Window function type.
"""
super(MultiResolutionSTFTLoss, self).__init__()
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
self.stft_losses = torch.nn.ModuleList()
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
self.stft_losses += [STFTLoss(fs, ss, wl, window)]
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Multi resolution spectral convergence loss value.
Tensor: Multi resolution log STFT magnitude loss value.
"""
sc_loss = 0.0
mag_loss = 0.0
for f in self.stft_losses:
sc_l, mag_l = f(x, y)
sc_loss += sc_l
mag_loss += mag_l
sc_loss /= len(self.stft_losses)
mag_loss /= len(self.stft_losses)
return sc_loss, mag_loss

246
vocoder/fregan/train.py Normal file
View File

@@ -0,0 +1,246 @@
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import itertools
import os
import time
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DistributedSampler, DataLoader
from torch.distributed import init_process_group
from torch.nn.parallel import DistributedDataParallel
from vocoder.fregan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
from vocoder.fregan.generator import FreGAN
from vocoder.fregan.discriminator import ResWiseMultiPeriodDiscriminator, ResWiseMultiScaleDiscriminator
from vocoder.fregan.loss import feature_loss, generator_loss, discriminator_loss
from vocoder.fregan.utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
torch.backends.cudnn.benchmark = True
def train(rank, a, h):
a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_fregan')
a.checkpoint_path.mkdir(exist_ok=True)
a.training_epochs = 3100
a.stdout_interval = 5
a.checkpoint_interval = a.backup_every
a.summary_interval = 5000
a.validation_interval = 1000
a.fine_tuning = True
a.input_wavs_dir = a.syn_dir.joinpath("audio")
a.input_mels_dir = a.syn_dir.joinpath("mels")
if h.num_gpus > 1:
init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
torch.cuda.manual_seed(h.seed)
device = torch.device('cuda:{:d}'.format(rank))
generator = FreGAN(h).to(device)
mpd = ResWiseMultiPeriodDiscriminator().to(device)
msd = ResWiseMultiScaleDiscriminator().to(device)
if rank == 0:
print(generator)
os.makedirs(a.checkpoint_path, exist_ok=True)
print("checkpoints directory : ", a.checkpoint_path)
if os.path.isdir(a.checkpoint_path):
cp_g = scan_checkpoint(a.checkpoint_path, 'g_fregan_')
cp_do = scan_checkpoint(a.checkpoint_path, 'do_fregan_')
steps = 0
if cp_g is None or cp_do is None:
state_dict_do = None
last_epoch = -1
else:
state_dict_g = load_checkpoint(cp_g, device)
state_dict_do = load_checkpoint(cp_do, device)
generator.load_state_dict(state_dict_g['generator'])
mpd.load_state_dict(state_dict_do['mpd'])
msd.load_state_dict(state_dict_do['msd'])
steps = state_dict_do['steps'] + 1
last_epoch = state_dict_do['epoch']
if h.num_gpus > 1:
generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
h.learning_rate, betas=[h.adam_b1, h.adam_b2])
if state_dict_do is not None:
optim_g.load_state_dict(state_dict_do['optim_g'])
optim_d.load_state_dict(state_dict_do['optim_d'])
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
training_filelist, validation_filelist = get_dataset_filelist(a)
trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)
train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
sampler=train_sampler,
batch_size=h.batch_size,
pin_memory=True,
drop_last=True)
if rank == 0:
validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
base_mels_path=a.input_mels_dir)
validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
sampler=None,
batch_size=1,
pin_memory=True,
drop_last=True)
sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
generator.train()
mpd.train()
msd.train()
for epoch in range(max(0, last_epoch), a.training_epochs):
if rank == 0:
start = time.time()
print("Epoch: {}".format(epoch + 1))
if h.num_gpus > 1:
train_sampler.set_epoch(epoch)
for i, batch in enumerate(train_loader):
if rank == 0:
start_b = time.time()
x, y, _, y_mel = batch
x = torch.autograd.Variable(x.to(device, non_blocking=True))
y = torch.autograd.Variable(y.to(device, non_blocking=True))
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
y = y.unsqueeze(1)
y_g_hat = generator(x)
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size,
h.win_size,
h.fmin, h.fmax_for_loss)
if steps > h.disc_start_step:
optim_d.zero_grad()
# MPD
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
# MSD
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
loss_disc_all = loss_disc_s + loss_disc_f
loss_disc_all.backward()
optim_d.step()
# Generator
optim_g.zero_grad()
# L1 Mel-Spectrogram Loss
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
# sc_loss, mag_loss = stft_loss(y_g_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1))
# loss_mel = h.lambda_aux * (sc_loss + mag_loss) # STFT Loss
if steps > h.disc_start_step:
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
loss_gen_all = loss_gen_s + loss_gen_f + (2 * (loss_fm_s + loss_fm_f)) + loss_mel
else:
loss_gen_all = loss_mel
loss_gen_all.backward()
optim_g.step()
if rank == 0:
# STDOUT logging
if steps % a.stdout_interval == 0:
with torch.no_grad():
mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
format(steps, loss_gen_all, mel_error, time.time() - start_b))
# checkpointing
if steps % a.checkpoint_interval == 0 and steps != 0:
checkpoint_path = "{}/g_fregan_{:08d}.pt".format(a.checkpoint_path, steps)
save_checkpoint(checkpoint_path,
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
checkpoint_path = "{}/do_fregan_{:08d}.pt".format(a.checkpoint_path, steps)
save_checkpoint(checkpoint_path,
{'mpd': (mpd.module if h.num_gpus > 1
else mpd).state_dict(),
'msd': (msd.module if h.num_gpus > 1
else msd).state_dict(),
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
'epoch': epoch})
# Tensorboard summary logging
if steps % a.summary_interval == 0:
sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
sw.add_scalar("training/mel_spec_error", mel_error, steps)
# Validation
if steps % a.validation_interval == 0: # and steps != 0:
generator.eval()
torch.cuda.empty_cache()
val_err_tot = 0
with torch.no_grad():
for j, batch in enumerate(validation_loader):
x, y, _, y_mel = batch
y_g_hat = generator(x.to(device))
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
h.hop_size, h.win_size,
h.fmin, h.fmax_for_loss)
#val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
if j <= 4:
if steps == 0:
sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
h.sampling_rate, h.hop_size, h.win_size,
h.fmin, h.fmax)
sw.add_figure('generated/y_hat_spec_{}'.format(j),
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
val_err = val_err_tot / (j + 1)
sw.add_scalar("validation/mel_spec_error", val_err, steps)
generator.train()
steps += 1
scheduler_g.step()
scheduler_d.step()
if rank == 0:
print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))

65
vocoder/fregan/utils.py Normal file
View File

@@ -0,0 +1,65 @@
import glob
import os
import matplotlib
import torch
from torch.nn.utils import weight_norm
matplotlib.use("Agg")
import matplotlib.pylab as plt
import shutil
def build_env(config, config_name, path):
t_path = os.path.join(path, config_name)
if config != t_path:
os.makedirs(path, exist_ok=True)
shutil.copyfile(config, os.path.join(path, config_name))
def plot_spectrogram(spectrogram):
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
interpolation='none')
plt.colorbar(im, ax=ax)
fig.canvas.draw()
plt.close()
return fig
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def apply_weight_norm(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
weight_norm(m)
def get_padding(kernel_size, dilation=1):
return int((kernel_size*dilation - dilation)/2)
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print("Loading '{}'".format(filepath))
checkpoint_dict = torch.load(filepath, map_location=device)
print("Complete.")
return checkpoint_dict
def save_checkpoint(filepath, obj):
print("Saving checkpoint to {}".format(filepath))
torch.save(obj, filepath)
print("Complete.")
def scan_checkpoint(cp_dir, prefix):
pattern = os.path.join(cp_dir, prefix + '????????.pt')
cp_list = glob.glob(pattern)
if len(cp_list) == 0:
return None
return sorted(cp_list)[-1]

View File

@@ -7,6 +7,7 @@
"adam_b2": 0.99,
"lr_decay": 0.999,
"seed": 1234,
"disc_start_step":0,
"upsample_rates": [5,5,4,2],
"upsample_kernel_sizes": [10,10,8,4],
@@ -27,5 +28,11 @@
"fmax": 7600,
"fmax_for_loss": null,
"num_workers": 4
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@@ -1,13 +1,6 @@
import os
import shutil
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def build_env(config, config_name, path):
t_path = os.path.join(path, config_name)
if config != t_path:

View File

@@ -3,7 +3,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import json
import torch
from vocoder.hifigan.env import AttrDict
from utils.util import AttrDict
from vocoder.hifigan.models import Generator
generator = None # type: Generator
@@ -19,12 +19,18 @@ def load_checkpoint(filepath, device):
return checkpoint_dict
def load_model(weights_fpath, config_fpath="./vocoder/saved_models/24k/config.json", verbose=True):
def load_model(weights_fpath, config_fpath=None, verbose=True):
global generator, _device, output_sample_rate
if verbose:
print("Building hifigan")
if config_fpath == None:
model_config_fpaths = list(weights_fpath.parent.rglob("*.json"))
if len(model_config_fpaths) > 0:
config_fpath = model_config_fpaths[0]
else:
config_fpath = "./vocoder/hifigan/config_16k_.json"
with open(config_fpath) as f:
data = f.read()
json_config = json.loads(data)

View File

@@ -12,7 +12,6 @@ from torch.utils.data import DistributedSampler, DataLoader
import torch.multiprocessing as mp
from torch.distributed import init_process_group
from torch.nn.parallel import DistributedDataParallel
from vocoder.hifigan.env import AttrDict, build_env
from vocoder.hifigan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
from vocoder.hifigan.models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\
discriminator_loss
@@ -52,8 +51,8 @@ def train(rank, a, h):
print("checkpoints directory : ", a.checkpoint_path)
if os.path.isdir(a.checkpoint_path):
cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
cp_g = scan_checkpoint(a.checkpoint_path, 'g_hifigan_')
cp_do = scan_checkpoint(a.checkpoint_path, 'do_hifigan_')
steps = 0
if cp_g is None or cp_do is None:
@@ -138,21 +137,21 @@ def train(rank, a, h):
y_g_hat = generator(x)
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
h.fmin, h.fmax_for_loss)
if steps > h.disc_start_step:
optim_d.zero_grad()
optim_d.zero_grad()
# MPD
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
# MPD
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
# MSD
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
# MSD
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
loss_disc_all = loss_disc_s + loss_disc_f
loss_disc_all = loss_disc_s + loss_disc_f
loss_disc_all.backward()
optim_d.step()
loss_disc_all.backward()
optim_d.step()
# Generator
optim_g.zero_grad()
@@ -160,13 +159,16 @@ def train(rank, a, h):
# L1 Mel-Spectrogram Loss
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
if steps > h.disc_start_step:
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
else:
loss_gen_all = loss_mel
loss_gen_all.backward()
optim_g.step()
@@ -182,10 +184,10 @@ def train(rank, a, h):
# checkpointing
if steps % a.checkpoint_interval == 0 and steps != 0:
checkpoint_path = "{}/g_{:08d}.pt".format(a.checkpoint_path, steps)
checkpoint_path = "{}/g_hifigan_{:08d}.pt".format(a.checkpoint_path, steps)
save_checkpoint(checkpoint_path,
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
checkpoint_path = "{}/do_hifigan_{:08d}.pt".format(a.checkpoint_path, steps)
save_checkpoint(checkpoint_path,
{'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(),
@@ -203,7 +205,7 @@ def train(rank, a, h):
checkpoint_path = "{}/g_hifigan.pt".format(a.checkpoint_path)
save_checkpoint(checkpoint_path,
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
checkpoint_path = "{}/do_hifigan".format(a.checkpoint_path)
checkpoint_path = "{}/do_hifigan.pt".format(a.checkpoint_path)
save_checkpoint(checkpoint_path,
{'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(),

View File

@@ -50,7 +50,7 @@ def save_checkpoint(filepath, obj):
def scan_checkpoint(cp_dir, prefix):
pattern = os.path.join(cp_dir, prefix + '????????')
pattern = os.path.join(cp_dir, prefix + '????????.pt')
cp_list = glob.glob(pattern)
if len(cp_list) == 0:
return None

View File

@@ -59,7 +59,7 @@ class WaveRNN(nn.Module) :
# Compute all gates for coarse and fine
u = F.sigmoid(R_u + I_u + self.bias_u)
r = F.sigmoid(R_r + I_r + self.bias_r)
e = F.tanh(r * R_e + I_e + self.bias_e)
e = torch.tanh(r * R_e + I_e + self.bias_e)
hidden = u * prev_hidden + (1. - u) * e
# Split the hidden state
@@ -118,7 +118,7 @@ class WaveRNN(nn.Module) :
# Compute the coarse gates
u = F.sigmoid(R_coarse_u + I_coarse_u + b_coarse_u)
r = F.sigmoid(R_coarse_r + I_coarse_r + b_coarse_r)
e = F.tanh(r * R_coarse_e + I_coarse_e + b_coarse_e)
e = torch.tanh(r * R_coarse_e + I_coarse_e + b_coarse_e)
hidden_coarse = u * hidden_coarse + (1. - u) * e
# Compute the coarse output
@@ -138,7 +138,7 @@ class WaveRNN(nn.Module) :
# Compute the fine gates
u = F.sigmoid(R_fine_u + I_fine_u + b_fine_u)
r = F.sigmoid(R_fine_r + I_fine_r + b_fine_r)
e = F.tanh(r * R_fine_e + I_fine_e + b_fine_e)
e = torch.tanh(r * R_fine_e + I_fine_e + b_fine_e)
hidden_fine = u * hidden_fine + (1. - u) * e
# Compute the fine output

View File

@@ -1,11 +1,13 @@
from utils.argutils import print_args
from vocoder.wavernn.train import train
from vocoder.hifigan.train import train as train_hifigan
from vocoder.hifigan.env import AttrDict
from vocoder.fregan.train import train as train_fregan
from utils.util import AttrDict
from pathlib import Path
import argparse
import json
import torch
import torch.multiprocessing as mp
if __name__ == "__main__":
parser = argparse.ArgumentParser(
@@ -61,11 +63,30 @@ if __name__ == "__main__":
# Process the arguments
if args.vocoder_type == "wavernn":
# Run the training wavernn
delattr(args, 'vocoder_type')
delattr(args, 'config')
train(**vars(args))
elif args.vocoder_type == "hifigan":
with open(args.config) as f:
json_config = json.load(f)
h = AttrDict(json_config)
train_hifigan(0, args, h)
if h.num_gpus > 1:
h.num_gpus = torch.cuda.device_count()
h.batch_size = int(h.batch_size / h.num_gpus)
print('Batch size per GPU :', h.batch_size)
mp.spawn(train_hifigan, nprocs=h.num_gpus, args=(args, h,))
else:
train_hifigan(0, args, h)
elif args.vocoder_type == "fregan":
with open('vocoder/fregan/config.json') as f:
json_config = json.load(f)
h = AttrDict(json_config)
if h.num_gpus > 1:
h.num_gpus = torch.cuda.device_count()
h.batch_size = int(h.batch_size / h.num_gpus)
print('Batch size per GPU :', h.batch_size)
mp.spawn(train_fregan, nprocs=h.num_gpus, args=(args, h,))
else:
train_fregan(0, args, h)

26
web.py
View File

@@ -1,11 +1,21 @@
from web import webApp
from gevent import pywsgi as wsgi
import os
import sys
import typer
cli = typer.Typer()
@cli.command()
def launch(port: int = typer.Option(8080, "--port", "-p")) -> None:
"""Start a graphical UI server for the opyrator.
The UI is auto-generated from the input- and output-schema of the given function.
"""
# Add the current working directory to the sys path
# This is required to resolve the opyrator path
sys.path.append(os.getcwd())
from mkgui.base.ui.streamlit_ui import launch_ui
launch_ui(port)
if __name__ == "__main__":
app = webApp()
host = app.config.get("HOST")
port = app.config.get("PORT")
print(f"Web server: http://{host}:{port}")
server = wsgi.WSGIServer((host, port), app)
server.serve_forever()
cli()

View File

@@ -94,7 +94,7 @@ def webApp():
embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
# Load input text
texts = request.form["text"].split("\n")
texts = filter(None, request.form["text"].split("\n"))
punctuation = '!,。、,' # punctuate and split/clean text
processed_texts = []
for text in texts:
@@ -109,7 +109,7 @@ def webApp():
spec = np.concatenate(specs, axis=1)
sample_rate = Synthesizer.sample_rate
if "vocoder" in request.form and request.form["vocoder"] == "WaveRNN":
wav = rnn_vocoder.infer_waveform(spec)
wav, sample_rate = rnn_vocoder.infer_waveform(spec)
else:
wav, sample_rate = gan_vocoder.infer_waveform(spec)
@@ -132,4 +132,4 @@ def webApp():
return app
if __name__ == "__main__":
webApp()
webApp()