mirror of
https://github.com/babysor/Realtime-Voice-Clone-Chinese.git
synced 2026-02-04 02:54:07 +08:00
Compare commits
103 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a191587417 | ||
|
|
d3ba597be9 | ||
|
|
6134c94b4d | ||
|
|
c04a1097bf | ||
|
|
9b4f8cc6c9 | ||
|
|
96993a5c61 | ||
|
|
70cc3988d3 | ||
|
|
c5998bfe71 | ||
|
|
c997dbdf66 | ||
|
|
47cc597ad0 | ||
|
|
8c895ed2c6 | ||
|
|
2e57bf3f11 | ||
|
|
11a5e2a141 | ||
|
|
7f0d983da7 | ||
|
|
0353bfc6e6 | ||
|
|
9ec114a7c1 | ||
|
|
ddf612e87c | ||
|
|
374cc89cfa | ||
|
|
6009da7072 | ||
|
|
1c61a601d1 | ||
|
|
02ee514aa3 | ||
|
|
6532d65153 | ||
|
|
3fe0690cc6 | ||
|
|
79f424d614 | ||
|
|
3c97d22938 | ||
|
|
fc26c38152 | ||
|
|
6c01b92703 | ||
|
|
c36f02634a | ||
|
|
b05e7441ff | ||
|
|
693de98f4d | ||
|
|
252a5e11b3 | ||
|
|
b617a87ee4 | ||
|
|
ad22997614 | ||
|
|
9e072c2619 | ||
|
|
b79e9d68e4 | ||
|
|
0536874dec | ||
|
|
4529479091 | ||
|
|
8ad9ba2b60 | ||
|
|
b56ec5ee1b | ||
|
|
0bc34a5bc9 | ||
|
|
875fe15069 | ||
|
|
4728863f9d | ||
|
|
a4daf42868 | ||
|
|
b50c7984ab | ||
|
|
26fe4a047d | ||
|
|
aff1b5313b | ||
|
|
7dca74e032 | ||
|
|
a37b26a89c | ||
|
|
902e1eb537 | ||
|
|
5c0e53a29a | ||
|
|
4edebdfeba | ||
|
|
6c8f3f4515 | ||
|
|
2bd323b7df | ||
|
|
3674d8b5c6 | ||
|
|
80aaf32164 | ||
|
|
c396792b22 | ||
|
|
7c58fe01d1 | ||
|
|
724194a4de | ||
|
|
31bc6656c3 | ||
|
|
aa35fb3139 | ||
|
|
727eafc51b | ||
|
|
d328ecba81 | ||
|
|
fad574118c | ||
|
|
b0c156a537 | ||
|
|
724809abf4 | ||
|
|
05cd1a54ea | ||
|
|
245099c740 | ||
|
|
6dd2af49fe | ||
|
|
8b43ec9a64 | ||
|
|
2a99f0ff05 | ||
|
|
a824b54122 | ||
|
|
81befb91b0 | ||
|
|
e2017d0314 | ||
|
|
547ac816df | ||
|
|
6b4ab39601 | ||
|
|
b46e7a7866 | ||
|
|
8a384a1191 | ||
|
|
11154783d8 | ||
|
|
d52db0444e | ||
|
|
790d11a58b | ||
|
|
cb82fcfe58 | ||
|
|
26ecb7546d | ||
|
|
f64914fca8 | ||
|
|
512da52775 | ||
|
|
9c219f05c2 | ||
|
|
4d9e460063 | ||
|
|
0d0b55d3e9 | ||
|
|
4acfee2a64 | ||
|
|
99269b2046 | ||
|
|
28e6bce570 | ||
|
|
5238c43799 | ||
|
|
2dd76e1b8d | ||
|
|
ddd478c0ad | ||
|
|
4178416385 | ||
|
|
3fbe03f2ff | ||
|
|
222e302274 | ||
|
|
32b9755cbe | ||
|
|
78fcfc4651 | ||
|
|
45bc43bf3c | ||
|
|
dacedfa9cc | ||
|
|
b60b75ea89 | ||
|
|
c4a8c72b83 | ||
|
|
8195a55a25 |
1
.github/FUNDING.yml
vendored
Normal file
1
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1 @@
|
||||
github: babysor
|
||||
17
.github/ISSUE_TEMPLATE/issue.md
vendored
Normal file
17
.github/ISSUE_TEMPLATE/issue.md
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
---
|
||||
name: Issue
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Summary[问题简述(一句话)]**
|
||||
A clear and concise description of what the issue is.
|
||||
|
||||
**Env & To Reproduce[复现与环境]**
|
||||
描述你用的环境、代码版本、模型
|
||||
|
||||
**Screenshots[截图(如有)]**
|
||||
If applicable, add screenshots to help
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -13,8 +13,9 @@
|
||||
*.bbl
|
||||
*.bcf
|
||||
*.toc
|
||||
*.wav
|
||||
*.sh
|
||||
synthesizer/saved_models/*
|
||||
vocoder/saved_models/*
|
||||
!vocoder/saved_models/pretrained/*
|
||||
*/saved_models
|
||||
!vocoder/saved_models/pretrained/**
|
||||
!encoder/saved_models/pretrained.pt
|
||||
wavs
|
||||
log
|
||||
119
.vscode/launch.json
vendored
119
.vscode/launch.json
vendored
@@ -1,48 +1,73 @@
|
||||
{
|
||||
// 使用 IntelliSense 了解相关属性。
|
||||
// 悬停以查看现有属性的描述。
|
||||
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Syn Preprocess",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "pre.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"D:\\ttsdata\\BZNSYP", "-d", "BZNSYP"
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Python: Vocoder Preprocess",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "vocoder_preprocess.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"..\\..\\chs1"
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Python: Vocoder Train",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "vocoder_train.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"dev", "..\\..\\chs1"
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Python: demo box",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "demo_toolbox.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"-d", "..\\..\\chs"
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
// 使用 IntelliSense 了解相关属性。
|
||||
// 悬停以查看现有属性的描述。
|
||||
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Web",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "web.py",
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "Python: Vocoder Preprocess",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "vocoder_preprocess.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["..\\audiodata"]
|
||||
},
|
||||
{
|
||||
"name": "Python: Vocoder Train",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "vocoder_train.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["dev", "..\\audiodata"]
|
||||
},
|
||||
{
|
||||
"name": "Python: Demo Box",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "demo_toolbox.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["-d","..\\audiodata"]
|
||||
},
|
||||
{
|
||||
"name": "Python: Demo Box VC",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "demo_toolbox.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["-d","..\\audiodata","-vc"]
|
||||
},
|
||||
{
|
||||
"name": "Python: Synth Train",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "synthesizer_train.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["my_run", "..\\"]
|
||||
},
|
||||
{
|
||||
"name": "Python: PPG Convert",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "run.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["-c", ".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml",
|
||||
"-m", ".\\ppg2mel\\saved_models\\best_loss_step_304000.pth", "--wav_dir", ".\\wavs\\input", "--ref_wav_path", ".\\wavs\\pkq.mp3", "-o", ".\\wavs\\output\\"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "GUI",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "mkgui\\base\\_cli.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": []
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"python.formatting.provider": "black"
|
||||
}
|
||||
219
README-CN.md
219
README-CN.md
@@ -5,83 +5,222 @@
|
||||
|
||||
### [English](README.md) | 中文
|
||||
|
||||
### [DEMO VIDEO](https://www.bilibili.com/video/BV1sA411P7wM/)
|
||||
### [DEMO VIDEO](https://www.bilibili.com/video/BV17Q4y1B7mY/) | [Wiki教程](https://github.com/babysor/MockingBird/wiki/Quick-Start-(Newbie)) | [训练教程](https://vaj2fgg8yn.feishu.cn/docs/doccn7kAbr3SJz0KM0SIDJ0Xnhd)
|
||||
|
||||
## 特性
|
||||
🌍 **中文** 支持普通话并使用多种中文数据集进行测试:adatatang_200zh, magicdata, aishell3
|
||||
🌍 **中文** 支持普通话并使用多种中文数据集进行测试:aidatatang_200zh, magicdata, aishell3, biaobei, MozillaCommonVoice, data_aishell 等
|
||||
|
||||
🤩 **PyTorch** 适用于 pytorch,已在 1.9.0 版本(最新于 2021 年 8 月)中测试,GPU Tesla T4 和 GTX 2060
|
||||
|
||||
🌍 **Windows + Linux** 在修复 nits 后在 Windows 操作系统和 linux 操作系统中进行测试
|
||||
🌍 **Windows + Linux** 可在 Windows 操作系统和 linux 操作系统中运行(苹果系统M1版也有社区成功运行案例)
|
||||
|
||||
🤩 **Easy & Awesome** 仅使用新训练的合成器(synthesizer)就有良好效果,复用预训练的编码器/声码器
|
||||
🤩 **Easy & Awesome** 仅需下载或新训练合成器(synthesizer)就有良好效果,复用预训练的编码器/声码器,或实时的HiFi-GAN作为vocoder
|
||||
|
||||
## 快速开始
|
||||
> 0训练新手友好版可以参考 [Quick Start (Newbie)](https://github.com/babysor/Realtime-Voice-Clone-Chinese/wiki/Quick-Start-(Newbie))
|
||||
🌍 **Webserver Ready** 可伺服你的训练结果,供远程调用
|
||||
|
||||
### 进行中的工作
|
||||
* GUI/客户端大升级与合并
|
||||
[X] 初始化框架 `./mkgui` (基于streamlit + fastapi)和 [技术设计](https://vaj2fgg8yn.feishu.cn/docs/doccnvotLWylBub8VJIjKzoEaee)
|
||||
[X] 增加 Voice Cloning and Conversion的演示页面
|
||||
[X] 增加Voice Conversion的预处理preprocessing 和训练 training 页面
|
||||
[ ] 增加其他的的预处理preprocessing 和训练 training 页面
|
||||
* 模型后端基于ESPnet2升级
|
||||
|
||||
|
||||
## 开始
|
||||
### 1. 安装要求
|
||||
> 按照原始存储库测试您是否已准备好所有环境。
|
||||
**Python 3.7 或更高版本** 需要运行工具箱。
|
||||
运行工具箱(demo_toolbox.py)需要 **Python 3.7 或更高版本** 。
|
||||
|
||||
* 安装 [PyTorch](https://pytorch.org/get-started/locally/)。
|
||||
> 如果在用 pip 方式安装的时候出现 `ERROR: Could not find a version that satisfies the requirement torch==1.9.0+cu102 (from versions: 0.1.2, 0.1.2.post1, 0.1.2.post2)` 这个错误可能是 python 版本过低,3.9 可以安装成功
|
||||
* 安装 [ffmpeg](https://ffmpeg.org/download.html#get-packages)。
|
||||
* 运行`pip install -r requirements.txt` 来安装剩余的必要包。
|
||||
* 安装 webrtcvad 用 `pip install webrtcvad-wheels`。
|
||||
* 安装 webrtcvad `pip install webrtcvad-wheels`。
|
||||
|
||||
### 2. 使用数据集训练合成器
|
||||
### 2. 准备预训练模型
|
||||
考虑训练您自己专属的模型或者下载社区他人训练好的模型:
|
||||
> 近期创建了[知乎专题](https://www.zhihu.com/column/c_1425605280340504576) 将不定期更新炼丹小技巧or心得,也欢迎提问
|
||||
#### 2.1 使用数据集自己训练encoder模型 (可选)
|
||||
|
||||
* 进行音频和梅尔频谱图预处理:
|
||||
`python encoder_preprocess.py <datasets_root>`
|
||||
使用`-d {dataset}` 指定数据集,支持 librispeech_other,voxceleb1,aidatatang_200zh,使用逗号分割处理多数据集。
|
||||
* 训练encoder: `python encoder_train.py my_run <datasets_root>/SV2TTS/encoder`
|
||||
> 训练encoder使用了visdom。你可以加上`-no_visdom`禁用visdom,但是有可视化会更好。在单独的命令行/进程中运行"visdom"来启动visdom服务器。
|
||||
|
||||
#### 2.2 使用数据集自己训练合成器模型(与2.3二选一)
|
||||
* 下载 数据集并解压:确保您可以访问 *train* 文件夹中的所有音频文件(如.wav)
|
||||
* 进行音频和梅尔频谱图预处理:
|
||||
`python pre.py <datasets_root>`
|
||||
|
||||
可以传入参数 --dataset `{dataset}` 支持 adatatang_200zh, magicdata, aishell3, BZNSYP
|
||||
`python pre.py <datasets_root> -d {dataset} -n {number}`
|
||||
可传入参数:
|
||||
* `-d {dataset}` 指定数据集,支持 aidatatang_200zh, magicdata, aishell3, data_aishell, 不传默认为aidatatang_200zh
|
||||
* `-n {number}` 指定并行数,CPU 11770k + 32GB实测10没有问题
|
||||
> 假如你下载的 `aidatatang_200zh`文件放在D盘,`train`文件路径为 `D:\data\aidatatang_200zh\corpus\train` , 你的`datasets_root`就是 `D:\data\`
|
||||
|
||||
>假如發生 `頁面文件太小,無法完成操作`,請參考這篇[文章](https://blog.csdn.net/qq_17755303/article/details/112564030),將虛擬內存更改為100G(102400),例如:档案放置D槽就更改D槽的虚拟内存
|
||||
|
||||
* 训练合成器:
|
||||
`python synthesizer_train.py mandarin <datasets_root>/SV2TTS/synthesizer`
|
||||
|
||||
* 当您在训练文件夹 *synthesizer/saved_models/* 中看到注意线显示和损失满足您的需要时,请转到下一步。
|
||||
> 仅供参考,我的注意力是在 18k 步之后出现的,并且在 50k 步之后损失变得低于 0.4
|
||||

|
||||

|
||||
* 当您在训练文件夹 *synthesizer/saved_models/* 中看到注意线显示和损失满足您的需要时,请转到`启动程序`一步。
|
||||
|
||||
### 2.2 使用预先训练好的合成器
|
||||
> 实在没有设备或者不想慢慢调试,可以使用网友贡献的模型(欢迎持续分享):
|
||||
#### 2.3使用社区预先训练好的合成器(与2.2二选一)
|
||||
> 当实在没有设备或者不想慢慢调试,可以使用社区贡献的模型(欢迎持续分享):
|
||||
|
||||
| 作者 | 下载链接 | 效果预览 |
|
||||
| --- | ----------- | ----- |
|
||||
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ 提取码:2021 | https://www.bilibili.com/video/BV1uh411B7AD/)
|
||||
| 作者 | 下载链接 | 效果预览 | 信息 |
|
||||
| --- | ----------- | ----- | ----- |
|
||||
| 作者 | https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g [百度盘链接](https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g) 4j5d | | 75k steps 用3个开源数据集混合训练
|
||||
| 作者 | https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw [百度盘链接](https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw) 提取码:om7f | | 25k steps 用3个开源数据集混合训练, 切换到tag v0.0.1使用
|
||||
|@FawenYo | https://drive.google.com/file/d/1H-YGOUHpmqKxJ9FRc6vAjPuqQki24UbC/view?usp=sharing [百度盘链接](https://pan.baidu.com/s/1vSYXO4wsLyjnF3Unl-Xoxg) 提取码:1024 | [input](https://github.com/babysor/MockingBird/wiki/audio/self_test.mp3) [output](https://github.com/babysor/MockingBird/wiki/audio/export.wav) | 200k steps 台湾口音需切换到tag v0.0.1使用
|
||||
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ 提取码:2021 | https://www.bilibili.com/video/BV1uh411B7AD/ | 150k steps 注意:根据[issue](https://github.com/babysor/MockingBird/issues/37)修复 并切换到tag v0.0.1使用
|
||||
|
||||
### 2.3 训练声码器 (Optional)
|
||||
#### 2.4训练声码器 (可选)
|
||||
对效果影响不大,已经预置3款,如果希望自己训练可以参考以下命令。
|
||||
* 预处理数据:
|
||||
`python vocoder_preprocess.py <datasets_root>`
|
||||
`python vocoder_preprocess.py <datasets_root> -m <synthesizer_model_path>`
|
||||
> `<datasets_root>`替换为你的数据集目录,`<synthesizer_model_path>`替换为一个你最好的synthesizer模型目录,例如 *sythensizer\saved_mode\xxx*
|
||||
|
||||
* 训练声码器:
|
||||
`python vocoder_train.py mandarin <datasets_root>`
|
||||
|
||||
### 3. 启动工具箱
|
||||
然后您可以尝试使用工具箱:
|
||||
* 训练wavernn声码器:
|
||||
`python vocoder_train.py <trainid> <datasets_root>`
|
||||
> `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
|
||||
|
||||
* 训练hifigan声码器:
|
||||
`python vocoder_train.py <trainid> <datasets_root> hifigan`
|
||||
> `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
|
||||
|
||||
### 3. 启动程序或工具箱
|
||||
您可以尝试使用以下命令:
|
||||
|
||||
### 3.1 启动Web程序(v2):
|
||||
`python web.py`
|
||||
运行成功后在浏览器打开地址, 默认为 `http://localhost:8080`
|
||||
> * 仅支持手动新录音(16khz), 不支持超过4MB的录音,最佳长度在5~15秒
|
||||
|
||||
### 3.2 启动工具箱:
|
||||
`python demo_toolbox.py -d <datasets_root>`
|
||||
> 请指定一个可用的数据集文件路径,如果有支持的数据集则会自动加载供调试,也同时会作为手动录制音频的存储目录。
|
||||
|
||||
> Good news🤩: 可直接使用中文
|
||||
<img width="1042" alt="d48ea37adf3660e657cfb047c10edbc" src="https://user-images.githubusercontent.com/7423248/134275227-c1ddf154-f118-4b77-8949-8c4c7daf25f0.png">
|
||||
|
||||
## TODO
|
||||
- [X] 允许直接使用中文
|
||||
- [X] 添加演示视频
|
||||
- [X] 添加对更多数据集的支持
|
||||
- [X] 上传预训练模型
|
||||
- [ ] 支持parallel tacotron
|
||||
- [ ] 服务化与容器化
|
||||
- [ ] 🙏 欢迎补充
|
||||
### 4. 番外:语音转换Voice Conversion(PPG based)
|
||||
想像柯南拿着变声器然后发出毛利小五郎的声音吗?本项目现基于PPG-VC,引入额外两个模块(PPG extractor + PPG2Mel), 可以实现变声功能。(文档不全,尤其是训练部分,正在努力补充中)
|
||||
#### 4.0 准备环境
|
||||
* 确保项目以上环境已经安装ok,运行`pip install -r requirements_vc.txt` 来安装剩余的必要包。
|
||||
* 下载以下模型 链接:https://pan.baidu.com/s/1bl_x_DHJSAUyN2fma-Q_Wg
|
||||
提取码:gh41
|
||||
* 24K采样率专用的vocoder(hifigan)到 *vocoder\saved_mode\xxx*
|
||||
* 预训练的ppg特征encoder(ppg_extractor)到 *ppg_extractor\saved_mode\xxx*
|
||||
* 预训练的PPG2Mel到 *ppg2mel\saved_mode\xxx*
|
||||
|
||||
#### 4.1 使用数据集自己训练PPG2Mel模型 (可选)
|
||||
|
||||
* 下载aidatatang_200zh数据集并解压:确保您可以访问 *train* 文件夹中的所有音频文件(如.wav)
|
||||
* 进行音频和梅尔频谱图预处理:
|
||||
`python pre4ppg.py <datasets_root> -d {dataset} -n {number}`
|
||||
可传入参数:
|
||||
* `-d {dataset}` 指定数据集,支持 aidatatang_200zh, 不传默认为aidatatang_200zh
|
||||
* `-n {number}` 指定并行数,CPU 11770k在8的情况下,需要运行12到18小时!待优化
|
||||
> 假如你下载的 `aidatatang_200zh`文件放在D盘,`train`文件路径为 `D:\data\aidatatang_200zh\corpus\train` , 你的`datasets_root`就是 `D:\data\`
|
||||
|
||||
* 训练合成器, 注意在上一步先下载好`ppg2mel.yaml`, 修改里面的地址指向预训练好的文件夹:
|
||||
`python ppg2mel_train.py --config .\ppg2mel\saved_models\ppg2mel.yaml --oneshotvc `
|
||||
* 如果想要继续上一次的训练,可以通过`--load .\ppg2mel\saved_models\<old_pt_file>` 参数指定一个预训练模型文件。
|
||||
|
||||
#### 4.2 启动工具箱VC模式
|
||||
您可以尝试使用以下命令:
|
||||
`python demo_toolbox.py vc -d <datasets_root>`
|
||||
> 请指定一个可用的数据集文件路径,如果有支持的数据集则会自动加载供调试,也同时会作为手动录制音频的存储目录。
|
||||
<img width="971" alt="微信图片_20220305005351" src="https://user-images.githubusercontent.com/7423248/156805733-2b093dbc-d989-4e68-8609-db11f365886a.png">
|
||||
|
||||
## 引用及论文
|
||||
> 该库一开始从仅支持英语的[Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) 分叉出来的,鸣谢作者。
|
||||
|
||||
| URL | Designation | 标题 | 实现源码 |
|
||||
| --- | ----------- | ----- | --------------------- |
|
||||
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | **SV2TTS** | **Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis** | This repo |
|
||||
| [1803.09017](https://arxiv.org/abs/1803.09017) | GlobalStyleToken (synthesizer)| Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis | 本代码库 |
|
||||
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | 本代码库 |
|
||||
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | SV2TTS | Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis | 本代码库 |
|
||||
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|
||||
|[1703.10135](https://arxiv.org/pdf/1703.10135.pdf) | Tacotron (synthesizer) | Tacotron: Towards End-to-End Speech Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN)
|
||||
|[1710.10467](https://arxiv.org/pdf/1710.10467.pdf) | GE2E (encoder)| Generalized End-To-End Loss for Speaker Verification | 本代码库 |
|
||||
|[1710.10467](https://arxiv.org/pdf/1710.10467.pdf) | GE2E (encoder)| Generalized End-To-End Loss for Speaker Verification | 本代码库 |
|
||||
|
||||
## 常見問題(FQ&A)
|
||||
#### 1.數據集哪裡下載?
|
||||
| 数据集 | OpenSLR地址 | 其他源 (Google Drive, Baidu网盘等) |
|
||||
| --- | ----------- | ---------------|
|
||||
| aidatatang_200zh | [OpenSLR](http://www.openslr.org/62/) | [Google Drive](https://drive.google.com/file/d/110A11KZoVe7vy6kXlLb6zVPLb_J91I_t/view?usp=sharing) |
|
||||
| magicdata | [OpenSLR](http://www.openslr.org/68/) | [Google Drive (Dev set)](https://drive.google.com/file/d/1g5bWRUSNH68ycC6eNvtwh07nX3QhOOlo/view?usp=sharing) |
|
||||
| aishell3 | [OpenSLR](https://www.openslr.org/93/) | [Google Drive](https://drive.google.com/file/d/1shYp_o4Z0X0cZSKQDtFirct2luFUwKzZ/view?usp=sharing) |
|
||||
| data_aishell | [OpenSLR](https://www.openslr.org/33/) | |
|
||||
> 解壓 aidatatang_200zh 後,還需將 `aidatatang_200zh\corpus\train`下的檔案全選解壓縮
|
||||
|
||||
#### 2.`<datasets_root>`是什麼意思?
|
||||
假如數據集路徑為 `D:\data\aidatatang_200zh`,那麼 `<datasets_root>`就是 `D:\data`
|
||||
|
||||
#### 3.訓練模型顯存不足
|
||||
訓練合成器時:將 `synthesizer/hparams.py`中的batch_size參數調小
|
||||
```
|
||||
//調整前
|
||||
tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule
|
||||
(2, 5e-4, 40_000, 12), # (r, lr, step, batch_size)
|
||||
(2, 2e-4, 80_000, 12), #
|
||||
(2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames
|
||||
(2, 3e-5, 320_000, 12), # synthesized for each decoder iteration)
|
||||
(2, 1e-5, 640_000, 12)], # lr = learning rate
|
||||
//調整後
|
||||
tts_schedule = [(2, 1e-3, 20_000, 8), # Progressive training schedule
|
||||
(2, 5e-4, 40_000, 8), # (r, lr, step, batch_size)
|
||||
(2, 2e-4, 80_000, 8), #
|
||||
(2, 1e-4, 160_000, 8), # r = reduction factor (# of mel frames
|
||||
(2, 3e-5, 320_000, 8), # synthesized for each decoder iteration)
|
||||
(2, 1e-5, 640_000, 8)], # lr = learning rate
|
||||
```
|
||||
|
||||
聲碼器-預處理數據集時:將 `synthesizer/hparams.py`中的batch_size參數調小
|
||||
```
|
||||
//調整前
|
||||
### Data Preprocessing
|
||||
max_mel_frames = 900,
|
||||
rescale = True,
|
||||
rescaling_max = 0.9,
|
||||
synthesis_batch_size = 16, # For vocoder preprocessing and inference.
|
||||
//調整後
|
||||
### Data Preprocessing
|
||||
max_mel_frames = 900,
|
||||
rescale = True,
|
||||
rescaling_max = 0.9,
|
||||
synthesis_batch_size = 8, # For vocoder preprocessing and inference.
|
||||
```
|
||||
|
||||
聲碼器-訓練聲碼器時:將 `vocoder/wavernn/hparams.py`中的batch_size參數調小
|
||||
```
|
||||
//調整前
|
||||
# Training
|
||||
voc_batch_size = 100
|
||||
voc_lr = 1e-4
|
||||
voc_gen_at_checkpoint = 5
|
||||
voc_pad = 2
|
||||
|
||||
//調整後
|
||||
# Training
|
||||
voc_batch_size = 6
|
||||
voc_lr = 1e-4
|
||||
voc_gen_at_checkpoint = 5
|
||||
voc_pad =2
|
||||
```
|
||||
|
||||
#### 4.碰到`RuntimeError: Error(s) in loading state_dict for Tacotron: size mismatch for encoder.embedding.weight: copying a param with shape torch.Size([70, 512]) from checkpoint, the shape in current model is torch.Size([75, 512]).`
|
||||
請參照 issue [#37](https://github.com/babysor/MockingBird/issues/37)
|
||||
|
||||
#### 5.如何改善CPU、GPU佔用率?
|
||||
適情況調整batch_size參數來改善
|
||||
|
||||
#### 6.發生 `頁面文件太小,無法完成操作`
|
||||
請參考這篇[文章](https://blog.csdn.net/qq_17755303/article/details/112564030),將虛擬內存更改為100G(102400),例如:档案放置D槽就更改D槽的虚拟内存
|
||||
|
||||
#### 7.什么时候算训练完成?
|
||||
首先一定要出现注意力模型,其次是loss足够低,取决于硬件设备和数据集。拿本人的供参考,我的注意力是在 18k 步之后出现的,并且在 50k 步之后损失变得低于 0.4
|
||||

|
||||
|
||||

|
||||
|
||||
|
||||
167
README.md
167
README.md
@@ -6,16 +6,25 @@
|
||||
> English | [中文](README-CN.md)
|
||||
|
||||
## Features
|
||||
🌍 **Chinese** supported mandarin and tested with multiple datasets: aidatatang_200zh, magicdata, aishell3
|
||||
🌍 **Chinese** supported mandarin and tested with multiple datasets: aidatatang_200zh, magicdata, aishell3, data_aishell, and etc.
|
||||
|
||||
🤩 **PyTorch** worked for pytorch, tested in version of 1.9.0(latest in August 2021), with GPU Tesla T4 and GTX 2060
|
||||
|
||||
🌍 **Windows + Linux** tested in both Windows OS and linux OS after fixing nits
|
||||
🌍 **Windows + Linux** run in both Windows OS and linux OS (even in M1 MACOS)
|
||||
|
||||
🤩 **Easy & Awesome** effect with only newly-trained synthesizer, by reusing the pretrained encoder/vocoder
|
||||
|
||||
🌍 **Webserver Ready** to serve your result with remote calling
|
||||
|
||||
### [DEMO VIDEO](https://www.bilibili.com/video/BV1sA411P7wM/)
|
||||
### [DEMO VIDEO](https://www.bilibili.com/video/BV17Q4y1B7mY/)
|
||||
|
||||
### Ongoing Works(Helps Needed)
|
||||
* Major upgrade on GUI/Client and unifying web and toolbox
|
||||
[X] Init framework `./mkgui` and [tech design](https://vaj2fgg8yn.feishu.cn/docs/doccnvotLWylBub8VJIjKzoEaee)
|
||||
[X] Add demo part of Voice Cloning and Conversion
|
||||
[X] Add preprocessing and training for Voice Conversion
|
||||
[ ] Add preprocessing and training for Encoder/Synthesizer/Vocoder
|
||||
* Major upgrade on model backend based on ESPnet2(not yet started)
|
||||
|
||||
## Quick Start
|
||||
|
||||
@@ -29,64 +38,146 @@
|
||||
* Run `pip install -r requirements.txt` to install the remaining necessary packages.
|
||||
* Install webrtcvad `pip install webrtcvad-wheels`(If you need)
|
||||
> Note that we are using the pretrained encoder/vocoder but synthesizer, since the original model is incompatible with the Chinese sympols. It means the demo_cli is not working at this moment.
|
||||
### 2. Train synthesizer with your dataset
|
||||
* Download aidatatang_200zh or other dataset and unzip: make sure you can access all .wav in *train* folder
|
||||
### 2. Prepare your models
|
||||
You can either train your models or use existing ones:
|
||||
|
||||
#### 2.1 Train encoder with your dataset (Optional)
|
||||
|
||||
* Preprocess with the audios and the mel spectrograms:
|
||||
`python encoder_preprocess.py <datasets_root>` Allowing parameter `--dataset {dataset}` to support the datasets you want to preprocess. Only the train set of these datasets will be used. Possible names: librispeech_other, voxceleb1, voxceleb2. Use comma to sperate multiple datasets.
|
||||
|
||||
* Train the encoder: `python encoder_train.py my_run <datasets_root>/SV2TTS/encoder`
|
||||
> For training, the encoder uses visdom. You can disable it with `--no_visdom`, but it's nice to have. Run "visdom" in a separate CLI/process to start your visdom server.
|
||||
|
||||
#### 2.2 Train synthesizer with your dataset
|
||||
* Download dataset and unzip: make sure you can access all .wav in folder
|
||||
* Preprocess with the audios and the mel spectrograms:
|
||||
`python pre.py <datasets_root>`
|
||||
|
||||
Allowing parameter `--dataset {dataset}` to support adatatang_200zh, magicdata, aishell3, BZNSYP
|
||||
|
||||
>If it happens `the page file is too small to complete the operation`, please refer to this [video](https://www.youtube.com/watch?v=Oh6dga-Oy10&ab_channel=CodeProf) and change the virtual memory to 100G (102400), for example : When the file is placed in the D disk, the virtual memory of the D disk is changed.
|
||||
|
||||
Allowing parameter `--dataset {dataset}` to support aidatatang_200zh, magicdata, aishell3, data_aishell, etc.If this parameter is not passed, the default dataset will be aidatatang_200zh.
|
||||
|
||||
* Train the synthesizer:
|
||||
`python synthesizer_train.py mandarin <datasets_root>/SV2TTS/synthesizer`
|
||||
|
||||
* Go to next step when you see attention line show and loss meet your need in training folder *synthesizer/saved_models/*.
|
||||
> FYI, my attention came after 18k steps and loss became lower than 0.4 after 50k steps.
|
||||

|
||||

|
||||
|
||||
### 2.2 Use pretrained model of synthesizer
|
||||
#### 2.3 Use pretrained model of synthesizer
|
||||
> Thanks to the community, some models will be shared:
|
||||
|
||||
| author | Download link | Previow Video |
|
||||
| --- | ----------- | ----- |
|
||||
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ code:2021 | https://www.bilibili.com/video/BV1uh411B7AD/
|
||||
| author | Download link | Preview Video | Info |
|
||||
| --- | ----------- | ----- |----- |
|
||||
| @author | https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g [Baidu](https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g) 4j5d | | 75k steps trained by multiple datasets
|
||||
| @author | https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw [Baidu](https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw) code:om7f | | 25k steps trained by multiple datasets, only works under version 0.0.1
|
||||
|@FawenYo | https://drive.google.com/file/d/1H-YGOUHpmqKxJ9FRc6vAjPuqQki24UbC/view?usp=sharing https://u.teknik.io/AYxWf.pt | [input](https://github.com/babysor/MockingBird/wiki/audio/self_test.mp3) [output](https://github.com/babysor/MockingBird/wiki/audio/export.wav) | 200k steps with local accent of Taiwan, only works under version 0.0.1
|
||||
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ code:2021 | https://www.bilibili.com/video/BV1uh411B7AD/ | only works under version 0.0.1
|
||||
|
||||
> A link to my early trained model: [Baidu Yun](https://pan.baidu.com/s/10t3XycWiNIg5dN5E_bMORQ)
|
||||
Code:aid4
|
||||
|
||||
### 2.3 Train vocoder (Optional)
|
||||
#### 2.4 Train vocoder (Optional)
|
||||
> note: vocoder has little difference in effect, so you may not need to train a new one.
|
||||
* Preprocess the data:
|
||||
`python vocoder_preprocess.py <datasets_root>`
|
||||
`python vocoder_preprocess.py <datasets_root> -m <synthesizer_model_path>`
|
||||
> `<datasets_root>` replace with your dataset root,`<synthesizer_model_path>`replace with directory of your best trained models of sythensizer, e.g. *sythensizer\saved_mode\xxx*
|
||||
|
||||
* Train the vocoder:
|
||||
* Train the wavernn vocoder:
|
||||
`python vocoder_train.py mandarin <datasets_root>`
|
||||
|
||||
### 3. Launch the Toolbox
|
||||
* Train the hifigan vocoder
|
||||
`python vocoder_train.py mandarin <datasets_root> hifigan`
|
||||
|
||||
### 3. Launch
|
||||
#### 3.1 Using the web server
|
||||
You can then try to run:`python web.py` and open it in browser, default as `http://localhost:8080`
|
||||
|
||||
#### 3.2 Using the Toolbox
|
||||
You can then try the toolbox:
|
||||
|
||||
`python demo_toolbox.py -d <datasets_root>`
|
||||
or
|
||||
`python demo_toolbox.py`
|
||||
|
||||
> Good news🤩: Chinese Characters are supported
|
||||
|
||||
## TODO
|
||||
- [x] Add demo video
|
||||
- [X] Add support for more dataset
|
||||
- [X] Upload pretrained model
|
||||
- [ ] Support parallel tacotron
|
||||
- [ ] Service orianted and docterize
|
||||
- 🙏 Welcome to add more
|
||||
|
||||
## Reference
|
||||
> This repository is forked from [Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) which only support English.
|
||||
|
||||
| URL | Designation | Title | Implementation source |
|
||||
| --- | ----------- | ----- | --------------------- |
|
||||
| [1803.09017](https://arxiv.org/abs/1803.09017) | GlobalStyleToken (synthesizer)| Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis | This repo |
|
||||
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | This repo |
|
||||
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | **SV2TTS** | **Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis** | This repo |
|
||||
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|
||||
|[1703.10135](https://arxiv.org/pdf/1703.10135.pdf) | Tacotron (synthesizer) | Tacotron: Towards End-to-End Speech Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN)
|
||||
|[1710.10467](https://arxiv.org/pdf/1710.10467.pdf) | GE2E (encoder)| Generalized End-To-End Loss for Speaker Verification | This repo |
|
||||
|[1710.10467](https://arxiv.org/pdf/1710.10467.pdf) | GE2E (encoder)| Generalized End-To-End Loss for Speaker Verification | This repo |
|
||||
|
||||
## F Q&A
|
||||
#### 1.Where can I download the dataset?
|
||||
| Dataset | Original Source | Alternative Sources |
|
||||
| --- | ----------- | ---------------|
|
||||
| aidatatang_200zh | [OpenSLR](http://www.openslr.org/62/) | [Google Drive](https://drive.google.com/file/d/110A11KZoVe7vy6kXlLb6zVPLb_J91I_t/view?usp=sharing) |
|
||||
| magicdata | [OpenSLR](http://www.openslr.org/68/) | [Google Drive (Dev set)](https://drive.google.com/file/d/1g5bWRUSNH68ycC6eNvtwh07nX3QhOOlo/view?usp=sharing) |
|
||||
| aishell3 | [OpenSLR](https://www.openslr.org/93/) | [Google Drive](https://drive.google.com/file/d/1shYp_o4Z0X0cZSKQDtFirct2luFUwKzZ/view?usp=sharing) |
|
||||
| data_aishell | [OpenSLR](https://www.openslr.org/33/) | |
|
||||
> After unzip aidatatang_200zh, you need to unzip all the files under `aidatatang_200zh\corpus\train`
|
||||
|
||||
#### 2.What is`<datasets_root>`?
|
||||
If the dataset path is `D:\data\aidatatang_200zh`,then `<datasets_root>` is`D:\data`
|
||||
|
||||
#### 3.Not enough VRAM
|
||||
Train the synthesizer:adjust the batch_size in `synthesizer/hparams.py`
|
||||
```
|
||||
//Before
|
||||
tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule
|
||||
(2, 5e-4, 40_000, 12), # (r, lr, step, batch_size)
|
||||
(2, 2e-4, 80_000, 12), #
|
||||
(2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames
|
||||
(2, 3e-5, 320_000, 12), # synthesized for each decoder iteration)
|
||||
(2, 1e-5, 640_000, 12)], # lr = learning rate
|
||||
//After
|
||||
tts_schedule = [(2, 1e-3, 20_000, 8), # Progressive training schedule
|
||||
(2, 5e-4, 40_000, 8), # (r, lr, step, batch_size)
|
||||
(2, 2e-4, 80_000, 8), #
|
||||
(2, 1e-4, 160_000, 8), # r = reduction factor (# of mel frames
|
||||
(2, 3e-5, 320_000, 8), # synthesized for each decoder iteration)
|
||||
(2, 1e-5, 640_000, 8)], # lr = learning rate
|
||||
```
|
||||
|
||||
Train Vocoder-Preprocess the data:adjust the batch_size in `synthesizer/hparams.py`
|
||||
```
|
||||
//Before
|
||||
### Data Preprocessing
|
||||
max_mel_frames = 900,
|
||||
rescale = True,
|
||||
rescaling_max = 0.9,
|
||||
synthesis_batch_size = 16, # For vocoder preprocessing and inference.
|
||||
//After
|
||||
### Data Preprocessing
|
||||
max_mel_frames = 900,
|
||||
rescale = True,
|
||||
rescaling_max = 0.9,
|
||||
synthesis_batch_size = 8, # For vocoder preprocessing and inference.
|
||||
```
|
||||
|
||||
Train Vocoder-Train the vocoder:adjust the batch_size in `vocoder/wavernn/hparams.py`
|
||||
```
|
||||
//Before
|
||||
# Training
|
||||
voc_batch_size = 100
|
||||
voc_lr = 1e-4
|
||||
voc_gen_at_checkpoint = 5
|
||||
voc_pad = 2
|
||||
|
||||
//After
|
||||
# Training
|
||||
voc_batch_size = 6
|
||||
voc_lr = 1e-4
|
||||
voc_gen_at_checkpoint = 5
|
||||
voc_pad =2
|
||||
```
|
||||
|
||||
#### 4.If it happens `RuntimeError: Error(s) in loading state_dict for Tacotron: size mismatch for encoder.embedding.weight: copying a param with shape torch.Size([70, 512]) from checkpoint, the shape in current model is torch.Size([75, 512]).`
|
||||
Please refer to issue [#37](https://github.com/babysor/MockingBird/issues/37)
|
||||
|
||||
#### 5. How to improve CPU and GPU occupancy rate?
|
||||
Adjust the batch_size as appropriate to improve
|
||||
|
||||
|
||||
#### 6. What if it happens `the page file is too small to complete the operation`
|
||||
Please refer to this [video](https://www.youtube.com/watch?v=Oh6dga-Oy10&ab_channel=CodeProf) and change the virtual memory to 100G (102400), for example : When the file is placed in the D disk, the virtual memory of the D disk is changed.
|
||||
|
||||
#### 7. When should I stop during training?
|
||||
FYI, my attention came after 18k steps and loss became lower than 0.4 after 50k steps.
|
||||

|
||||

|
||||
|
||||
@@ -15,12 +15,18 @@ if __name__ == '__main__':
|
||||
parser.add_argument("-d", "--datasets_root", type=Path, help= \
|
||||
"Path to the directory containing your datasets. See toolbox/__init__.py for a list of "
|
||||
"supported datasets.", default=None)
|
||||
parser.add_argument("-vc", "--vc_mode", action="store_true",
|
||||
help="Voice Conversion Mode(PPG based)")
|
||||
parser.add_argument("-e", "--enc_models_dir", type=Path, default="encoder/saved_models",
|
||||
help="Directory containing saved encoder models")
|
||||
parser.add_argument("-s", "--syn_models_dir", type=Path, default="synthesizer/saved_models",
|
||||
help="Directory containing saved synthesizer models")
|
||||
parser.add_argument("-v", "--voc_models_dir", type=Path, default="vocoder/saved_models",
|
||||
help="Directory containing saved vocoder models")
|
||||
parser.add_argument("-ex", "--extractor_models_dir", type=Path, default="ppg_extractor/saved_models",
|
||||
help="Directory containing saved extrator models")
|
||||
parser.add_argument("-cv", "--convertor_models_dir", type=Path, default="ppg2mel/saved_models",
|
||||
help="Directory containing saved convert models")
|
||||
parser.add_argument("--cpu", action="store_true", help=\
|
||||
"If True, processing is done on CPU, even when a GPU is available.")
|
||||
parser.add_argument("--seed", type=int, default=None, help=\
|
||||
|
||||
@@ -34,8 +34,16 @@ def load_model(weights_fpath: Path, device=None):
|
||||
_model.load_state_dict(checkpoint["model_state"])
|
||||
_model.eval()
|
||||
print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
|
||||
return _model
|
||||
|
||||
|
||||
def set_model(model, device=None):
|
||||
global _model, _device
|
||||
_model = model
|
||||
if device is None:
|
||||
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
_device = device
|
||||
_model.to(device)
|
||||
|
||||
def is_loaded():
|
||||
return _model is not None
|
||||
|
||||
@@ -57,7 +65,7 @@ def embed_frames_batch(frames_batch):
|
||||
|
||||
|
||||
def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
|
||||
min_pad_coverage=0.75, overlap=0.5):
|
||||
min_pad_coverage=0.75, overlap=0.5, rate=None):
|
||||
"""
|
||||
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
|
||||
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
|
||||
@@ -85,9 +93,18 @@ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_fram
|
||||
assert 0 <= overlap < 1
|
||||
assert 0 < min_pad_coverage <= 1
|
||||
|
||||
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
||||
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
||||
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
|
||||
if rate != None:
|
||||
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
||||
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
||||
frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
|
||||
else:
|
||||
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
||||
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
||||
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
|
||||
|
||||
assert 0 < frame_step, "The rate is too high"
|
||||
assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
|
||||
(sampling_rate / (samples_per_frame * partials_n_frames))
|
||||
|
||||
# Compute the slices
|
||||
wav_slices, mel_slices = [], []
|
||||
|
||||
@@ -117,6 +117,15 @@ def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir,
|
||||
logger.finalize()
|
||||
print("Done preprocessing %s.\n" % dataset_name)
|
||||
|
||||
def preprocess_aidatatang_200zh(datasets_root: Path, out_dir: Path, skip_existing=False):
|
||||
dataset_name = "aidatatang_200zh"
|
||||
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
||||
if not dataset_root:
|
||||
return
|
||||
# Preprocess all speakers
|
||||
speaker_dirs = list(dataset_root.joinpath("corpus", "train").glob("*"))
|
||||
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
|
||||
skip_existing, logger)
|
||||
|
||||
def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
|
||||
for dataset_name in librispeech_datasets["train"]["other"]:
|
||||
|
||||
Binary file not shown.
@@ -1,4 +1,4 @@
|
||||
from encoder.preprocess import preprocess_librispeech, preprocess_voxceleb1, preprocess_voxceleb2
|
||||
from encoder.preprocess import preprocess_librispeech, preprocess_voxceleb1, preprocess_voxceleb2, preprocess_aidatatang_200zh
|
||||
from utils.argutils import print_args
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
@@ -10,17 +10,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocesses audio files from datasets, encodes them as mel spectrograms and "
|
||||
"writes them to the disk. This will allow you to train the encoder. The "
|
||||
"datasets required are at least one of VoxCeleb1, VoxCeleb2 and LibriSpeech. "
|
||||
"Ideally, you should have all three. You should extract them as they are "
|
||||
"after having downloaded them and put them in a same directory, e.g.:\n"
|
||||
"-[datasets_root]\n"
|
||||
" -LibriSpeech\n"
|
||||
" -train-other-500\n"
|
||||
" -VoxCeleb1\n"
|
||||
" -wav\n"
|
||||
" -vox1_meta.csv\n"
|
||||
" -VoxCeleb2\n"
|
||||
" -dev",
|
||||
"datasets required are at least one of LibriSpeech, VoxCeleb1, VoxCeleb2, aidatatang_200zh. ",
|
||||
formatter_class=MyFormatter
|
||||
)
|
||||
parser.add_argument("datasets_root", type=Path, help=\
|
||||
@@ -29,7 +19,7 @@ if __name__ == "__main__":
|
||||
"Path to the output directory that will contain the mel spectrograms. If left out, "
|
||||
"defaults to <datasets_root>/SV2TTS/encoder/")
|
||||
parser.add_argument("-d", "--datasets", type=str,
|
||||
default="librispeech_other,voxceleb1,voxceleb2", help=\
|
||||
default="librispeech_other,voxceleb1,aidatatang_200zh", help=\
|
||||
"Comma-separated list of the name of the datasets you want to preprocess. Only the train "
|
||||
"set of these datasets will be used. Possible names: librispeech_other, voxceleb1, "
|
||||
"voxceleb2.")
|
||||
@@ -63,6 +53,7 @@ if __name__ == "__main__":
|
||||
"librispeech_other": preprocess_librispeech,
|
||||
"voxceleb1": preprocess_voxceleb1,
|
||||
"voxceleb2": preprocess_voxceleb2,
|
||||
"aidatatang_200zh": preprocess_aidatatang_200zh,
|
||||
}
|
||||
args = vars(args)
|
||||
for dataset in args.pop("datasets"):
|
||||
0
mkgui/__init__.py
Normal file
0
mkgui/__init__.py
Normal file
143
mkgui/app.py
Normal file
143
mkgui/app.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from asyncio.windows_events import NULL
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from encoder import inference as encoder
|
||||
import librosa
|
||||
from scipy.io.wavfile import write
|
||||
import re
|
||||
import numpy as np
|
||||
from mkgui.base.components.types import FileContent
|
||||
from vocoder.hifigan import inference as gan_vocoder
|
||||
from synthesizer.inference import Synthesizer
|
||||
from typing import Any
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Constants
|
||||
AUDIO_SAMPLES_DIR = 'samples\\'
|
||||
SYN_MODELS_DIRT = "synthesizer\\saved_models"
|
||||
ENC_MODELS_DIRT = "encoder\\saved_models"
|
||||
VOC_MODELS_DIRT = "vocoder\\saved_models"
|
||||
TEMP_SOURCE_AUDIO = "wavs/temp_source.wav"
|
||||
TEMP_RESULT_AUDIO = "wavs/temp_result.wav"
|
||||
|
||||
# Load local sample audio as options TODO: load dataset
|
||||
if os.path.isdir(AUDIO_SAMPLES_DIR):
|
||||
audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
|
||||
# Pre-Load models
|
||||
if os.path.isdir(SYN_MODELS_DIRT):
|
||||
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
||||
else:
|
||||
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(ENC_MODELS_DIRT):
|
||||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded encoders models: " + str(len(encoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(VOC_MODELS_DIRT):
|
||||
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
|
||||
print("Loaded vocoders models: " + str(len(synthesizers)))
|
||||
else:
|
||||
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
|
||||
class Input(BaseModel):
|
||||
message: str = Field(
|
||||
..., example="欢迎使用工具箱, 现已支持中文输入!", alias="文本内容"
|
||||
)
|
||||
local_audio_file: audio_input_selection = Field(
|
||||
..., alias="输入语音(本地wav)",
|
||||
description="选择本地语音文件."
|
||||
)
|
||||
upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
|
||||
description="拖拽或点击上传.", mime_type="audio/wav")
|
||||
encoder: encoders = Field(
|
||||
..., alias="编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
synthesizer: synthesizers = Field(
|
||||
..., alias="合成模型",
|
||||
description="选择语音合成模型文件."
|
||||
)
|
||||
vocoder: vocoders = Field(
|
||||
..., alias="语音解码模型",
|
||||
description="选择语音解码模型文件(目前只支持HifiGan类型)."
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: tuple[AudioEntity, AudioEntity]
|
||||
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
src, result = self.__root__
|
||||
|
||||
streamlit_app.subheader("Synthesized Audio")
|
||||
streamlit_app.audio(result.content, format="audio/wav")
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(src.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Source Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(result.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Result Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
|
||||
|
||||
def synthesize(input: Input) -> Output:
|
||||
"""synthesize(合成)"""
|
||||
# load models
|
||||
encoder.load_model(Path(input.encoder.value))
|
||||
current_synt = Synthesizer(Path(input.synthesizer.value))
|
||||
gan_vocoder.load_model(Path(input.vocoder.value))
|
||||
|
||||
# load file
|
||||
if input.upload_audio_file != None:
|
||||
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
||||
f.write(input.upload_audio_file.as_bytes())
|
||||
f.seek(0)
|
||||
wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
|
||||
else:
|
||||
wav, sample_rate = librosa.load(input.local_audio_file.value)
|
||||
write(TEMP_SOURCE_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
||||
|
||||
source_spec = Synthesizer.make_spectrogram(wav)
|
||||
|
||||
# preprocess
|
||||
encoder_wav = encoder.preprocess_wav(wav, sample_rate)
|
||||
embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
||||
|
||||
# Load input text
|
||||
texts = filter(None, input.message.split("\n"))
|
||||
punctuation = '!,。、,' # punctuate and split/clean text
|
||||
processed_texts = []
|
||||
for text in texts:
|
||||
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
|
||||
if processed_text:
|
||||
processed_texts.append(processed_text.strip())
|
||||
texts = processed_texts
|
||||
|
||||
# synthesize and vocode
|
||||
embeds = [embed] * len(texts)
|
||||
specs = current_synt.synthesize_spectrograms(texts, embeds)
|
||||
spec = np.concatenate(specs, axis=1)
|
||||
sample_rate = Synthesizer.sample_rate
|
||||
wav, sample_rate = gan_vocoder.infer_waveform(spec)
|
||||
|
||||
# write and output
|
||||
write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
||||
with open(TEMP_SOURCE_AUDIO, "rb") as f:
|
||||
source_file = f.read()
|
||||
with open(TEMP_RESULT_AUDIO, "rb") as f:
|
||||
result_file = f.read()
|
||||
return Output(__root__=(AudioEntity(content=source_file, mel=source_spec), AudioEntity(content=result_file, mel=spec)))
|
||||
167
mkgui/app_vc.py
Normal file
167
mkgui/app_vc.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from asyncio.windows_events import NULL
|
||||
from synthesizer.inference import Synthesizer
|
||||
from pydantic import BaseModel, Field
|
||||
from encoder import inference as speacker_encoder
|
||||
import torch
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
import ppg_extractor as Extractor
|
||||
import ppg2mel as Convertor
|
||||
import librosa
|
||||
from scipy.io.wavfile import write
|
||||
import re
|
||||
import numpy as np
|
||||
from mkgui.base.components.types import FileContent
|
||||
from vocoder.hifigan import inference as gan_vocoder
|
||||
from typing import Any
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
# Constants
|
||||
AUDIO_SAMPLES_DIR = 'samples\\'
|
||||
EXT_MODELS_DIRT = "ppg_extractor\\saved_models"
|
||||
CONV_MODELS_DIRT = "ppg2mel\\saved_models"
|
||||
VOC_MODELS_DIRT = "vocoder\\saved_models"
|
||||
TEMP_SOURCE_AUDIO = "wavs/temp_source.wav"
|
||||
TEMP_TARGET_AUDIO = "wavs/temp_target.wav"
|
||||
TEMP_RESULT_AUDIO = "wavs/temp_result.wav"
|
||||
|
||||
# Load local sample audio as options TODO: load dataset
|
||||
if os.path.isdir(AUDIO_SAMPLES_DIR):
|
||||
audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
|
||||
# Pre-Load models
|
||||
if os.path.isdir(EXT_MODELS_DIRT):
|
||||
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded extractor models: " + str(len(extractors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(CONV_MODELS_DIRT):
|
||||
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
|
||||
print("Loaded convertor models: " + str(len(convertors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(VOC_MODELS_DIRT):
|
||||
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
|
||||
print("Loaded vocoders models: " + str(len(vocoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
class Input(BaseModel):
|
||||
local_audio_file: audio_input_selection = Field(
|
||||
..., alias="输入语音(本地wav)",
|
||||
description="选择本地语音文件."
|
||||
)
|
||||
upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
|
||||
description="拖拽或点击上传.", mime_type="audio/wav")
|
||||
local_audio_file_target: audio_input_selection = Field(
|
||||
..., alias="目标语音(本地wav)",
|
||||
description="选择本地语音文件."
|
||||
)
|
||||
upload_audio_file_target: FileContent = Field(default=None, alias="或上传目标语音",
|
||||
description="拖拽或点击上传.", mime_type="audio/wav")
|
||||
extractor: extractors = Field(
|
||||
..., alias="编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
convertor: convertors = Field(
|
||||
..., alias="转换模型",
|
||||
description="选择语音转换模型文件."
|
||||
)
|
||||
vocoder: vocoders = Field(
|
||||
..., alias="语音编码模型",
|
||||
description="选择语音解码模型文件(目前只支持HifiGan类型)."
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: tuple[AudioEntity, AudioEntity, AudioEntity]
|
||||
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
src, target, result = self.__root__
|
||||
|
||||
streamlit_app.subheader("Synthesized Audio")
|
||||
streamlit_app.audio(result.content, format="audio/wav")
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(src.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Source Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(target.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Target Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(result.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Result Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
|
||||
def convert(input: Input) -> Output:
|
||||
"""convert(转换)"""
|
||||
# load models
|
||||
extractor = Extractor.load_model(Path(input.extractor.value))
|
||||
convertor = Convertor.load_model(Path(input.convertor.value))
|
||||
# current_synt = Synthesizer(Path(input.synthesizer.value))
|
||||
gan_vocoder.load_model(Path(input.vocoder.value))
|
||||
|
||||
# load file
|
||||
if input.upload_audio_file != None:
|
||||
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
||||
f.write(input.upload_audio_file.as_bytes())
|
||||
f.seek(0)
|
||||
src_wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
|
||||
else:
|
||||
src_wav, sample_rate = librosa.load(input.local_audio_file.value)
|
||||
write(TEMP_SOURCE_AUDIO, sample_rate, src_wav) #Make sure we get the correct wav
|
||||
|
||||
if input.upload_audio_file_target != None:
|
||||
with open(TEMP_TARGET_AUDIO, "w+b") as f:
|
||||
f.write(input.upload_audio_file_target.as_bytes())
|
||||
f.seek(0)
|
||||
ref_wav, _ = librosa.load(TEMP_TARGET_AUDIO)
|
||||
else:
|
||||
ref_wav, _ = librosa.load(input.local_audio_file_target.value)
|
||||
write(TEMP_TARGET_AUDIO, sample_rate, ref_wav) #Make sure we get the correct wav
|
||||
|
||||
ppg = extractor.extract_from_wav(src_wav)
|
||||
# Import necessary dependency of Voice Conversion
|
||||
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
|
||||
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
|
||||
speacker_encoder.load_model(Path("encoder/saved_models/pretrained_bak_5805000.pt"))
|
||||
embed = speacker_encoder.embed_utterance(ref_wav)
|
||||
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
|
||||
min_len = min(ppg.shape[1], len(lf0_uv))
|
||||
ppg = ppg[:, :min_len]
|
||||
lf0_uv = lf0_uv[:min_len]
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
_, mel_pred, att_ws = convertor.inference(
|
||||
ppg,
|
||||
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
|
||||
spembs=torch.from_numpy(embed).unsqueeze(0).to(device),
|
||||
)
|
||||
mel_pred= mel_pred.transpose(0, 1)
|
||||
breaks = [mel_pred.shape[1]]
|
||||
mel_pred= mel_pred.detach().cpu().numpy()
|
||||
|
||||
# synthesize and vocode
|
||||
wav, sample_rate = gan_vocoder.infer_waveform(mel_pred)
|
||||
|
||||
# write and output
|
||||
write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
||||
with open(TEMP_SOURCE_AUDIO, "rb") as f:
|
||||
source_file = f.read()
|
||||
with open(TEMP_TARGET_AUDIO, "rb") as f:
|
||||
target_file = f.read()
|
||||
with open(TEMP_RESULT_AUDIO, "rb") as f:
|
||||
result_file = f.read()
|
||||
|
||||
|
||||
return Output(__root__=(AudioEntity(content=source_file, mel=Synthesizer.make_spectrogram(src_wav)), AudioEntity(content=target_file, mel=Synthesizer.make_spectrogram(ref_wav)), AudioEntity(content=result_file, mel=Synthesizer.make_spectrogram(wav))))
|
||||
2
mkgui/base/__init__.py
Normal file
2
mkgui/base/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
|
||||
from .core import Opyrator
|
||||
1
mkgui/base/api/__init__.py
Normal file
1
mkgui/base/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .fastapi_app import create_api
|
||||
102
mkgui/base/api/fastapi_utils.py
Normal file
102
mkgui/base/api/fastapi_utils.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Collection of utilities for FastAPI apps."""
|
||||
|
||||
import inspect
|
||||
from typing import Any, Type
|
||||
|
||||
from fastapi import FastAPI, Form
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def as_form(cls: Type[BaseModel]) -> Any:
|
||||
"""Adds an as_form class method to decorated models.
|
||||
|
||||
The as_form class method can be used with FastAPI endpoints
|
||||
"""
|
||||
new_params = [
|
||||
inspect.Parameter(
|
||||
field.alias,
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
default=(Form(field.default) if not field.required else Form(...)),
|
||||
)
|
||||
for field in cls.__fields__.values()
|
||||
]
|
||||
|
||||
async def _as_form(**data): # type: ignore
|
||||
return cls(**data)
|
||||
|
||||
sig = inspect.signature(_as_form)
|
||||
sig = sig.replace(parameters=new_params)
|
||||
_as_form.__signature__ = sig # type: ignore
|
||||
setattr(cls, "as_form", _as_form)
|
||||
return cls
|
||||
|
||||
|
||||
def patch_fastapi(app: FastAPI) -> None:
|
||||
"""Patch function to allow relative url resolution.
|
||||
|
||||
This patch is required to make fastapi fully functional with a relative url path.
|
||||
This code snippet can be copy-pasted to any Fastapi application.
|
||||
"""
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse
|
||||
|
||||
async def redoc_ui_html(req: Request) -> HTMLResponse:
|
||||
assert app.openapi_url is not None
|
||||
redoc_ui = get_redoc_html(
|
||||
openapi_url="./" + app.openapi_url.lstrip("/"),
|
||||
title=app.title + " - Redoc UI",
|
||||
)
|
||||
|
||||
return HTMLResponse(redoc_ui.body.decode("utf-8"))
|
||||
|
||||
async def swagger_ui_html(req: Request) -> HTMLResponse:
|
||||
assert app.openapi_url is not None
|
||||
swagger_ui = get_swagger_ui_html(
|
||||
openapi_url="./" + app.openapi_url.lstrip("/"),
|
||||
title=app.title + " - Swagger UI",
|
||||
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
|
||||
)
|
||||
|
||||
# insert request interceptor to have all request run on relativ path
|
||||
request_interceptor = (
|
||||
"requestInterceptor: (e) => {"
|
||||
"\n\t\t\tvar url = window.location.origin + window.location.pathname"
|
||||
'\n\t\t\turl = url.substring( 0, url.lastIndexOf( "/" ) + 1);'
|
||||
"\n\t\t\turl = e.url.replace(/http(s)?:\/\/[^/]*\//i, url);" # noqa: W605
|
||||
"\n\t\t\te.contextUrl = url"
|
||||
"\n\t\t\te.url = url"
|
||||
"\n\t\t\treturn e;}"
|
||||
)
|
||||
|
||||
return HTMLResponse(
|
||||
swagger_ui.body.decode("utf-8").replace(
|
||||
"dom_id: '#swagger-ui',",
|
||||
"dom_id: '#swagger-ui',\n\t\t" + request_interceptor + ",",
|
||||
)
|
||||
)
|
||||
|
||||
# remove old docs route and add our patched route
|
||||
routes_new = []
|
||||
for app_route in app.routes:
|
||||
if app_route.path == "/docs": # type: ignore
|
||||
continue
|
||||
|
||||
if app_route.path == "/redoc": # type: ignore
|
||||
continue
|
||||
|
||||
routes_new.append(app_route)
|
||||
|
||||
app.router.routes = routes_new
|
||||
|
||||
assert app.docs_url is not None
|
||||
app.add_route(app.docs_url, swagger_ui_html, include_in_schema=False)
|
||||
assert app.redoc_url is not None
|
||||
app.add_route(app.redoc_url, redoc_ui_html, include_in_schema=False)
|
||||
|
||||
# Make graphql realtive
|
||||
from starlette import graphql
|
||||
|
||||
graphql.GRAPHIQL = graphql.GRAPHIQL.replace(
|
||||
"({{REQUEST_PATH}}", '("." + {{REQUEST_PATH}}'
|
||||
)
|
||||
0
mkgui/base/components/__init__.py
Normal file
0
mkgui/base/components/__init__.py
Normal file
43
mkgui/base/components/outputs.py
Normal file
43
mkgui/base/components/outputs.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ScoredLabel(BaseModel):
|
||||
label: str
|
||||
score: float
|
||||
|
||||
|
||||
class ClassificationOutput(BaseModel):
|
||||
__root__: List[ScoredLabel]
|
||||
|
||||
def __iter__(self): # type: ignore
|
||||
return iter(self.__root__)
|
||||
|
||||
def __getitem__(self, item): # type: ignore
|
||||
return self.__root__[item]
|
||||
|
||||
def render_output_ui(self, streamlit) -> None: # type: ignore
|
||||
import plotly.express as px
|
||||
|
||||
sorted_predictions = sorted(
|
||||
[prediction.dict() for prediction in self.__root__],
|
||||
key=lambda k: k["score"],
|
||||
)
|
||||
|
||||
num_labels = len(sorted_predictions)
|
||||
if len(sorted_predictions) > 10:
|
||||
num_labels = streamlit.slider(
|
||||
"Maximum labels to show: ",
|
||||
min_value=1,
|
||||
max_value=len(sorted_predictions),
|
||||
value=len(sorted_predictions),
|
||||
)
|
||||
fig = px.bar(
|
||||
sorted_predictions[len(sorted_predictions) - num_labels :],
|
||||
x="score",
|
||||
y="label",
|
||||
orientation="h",
|
||||
)
|
||||
streamlit.plotly_chart(fig, use_container_width=True)
|
||||
# fig.show()
|
||||
46
mkgui/base/components/types.py
Normal file
46
mkgui/base/components/types.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import base64
|
||||
from typing import Any, Dict, overload
|
||||
|
||||
|
||||
class FileContent(str):
|
||||
def as_bytes(self) -> bytes:
|
||||
return base64.b64decode(self, validate=True)
|
||||
|
||||
def as_str(self) -> str:
|
||||
return self.as_bytes().decode()
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(format="byte")
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Any: # type: ignore
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: Any) -> "FileContent":
|
||||
if isinstance(value, FileContent):
|
||||
return value
|
||||
elif isinstance(value, str):
|
||||
return FileContent(value)
|
||||
elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||
return FileContent(base64.b64encode(value).decode())
|
||||
else:
|
||||
raise Exception("Wrong type")
|
||||
|
||||
# # 暂时无法使用,因为浏览器中没有考虑选择文件夹
|
||||
# class DirectoryContent(FileContent):
|
||||
# @classmethod
|
||||
# def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
# field_schema.update(format="path")
|
||||
|
||||
# @classmethod
|
||||
# def validate(cls, value: Any) -> "DirectoryContent":
|
||||
# if isinstance(value, DirectoryContent):
|
||||
# return value
|
||||
# elif isinstance(value, str):
|
||||
# return DirectoryContent(value)
|
||||
# elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||
# return DirectoryContent(base64.b64encode(value).decode())
|
||||
# else:
|
||||
# raise Exception("Wrong type")
|
||||
203
mkgui/base/core.py
Normal file
203
mkgui/base/core.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import re
|
||||
from typing import Any, Callable, Type, Union, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
from pydantic.tools import parse_obj_as
|
||||
|
||||
|
||||
def name_to_title(name: str) -> str:
|
||||
"""Converts a camelCase or snake_case name to title case."""
|
||||
# If camelCase -> convert to snake case
|
||||
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
||||
name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
|
||||
# Convert to title case
|
||||
return name.replace("_", " ").strip().title()
|
||||
|
||||
|
||||
def is_compatible_type(type: Type) -> bool:
|
||||
"""Returns `True` if the type is opyrator-compatible."""
|
||||
try:
|
||||
if issubclass(type, BaseModel):
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# valid list type
|
||||
if type.__origin__ is list and issubclass(type.__args__[0], BaseModel):
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_input_type(func: Callable) -> Type:
|
||||
"""Returns the input type of a given function (callable).
|
||||
|
||||
Args:
|
||||
func: The function for which to get the input type.
|
||||
|
||||
Raises:
|
||||
ValueError: If the function does not have a valid input type annotation.
|
||||
"""
|
||||
type_hints = get_type_hints(func)
|
||||
|
||||
if "input" not in type_hints:
|
||||
raise ValueError(
|
||||
"The callable MUST have a parameter with the name `input` with typing annotation. "
|
||||
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
|
||||
)
|
||||
|
||||
input_type = type_hints["input"]
|
||||
|
||||
if not is_compatible_type(input_type):
|
||||
raise ValueError(
|
||||
"The `input` parameter MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
|
||||
)
|
||||
|
||||
# TODO: return warning if more than one input parameters
|
||||
|
||||
return input_type
|
||||
|
||||
|
||||
def get_output_type(func: Callable) -> Type:
|
||||
"""Returns the output type of a given function (callable).
|
||||
|
||||
Args:
|
||||
func: The function for which to get the output type.
|
||||
|
||||
Raises:
|
||||
ValueError: If the function does not have a valid output type annotation.
|
||||
"""
|
||||
type_hints = get_type_hints(func)
|
||||
if "return" not in type_hints:
|
||||
raise ValueError(
|
||||
"The return type of the callable MUST be annotated with type hints."
|
||||
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
|
||||
)
|
||||
|
||||
output_type = type_hints["return"]
|
||||
|
||||
if not is_compatible_type(output_type):
|
||||
raise ValueError(
|
||||
"The return value MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
|
||||
)
|
||||
|
||||
return output_type
|
||||
|
||||
|
||||
def get_callable(import_string: str) -> Callable:
|
||||
"""Import a callable from an string."""
|
||||
callable_seperator = ":"
|
||||
if callable_seperator not in import_string:
|
||||
# Use dot as seperator
|
||||
callable_seperator = "."
|
||||
|
||||
if callable_seperator not in import_string:
|
||||
raise ValueError("The callable path MUST specify the function. ")
|
||||
|
||||
mod_name, callable_name = import_string.rsplit(callable_seperator, 1)
|
||||
mod = importlib.import_module(mod_name)
|
||||
return getattr(mod, callable_name)
|
||||
|
||||
|
||||
class Opyrator:
|
||||
def __init__(self, func: Union[Callable, str]) -> None:
|
||||
if isinstance(func, str):
|
||||
# Try to load the function from a string notion
|
||||
self.function = get_callable(func)
|
||||
else:
|
||||
self.function = func
|
||||
|
||||
self._action = "Execute"
|
||||
self._input_type = None
|
||||
self._output_type = None
|
||||
|
||||
if not callable(self.function):
|
||||
raise ValueError("The provided function parameters is not a callable.")
|
||||
|
||||
if inspect.isclass(self.function):
|
||||
raise ValueError(
|
||||
"The provided callable is an uninitialized Class. This is not allowed."
|
||||
)
|
||||
|
||||
if inspect.isfunction(self.function):
|
||||
# The provided callable is a function
|
||||
self._input_type = get_input_type(self.function)
|
||||
self._output_type = get_output_type(self.function)
|
||||
|
||||
try:
|
||||
# Get name
|
||||
self._name = name_to_title(self.function.__name__)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Get description from function
|
||||
doc_string = inspect.getdoc(self.function)
|
||||
if doc_string:
|
||||
self._action = doc_string
|
||||
except Exception:
|
||||
pass
|
||||
elif hasattr(self.function, "__call__"):
|
||||
# The provided callable is a function
|
||||
self._input_type = get_input_type(self.function.__call__) # type: ignore
|
||||
self._output_type = get_output_type(self.function.__call__) # type: ignore
|
||||
|
||||
try:
|
||||
# Get name
|
||||
self._name = name_to_title(type(self.function).__name__)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Get action from
|
||||
doc_string = inspect.getdoc(self.function.__call__) # type: ignore
|
||||
if doc_string:
|
||||
self._action = doc_string
|
||||
|
||||
if (
|
||||
not self._action
|
||||
or self._action == "Call"
|
||||
):
|
||||
# Get docstring from class instead of __call__ function
|
||||
doc_string = inspect.getdoc(self.function)
|
||||
if doc_string:
|
||||
self._action = doc_string
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Unknown callable type.")
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def action(self) -> str:
|
||||
return self._action
|
||||
|
||||
@property
|
||||
def input_type(self) -> Any:
|
||||
return self._input_type
|
||||
|
||||
@property
|
||||
def output_type(self) -> Any:
|
||||
return self._output_type
|
||||
|
||||
def __call__(self, input: Any, **kwargs: Any) -> Any:
|
||||
|
||||
input_obj = input
|
||||
|
||||
if isinstance(input, str):
|
||||
# Allow json input
|
||||
input_obj = parse_raw_as(self.input_type, input)
|
||||
|
||||
if isinstance(input, dict):
|
||||
# Allow dict input
|
||||
input_obj = parse_obj_as(self.input_type, input)
|
||||
|
||||
return self.function(input_obj, **kwargs)
|
||||
1
mkgui/base/ui/__init__.py
Normal file
1
mkgui/base/ui/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .streamlit_ui import render_streamlit_ui
|
||||
129
mkgui/base/ui/schema_utils.py
Normal file
129
mkgui/base/ui/schema_utils.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def resolve_reference(reference: str, references: Dict) -> Dict:
|
||||
return references[reference.split("/")[-1]]
|
||||
|
||||
|
||||
def get_single_reference_item(property: Dict, references: Dict) -> Dict:
|
||||
# Ref can either be directly in the properties or the first element of allOf
|
||||
reference = property.get("$ref")
|
||||
if reference is None:
|
||||
reference = property["allOf"][0]["$ref"]
|
||||
return resolve_reference(reference, references)
|
||||
|
||||
|
||||
def is_single_string_property(property: Dict) -> bool:
|
||||
return property.get("type") == "string"
|
||||
|
||||
|
||||
def is_single_datetime_property(property: Dict) -> bool:
|
||||
if property.get("type") != "string":
|
||||
return False
|
||||
return property.get("format") in ["date-time", "time", "date"]
|
||||
|
||||
|
||||
def is_single_boolean_property(property: Dict) -> bool:
|
||||
return property.get("type") == "boolean"
|
||||
|
||||
|
||||
def is_single_number_property(property: Dict) -> bool:
|
||||
return property.get("type") in ["integer", "number"]
|
||||
|
||||
|
||||
def is_single_file_property(property: Dict) -> bool:
|
||||
if property.get("type") != "string":
|
||||
return False
|
||||
# TODO: binary?
|
||||
return property.get("format") == "byte"
|
||||
|
||||
|
||||
def is_single_directory_property(property: Dict) -> bool:
|
||||
if property.get("type") != "string":
|
||||
return False
|
||||
return property.get("format") == "path"
|
||||
|
||||
def is_multi_enum_property(property: Dict, references: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
if property.get("uniqueItems") is not True:
|
||||
# Only relevant if it is a set or other datastructures with unique items
|
||||
return False
|
||||
|
||||
try:
|
||||
_ = resolve_reference(property["items"]["$ref"], references)["enum"]
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_single_enum_property(property: Dict, references: Dict) -> bool:
|
||||
try:
|
||||
_ = get_single_reference_item(property, references)["enum"]
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_single_dict_property(property: Dict) -> bool:
|
||||
if property.get("type") != "object":
|
||||
return False
|
||||
return "additionalProperties" in property
|
||||
|
||||
|
||||
def is_single_reference(property: Dict) -> bool:
|
||||
if property.get("type") is not None:
|
||||
return False
|
||||
|
||||
return bool(property.get("$ref"))
|
||||
|
||||
|
||||
def is_multi_file_property(property: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
if property.get("items") is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
# TODO: binary
|
||||
return property["items"]["format"] == "byte"
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_single_object(property: Dict, references: Dict) -> bool:
|
||||
try:
|
||||
object_reference = get_single_reference_item(property, references)
|
||||
if object_reference["type"] != "object":
|
||||
return False
|
||||
return "properties" in object_reference
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_property_list(property: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
if property.get("items") is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
return property["items"]["type"] in ["string", "number", "integer"]
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_object_list_property(property: Dict, references: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
try:
|
||||
object_reference = resolve_reference(property["items"]["$ref"], references)
|
||||
if object_reference["type"] != "object":
|
||||
return False
|
||||
return "properties" in object_reference
|
||||
except Exception:
|
||||
return False
|
||||
885
mkgui/base/ui/streamlit_ui.py
Normal file
885
mkgui/base/ui/streamlit_ui.py
Normal file
@@ -0,0 +1,885 @@
|
||||
import datetime
|
||||
import inspect
|
||||
import mimetypes
|
||||
import sys
|
||||
from os import getcwd, unlink
|
||||
from platform import system
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any, Callable, Dict, List, Type
|
||||
from PIL import Image
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, ValidationError, parse_obj_as
|
||||
|
||||
from mkgui.base import Opyrator
|
||||
from mkgui.base.core import name_to_title
|
||||
from mkgui.base.ui import schema_utils
|
||||
from mkgui.base.ui.streamlit_utils import CUSTOM_STREAMLIT_CSS
|
||||
|
||||
STREAMLIT_RUNNER_SNIPPET = """
|
||||
from mkgui.base.ui import render_streamlit_ui
|
||||
from mkgui.base import Opyrator
|
||||
|
||||
import streamlit as st
|
||||
|
||||
# TODO: Make it configurable
|
||||
# Page config can only be setup once
|
||||
st.set_page_config(
|
||||
page_title="MockingBird",
|
||||
page_icon="🧊",
|
||||
layout="wide")
|
||||
|
||||
render_streamlit_ui()
|
||||
"""
|
||||
|
||||
# with st.spinner("Loading MockingBird GUI. Please wait..."):
|
||||
# opyrator = Opyrator("{opyrator_path}")
|
||||
|
||||
|
||||
def launch_ui(port: int = 8501) -> None:
|
||||
with NamedTemporaryFile(
|
||||
suffix=".py", mode="w", encoding="utf-8", delete=False
|
||||
) as f:
|
||||
f.write(STREAMLIT_RUNNER_SNIPPET)
|
||||
f.seek(0)
|
||||
|
||||
import subprocess
|
||||
|
||||
python_path = f'PYTHONPATH="$PYTHONPATH:{getcwd()}"'
|
||||
if system() == "Windows":
|
||||
python_path = f"set PYTHONPATH=%PYTHONPATH%;{getcwd()} &&"
|
||||
subprocess.run(
|
||||
f"""set STREAMLIT_GLOBAL_SHOW_WARNING_ON_DIRECT_EXECUTION=false""",
|
||||
shell=True,
|
||||
)
|
||||
|
||||
subprocess.run(
|
||||
f"""{python_path} "{sys.executable}" -m streamlit run --server.port={port} --server.headless=True --runner.magicEnabled=False --server.maxUploadSize=50 --browser.gatherUsageStats=False {f.name}""",
|
||||
shell=True,
|
||||
)
|
||||
|
||||
f.close()
|
||||
unlink(f.name)
|
||||
|
||||
|
||||
def function_has_named_arg(func: Callable, parameter: str) -> bool:
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
for param in sig.parameters.values():
|
||||
if param.name == "input":
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def has_output_ui_renderer(data_item: BaseModel) -> bool:
|
||||
return hasattr(data_item, "render_output_ui")
|
||||
|
||||
|
||||
def has_input_ui_renderer(input_class: Type[BaseModel]) -> bool:
|
||||
return hasattr(input_class, "render_input_ui")
|
||||
|
||||
|
||||
def is_compatible_audio(mime_type: str) -> bool:
|
||||
return mime_type in ["audio/mpeg", "audio/ogg", "audio/wav"]
|
||||
|
||||
|
||||
def is_compatible_image(mime_type: str) -> bool:
|
||||
return mime_type in ["image/png", "image/jpeg"]
|
||||
|
||||
|
||||
def is_compatible_video(mime_type: str) -> bool:
|
||||
return mime_type in ["video/mp4"]
|
||||
|
||||
|
||||
class InputUI:
|
||||
def __init__(self, session_state, input_class: Type[BaseModel]):
|
||||
self._session_state = session_state
|
||||
self._input_class = input_class
|
||||
|
||||
self._schema_properties = input_class.schema(by_alias=True).get(
|
||||
"properties", {}
|
||||
)
|
||||
self._schema_references = input_class.schema(by_alias=True).get(
|
||||
"definitions", {}
|
||||
)
|
||||
|
||||
def render_ui(self, streamlit_app_root) -> None:
|
||||
if has_input_ui_renderer(self._input_class):
|
||||
# The input model has a rendering function
|
||||
# The rendering also returns the current state of input data
|
||||
self._session_state.input_data = self._input_class.render_input_ui( # type: ignore
|
||||
st, self._session_state.input_data
|
||||
)
|
||||
return
|
||||
|
||||
# print(self._schema_properties)
|
||||
for property_key in self._schema_properties.keys():
|
||||
property = self._schema_properties[property_key]
|
||||
|
||||
if not property.get("title"):
|
||||
# Set property key as fallback title
|
||||
property["title"] = name_to_title(property_key)
|
||||
|
||||
try:
|
||||
if "input_data" in self._session_state:
|
||||
self._store_value(
|
||||
property_key,
|
||||
self._render_property(streamlit_app_root, property_key, property),
|
||||
)
|
||||
except Exception as e:
|
||||
print("Exception!", e)
|
||||
pass
|
||||
|
||||
def _get_default_streamlit_input_kwargs(self, key: str, property: Dict) -> Dict:
|
||||
streamlit_kwargs = {
|
||||
"label": property.get("title"),
|
||||
"key": key,
|
||||
}
|
||||
|
||||
if property.get("description"):
|
||||
streamlit_kwargs["help"] = property.get("description")
|
||||
return streamlit_kwargs
|
||||
|
||||
def _store_value(self, key: str, value: Any) -> None:
|
||||
data_element = self._session_state.input_data
|
||||
key_elements = key.split(".")
|
||||
for i, key_element in enumerate(key_elements):
|
||||
if i == len(key_elements) - 1:
|
||||
# add value to this element
|
||||
data_element[key_element] = value
|
||||
return
|
||||
if key_element not in data_element:
|
||||
data_element[key_element] = {}
|
||||
data_element = data_element[key_element]
|
||||
|
||||
def _get_value(self, key: str) -> Any:
|
||||
data_element = self._session_state.input_data
|
||||
key_elements = key.split(".")
|
||||
for i, key_element in enumerate(key_elements):
|
||||
if i == len(key_elements) - 1:
|
||||
# add value to this element
|
||||
if key_element not in data_element:
|
||||
return None
|
||||
return data_element[key_element]
|
||||
if key_element not in data_element:
|
||||
data_element[key_element] = {}
|
||||
data_element = data_element[key_element]
|
||||
return None
|
||||
|
||||
def _render_single_datetime_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
if property.get("format") == "time":
|
||||
if property.get("default"):
|
||||
try:
|
||||
streamlit_kwargs["value"] = datetime.time.fromisoformat( # type: ignore
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return streamlit_app.time_input(**streamlit_kwargs)
|
||||
elif property.get("format") == "date":
|
||||
if property.get("default"):
|
||||
try:
|
||||
streamlit_kwargs["value"] = datetime.date.fromisoformat( # type: ignore
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return streamlit_app.date_input(**streamlit_kwargs)
|
||||
elif property.get("format") == "date-time":
|
||||
if property.get("default"):
|
||||
try:
|
||||
streamlit_kwargs["value"] = datetime.datetime.fromisoformat( # type: ignore
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
with streamlit_app.container():
|
||||
streamlit_app.subheader(streamlit_kwargs.get("label"))
|
||||
if streamlit_kwargs.get("description"):
|
||||
streamlit_app.text(streamlit_kwargs.get("description"))
|
||||
selected_date = None
|
||||
selected_time = None
|
||||
date_col, time_col = streamlit_app.columns(2)
|
||||
with date_col:
|
||||
date_kwargs = {"label": "Date", "key": key + "-date-input"}
|
||||
if streamlit_kwargs.get("value"):
|
||||
try:
|
||||
date_kwargs["value"] = streamlit_kwargs.get( # type: ignore
|
||||
"value"
|
||||
).date()
|
||||
except Exception:
|
||||
pass
|
||||
selected_date = streamlit_app.date_input(**date_kwargs)
|
||||
|
||||
with time_col:
|
||||
time_kwargs = {"label": "Time", "key": key + "-time-input"}
|
||||
if streamlit_kwargs.get("value"):
|
||||
try:
|
||||
time_kwargs["value"] = streamlit_kwargs.get( # type: ignore
|
||||
"value"
|
||||
).time()
|
||||
except Exception:
|
||||
pass
|
||||
selected_time = streamlit_app.time_input(**time_kwargs)
|
||||
return datetime.datetime.combine(selected_date, selected_time)
|
||||
else:
|
||||
streamlit_app.warning(
|
||||
"Date format is not supported: " + str(property.get("format"))
|
||||
)
|
||||
|
||||
def _render_single_file_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
file_extension = None
|
||||
if "mime_type" in property:
|
||||
file_extension = mimetypes.guess_extension(property["mime_type"])
|
||||
|
||||
uploaded_file = streamlit_app.file_uploader(
|
||||
**streamlit_kwargs, accept_multiple_files=False, type=file_extension
|
||||
)
|
||||
if uploaded_file is None:
|
||||
return None
|
||||
|
||||
bytes = uploaded_file.getvalue()
|
||||
if property.get("mime_type"):
|
||||
if is_compatible_audio(property["mime_type"]):
|
||||
# Show audio
|
||||
streamlit_app.audio(bytes, format=property.get("mime_type"))
|
||||
if is_compatible_image(property["mime_type"]):
|
||||
# Show image
|
||||
streamlit_app.image(bytes)
|
||||
if is_compatible_video(property["mime_type"]):
|
||||
# Show video
|
||||
streamlit_app.video(bytes, format=property.get("mime_type"))
|
||||
return bytes
|
||||
|
||||
def _render_single_string_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
if property.get("default"):
|
||||
streamlit_kwargs["value"] = property.get("default")
|
||||
elif property.get("example"):
|
||||
# TODO: also use example for other property types
|
||||
# Use example as value if it is provided
|
||||
streamlit_kwargs["value"] = property.get("example")
|
||||
|
||||
if property.get("maxLength") is not None:
|
||||
streamlit_kwargs["max_chars"] = property.get("maxLength")
|
||||
|
||||
if (
|
||||
property.get("format")
|
||||
or (
|
||||
property.get("maxLength") is not None
|
||||
and int(property.get("maxLength")) < 140 # type: ignore
|
||||
)
|
||||
or property.get("writeOnly")
|
||||
):
|
||||
# If any format is set, use single text input
|
||||
# If max chars is set to less than 140, use single text input
|
||||
# If write only -> password field
|
||||
if property.get("writeOnly"):
|
||||
streamlit_kwargs["type"] = "password"
|
||||
return streamlit_app.text_input(**streamlit_kwargs)
|
||||
else:
|
||||
# Otherwise use multiline text area
|
||||
return streamlit_app.text_area(**streamlit_kwargs)
|
||||
|
||||
def _render_multi_enum_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
reference_item = schema_utils.resolve_reference(
|
||||
property["items"]["$ref"], self._schema_references
|
||||
)
|
||||
# TODO: how to select defaults
|
||||
return streamlit_app.multiselect(
|
||||
**streamlit_kwargs, options=reference_item["enum"]
|
||||
)
|
||||
|
||||
def _render_single_enum_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
reference_item = schema_utils.get_single_reference_item(
|
||||
property, self._schema_references
|
||||
)
|
||||
|
||||
if property.get("default") is not None:
|
||||
try:
|
||||
streamlit_kwargs["index"] = reference_item["enum"].index(
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
# Use default selection
|
||||
pass
|
||||
|
||||
return streamlit_app.selectbox(
|
||||
**streamlit_kwargs, options=reference_item["enum"]
|
||||
)
|
||||
|
||||
def _render_single_dict_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
# Add title and subheader
|
||||
streamlit_app.subheader(property.get("title"))
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
current_dict = self._get_value(key)
|
||||
if not current_dict:
|
||||
current_dict = {}
|
||||
|
||||
key_col, value_col = streamlit_app.columns(2)
|
||||
|
||||
with key_col:
|
||||
updated_key = streamlit_app.text_input(
|
||||
"Key", value="", key=key + "-new-key"
|
||||
)
|
||||
|
||||
with value_col:
|
||||
# TODO: also add boolean?
|
||||
value_kwargs = {"label": "Value", "key": key + "-new-value"}
|
||||
if property["additionalProperties"].get("type") == "integer":
|
||||
value_kwargs["value"] = 0 # type: ignore
|
||||
updated_value = streamlit_app.number_input(**value_kwargs)
|
||||
elif property["additionalProperties"].get("type") == "number":
|
||||
value_kwargs["value"] = 0.0 # type: ignore
|
||||
value_kwargs["format"] = "%f"
|
||||
updated_value = streamlit_app.number_input(**value_kwargs)
|
||||
else:
|
||||
value_kwargs["value"] = ""
|
||||
updated_value = streamlit_app.text_input(**value_kwargs)
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
with streamlit_app.container():
|
||||
clear_col, add_col = streamlit_app.columns([1, 2])
|
||||
|
||||
with clear_col:
|
||||
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
||||
current_dict = {}
|
||||
|
||||
with add_col:
|
||||
if (
|
||||
streamlit_app.button("Add Item", key=key + "-add-item")
|
||||
and updated_key
|
||||
):
|
||||
current_dict[updated_key] = updated_value
|
||||
|
||||
streamlit_app.write(current_dict)
|
||||
|
||||
return current_dict
|
||||
|
||||
def _render_single_reference(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
reference_item = schema_utils.get_single_reference_item(
|
||||
property, self._schema_references
|
||||
)
|
||||
return self._render_property(streamlit_app, key, reference_item)
|
||||
|
||||
def _render_multi_file_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
file_extension = None
|
||||
if "mime_type" in property:
|
||||
file_extension = mimetypes.guess_extension(property["mime_type"])
|
||||
|
||||
uploaded_files = streamlit_app.file_uploader(
|
||||
**streamlit_kwargs, accept_multiple_files=True, type=file_extension
|
||||
)
|
||||
uploaded_files_bytes = []
|
||||
if uploaded_files:
|
||||
for uploaded_file in uploaded_files:
|
||||
uploaded_files_bytes.append(uploaded_file.read())
|
||||
return uploaded_files_bytes
|
||||
|
||||
def _render_single_boolean_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
if property.get("default"):
|
||||
streamlit_kwargs["value"] = property.get("default")
|
||||
return streamlit_app.checkbox(**streamlit_kwargs)
|
||||
|
||||
def _render_single_number_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
number_transform = int
|
||||
if property.get("type") == "number":
|
||||
number_transform = float # type: ignore
|
||||
streamlit_kwargs["format"] = "%f"
|
||||
|
||||
if "multipleOf" in property:
|
||||
# Set stepcount based on multiple of parameter
|
||||
streamlit_kwargs["step"] = number_transform(property["multipleOf"])
|
||||
elif number_transform == int:
|
||||
# Set step size to 1 as default
|
||||
streamlit_kwargs["step"] = 1
|
||||
elif number_transform == float:
|
||||
# Set step size to 0.01 as default
|
||||
# TODO: adapt to default value
|
||||
streamlit_kwargs["step"] = 0.01
|
||||
|
||||
if "minimum" in property:
|
||||
streamlit_kwargs["min_value"] = number_transform(property["minimum"])
|
||||
if "exclusiveMinimum" in property:
|
||||
streamlit_kwargs["min_value"] = number_transform(
|
||||
property["exclusiveMinimum"] + streamlit_kwargs["step"]
|
||||
)
|
||||
if "maximum" in property:
|
||||
streamlit_kwargs["max_value"] = number_transform(property["maximum"])
|
||||
|
||||
if "exclusiveMaximum" in property:
|
||||
streamlit_kwargs["max_value"] = number_transform(
|
||||
property["exclusiveMaximum"] - streamlit_kwargs["step"]
|
||||
)
|
||||
|
||||
if property.get("default") is not None:
|
||||
streamlit_kwargs["value"] = number_transform(property.get("default")) # type: ignore
|
||||
else:
|
||||
if "min_value" in streamlit_kwargs:
|
||||
streamlit_kwargs["value"] = streamlit_kwargs["min_value"]
|
||||
elif number_transform == int:
|
||||
streamlit_kwargs["value"] = 0
|
||||
else:
|
||||
# Set default value to step
|
||||
streamlit_kwargs["value"] = number_transform(streamlit_kwargs["step"])
|
||||
|
||||
if "min_value" in streamlit_kwargs and "max_value" in streamlit_kwargs:
|
||||
# TODO: Only if less than X steps
|
||||
return streamlit_app.slider(**streamlit_kwargs)
|
||||
else:
|
||||
return streamlit_app.number_input(**streamlit_kwargs)
|
||||
|
||||
def _render_object_input(self, streamlit_app: st, key: str, property: Dict) -> Any:
|
||||
properties = property["properties"]
|
||||
object_inputs = {}
|
||||
for property_key in properties:
|
||||
property = properties[property_key]
|
||||
if not property.get("title"):
|
||||
# Set property key as fallback title
|
||||
property["title"] = name_to_title(property_key)
|
||||
# construct full key based on key parts -> required later to get the value
|
||||
full_key = key + "." + property_key
|
||||
object_inputs[property_key] = self._render_property(
|
||||
streamlit_app, full_key, property
|
||||
)
|
||||
return object_inputs
|
||||
|
||||
def _render_single_object_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
# Add title and subheader
|
||||
title = property.get("title")
|
||||
streamlit_app.subheader(title)
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
object_reference = schema_utils.get_single_reference_item(
|
||||
property, self._schema_references
|
||||
)
|
||||
return self._render_object_input(streamlit_app, key, object_reference)
|
||||
|
||||
def _render_property_list_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
# Add title and subheader
|
||||
streamlit_app.subheader(property.get("title"))
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
current_list = self._get_value(key)
|
||||
if not current_list:
|
||||
current_list = []
|
||||
|
||||
value_kwargs = {"label": "Value", "key": key + "-new-value"}
|
||||
if property["items"]["type"] == "integer":
|
||||
value_kwargs["value"] = 0 # type: ignore
|
||||
new_value = streamlit_app.number_input(**value_kwargs)
|
||||
elif property["items"]["type"] == "number":
|
||||
value_kwargs["value"] = 0.0 # type: ignore
|
||||
value_kwargs["format"] = "%f"
|
||||
new_value = streamlit_app.number_input(**value_kwargs)
|
||||
else:
|
||||
value_kwargs["value"] = ""
|
||||
new_value = streamlit_app.text_input(**value_kwargs)
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
with streamlit_app.container():
|
||||
clear_col, add_col = streamlit_app.columns([1, 2])
|
||||
|
||||
with clear_col:
|
||||
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
||||
current_list = []
|
||||
|
||||
with add_col:
|
||||
if (
|
||||
streamlit_app.button("Add Item", key=key + "-add-item")
|
||||
and new_value is not None
|
||||
):
|
||||
current_list.append(new_value)
|
||||
|
||||
streamlit_app.write(current_list)
|
||||
|
||||
return current_list
|
||||
|
||||
def _render_object_list_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
# TODO: support max_items, and min_items properties
|
||||
|
||||
# Add title and subheader
|
||||
streamlit_app.subheader(property.get("title"))
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
current_list = self._get_value(key)
|
||||
if not current_list:
|
||||
current_list = []
|
||||
|
||||
object_reference = schema_utils.resolve_reference(
|
||||
property["items"]["$ref"], self._schema_references
|
||||
)
|
||||
input_data = self._render_object_input(streamlit_app, key, object_reference)
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
with streamlit_app.container():
|
||||
clear_col, add_col = streamlit_app.columns([1, 2])
|
||||
|
||||
with clear_col:
|
||||
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
||||
current_list = []
|
||||
|
||||
with add_col:
|
||||
if (
|
||||
streamlit_app.button("Add Item", key=key + "-add-item")
|
||||
and input_data
|
||||
):
|
||||
current_list.append(input_data)
|
||||
|
||||
streamlit_app.write(current_list)
|
||||
return current_list
|
||||
|
||||
def _render_property(self, streamlit_app: st, key: str, property: Dict) -> Any:
|
||||
if schema_utils.is_single_enum_property(property, self._schema_references):
|
||||
return self._render_single_enum_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_multi_enum_property(property, self._schema_references):
|
||||
return self._render_multi_enum_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_file_property(property):
|
||||
return self._render_single_file_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_multi_file_property(property):
|
||||
return self._render_multi_file_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_datetime_property(property):
|
||||
return self._render_single_datetime_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_boolean_property(property):
|
||||
return self._render_single_boolean_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_dict_property(property):
|
||||
return self._render_single_dict_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_number_property(property):
|
||||
return self._render_single_number_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_string_property(property):
|
||||
return self._render_single_string_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_object(property, self._schema_references):
|
||||
return self._render_single_object_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_object_list_property(property, self._schema_references):
|
||||
return self._render_object_list_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_property_list(property):
|
||||
return self._render_property_list_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_reference(property):
|
||||
return self._render_single_reference(streamlit_app, key, property)
|
||||
|
||||
streamlit_app.warning(
|
||||
"The type of the following property is currently not supported: "
|
||||
+ str(property.get("title"))
|
||||
)
|
||||
raise Exception("Unsupported property")
|
||||
|
||||
|
||||
class OutputUI:
|
||||
def __init__(self, output_data: Any, input_data: Any):
|
||||
self._output_data = output_data
|
||||
self._input_data = input_data
|
||||
|
||||
def render_ui(self, streamlit_app) -> None:
|
||||
try:
|
||||
if isinstance(self._output_data, BaseModel):
|
||||
self._render_single_output(streamlit_app, self._output_data)
|
||||
return
|
||||
if type(self._output_data) == list:
|
||||
self._render_list_output(streamlit_app, self._output_data)
|
||||
return
|
||||
except Exception as ex:
|
||||
streamlit_app.exception(ex)
|
||||
# Fallback to
|
||||
streamlit_app.json(jsonable_encoder(self._output_data))
|
||||
|
||||
def _render_single_text_property(
|
||||
self, streamlit: st, property_schema: Dict, value: Any
|
||||
) -> None:
|
||||
# Add title and subheader
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
if value is None or value == "":
|
||||
streamlit.info("No value returned!")
|
||||
else:
|
||||
streamlit.code(str(value), language="plain")
|
||||
|
||||
def _render_single_file_property(
|
||||
self, streamlit: st, property_schema: Dict, value: Any
|
||||
) -> None:
|
||||
# Add title and subheader
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
if value is None or value == "":
|
||||
streamlit.info("No value returned!")
|
||||
else:
|
||||
# TODO: Detect if it is a FileContent instance
|
||||
# TODO: detect if it is base64
|
||||
file_extension = ""
|
||||
if "mime_type" in property_schema:
|
||||
mime_type = property_schema["mime_type"]
|
||||
file_extension = mimetypes.guess_extension(mime_type) or ""
|
||||
|
||||
if is_compatible_audio(mime_type):
|
||||
streamlit.audio(value.as_bytes(), format=mime_type)
|
||||
return
|
||||
|
||||
if is_compatible_image(mime_type):
|
||||
streamlit.image(value.as_bytes())
|
||||
return
|
||||
|
||||
if is_compatible_video(mime_type):
|
||||
streamlit.video(value.as_bytes(), format=mime_type)
|
||||
return
|
||||
|
||||
filename = (
|
||||
(property_schema["title"] + file_extension)
|
||||
.lower()
|
||||
.strip()
|
||||
.replace(" ", "-")
|
||||
)
|
||||
streamlit.markdown(
|
||||
f'<a href="data:application/octet-stream;base64,{value}" download="{filename}"><input type="button" value="Download File"></a>',
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
def _render_single_complex_property(
|
||||
self, streamlit: st, property_schema: Dict, value: Any
|
||||
) -> None:
|
||||
# Add title and subheader
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
|
||||
streamlit.json(jsonable_encoder(value))
|
||||
|
||||
def _render_single_output(self, streamlit: st, output_data: BaseModel) -> None:
|
||||
try:
|
||||
if has_output_ui_renderer(output_data):
|
||||
if function_has_named_arg(output_data.render_output_ui, "input"): # type: ignore
|
||||
# render method also requests the input data
|
||||
output_data.render_output_ui(streamlit, input=self._input_data) # type: ignore
|
||||
else:
|
||||
output_data.render_output_ui(streamlit) # type: ignore
|
||||
return
|
||||
except Exception:
|
||||
# Use default auto-generation methods if the custom rendering throws an exception
|
||||
logger.exception(
|
||||
"Failed to execute custom render_output_ui function. Using auto-generation instead"
|
||||
)
|
||||
|
||||
model_schema = output_data.schema(by_alias=False)
|
||||
model_properties = model_schema.get("properties")
|
||||
definitions = model_schema.get("definitions")
|
||||
|
||||
if model_properties:
|
||||
for property_key in output_data.__dict__:
|
||||
property_schema = model_properties.get(property_key)
|
||||
if not property_schema.get("title"):
|
||||
# Set property key as fallback title
|
||||
property_schema["title"] = property_key
|
||||
|
||||
output_property_value = output_data.__dict__[property_key]
|
||||
|
||||
if has_output_ui_renderer(output_property_value):
|
||||
output_property_value.render_output_ui(streamlit) # type: ignore
|
||||
continue
|
||||
|
||||
if isinstance(output_property_value, BaseModel):
|
||||
# Render output recursivly
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
self._render_single_output(streamlit, output_property_value)
|
||||
continue
|
||||
|
||||
if property_schema:
|
||||
if schema_utils.is_single_file_property(property_schema):
|
||||
self._render_single_file_property(
|
||||
streamlit, property_schema, output_property_value
|
||||
)
|
||||
continue
|
||||
|
||||
if (
|
||||
schema_utils.is_single_string_property(property_schema)
|
||||
or schema_utils.is_single_number_property(property_schema)
|
||||
or schema_utils.is_single_datetime_property(property_schema)
|
||||
or schema_utils.is_single_boolean_property(property_schema)
|
||||
):
|
||||
self._render_single_text_property(
|
||||
streamlit, property_schema, output_property_value
|
||||
)
|
||||
continue
|
||||
if definitions and schema_utils.is_single_enum_property(
|
||||
property_schema, definitions
|
||||
):
|
||||
self._render_single_text_property(
|
||||
streamlit, property_schema, output_property_value.value
|
||||
)
|
||||
continue
|
||||
|
||||
# TODO: render dict as table
|
||||
|
||||
self._render_single_complex_property(
|
||||
streamlit, property_schema, output_property_value
|
||||
)
|
||||
return
|
||||
|
||||
def _render_list_output(self, streamlit: st, output_data: List) -> None:
|
||||
try:
|
||||
data_items: List = []
|
||||
for data_item in output_data:
|
||||
if has_output_ui_renderer(data_item):
|
||||
# Render using the render function
|
||||
data_item.render_output_ui(streamlit) # type: ignore
|
||||
continue
|
||||
data_items.append(data_item.dict())
|
||||
# Try to show as dataframe
|
||||
streamlit.table(pd.DataFrame(data_items))
|
||||
except Exception:
|
||||
# Fallback to
|
||||
streamlit.json(jsonable_encoder(output_data))
|
||||
|
||||
|
||||
def getOpyrator(mode: str) -> Opyrator:
|
||||
if mode == None or mode.startswith('VC'):
|
||||
from mkgui.app_vc import convert
|
||||
return Opyrator(convert)
|
||||
if mode == None or mode.startswith('预处理'):
|
||||
from mkgui.preprocess import preprocess
|
||||
return Opyrator(preprocess)
|
||||
if mode == None or mode.startswith('模型训练'):
|
||||
from mkgui.train import train
|
||||
return Opyrator(train)
|
||||
from mkgui.app import synthesize
|
||||
return Opyrator(synthesize)
|
||||
|
||||
|
||||
def render_streamlit_ui() -> None:
|
||||
# init
|
||||
session_state = st.session_state
|
||||
session_state.input_data = {}
|
||||
# Add custom css settings
|
||||
st.markdown(f"<style>{CUSTOM_STREAMLIT_CSS}</style>", unsafe_allow_html=True)
|
||||
|
||||
with st.spinner("Loading MockingBird GUI. Please wait..."):
|
||||
session_state.mode = st.sidebar.selectbox(
|
||||
'模式选择',
|
||||
( "AI拟音", "VC拟音", "预处理", "模型训练")
|
||||
)
|
||||
if "mode" in session_state:
|
||||
mode = session_state.mode
|
||||
else:
|
||||
mode = ""
|
||||
opyrator = getOpyrator(mode)
|
||||
title = opyrator.name + mode
|
||||
|
||||
col1, col2, _ = st.columns(3)
|
||||
col2.title(title)
|
||||
col2.markdown("欢迎使用MockingBird Web 2")
|
||||
|
||||
image = Image.open('.\\mkgui\\static\\mb.png')
|
||||
col1.image(image)
|
||||
|
||||
st.markdown("---")
|
||||
left, right = st.columns([0.4, 0.6])
|
||||
|
||||
with left:
|
||||
st.header("Control 控制")
|
||||
InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(st)
|
||||
execute_selected = st.button(opyrator.action)
|
||||
if execute_selected:
|
||||
with st.spinner("Executing operation. Please wait..."):
|
||||
try:
|
||||
input_data_obj = parse_obj_as(
|
||||
opyrator.input_type, session_state.input_data
|
||||
)
|
||||
session_state.output_data = opyrator(input=input_data_obj)
|
||||
session_state.latest_operation_input = input_data_obj # should this really be saved as additional session object?
|
||||
except ValidationError as ex:
|
||||
st.error(ex)
|
||||
else:
|
||||
# st.success("Operation executed successfully.")
|
||||
pass
|
||||
|
||||
with right:
|
||||
st.header("Result 结果")
|
||||
if 'output_data' in session_state:
|
||||
OutputUI(
|
||||
session_state.output_data, session_state.latest_operation_input
|
||||
).render_ui(st)
|
||||
if st.button("Clear"):
|
||||
# Clear all state
|
||||
for key in st.session_state.keys():
|
||||
del st.session_state[key]
|
||||
session_state.input_data = {}
|
||||
st.experimental_rerun()
|
||||
else:
|
||||
# placeholder
|
||||
st.caption("请使用左侧控制板进行输入并运行获得结果")
|
||||
|
||||
|
||||
13
mkgui/base/ui/streamlit_utils.py
Normal file
13
mkgui/base/ui/streamlit_utils.py
Normal file
@@ -0,0 +1,13 @@
|
||||
CUSTOM_STREAMLIT_CSS = """
|
||||
div[data-testid="stBlock"] button {
|
||||
width: 100% !important;
|
||||
margin-bottom: 20px !important;
|
||||
border-color: #bfbfbf !important;
|
||||
}
|
||||
section[data-testid="stSidebar"] div {
|
||||
max-width: 10rem;
|
||||
}
|
||||
pre code {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
"""
|
||||
96
mkgui/preprocess.py
Normal file
96
mkgui/preprocess.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
# Constants
|
||||
EXT_MODELS_DIRT = "ppg_extractor\\saved_models"
|
||||
ENC_MODELS_DIRT = "encoder\\saved_models"
|
||||
|
||||
|
||||
if os.path.isdir(EXT_MODELS_DIRT):
|
||||
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded extractor models: " + str(len(extractors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(ENC_MODELS_DIRT):
|
||||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded encoders models: " + str(len(encoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
class Model(str, Enum):
|
||||
VC_PPG2MEL = "ppg2mel"
|
||||
|
||||
class Dataset(str, Enum):
|
||||
AIDATATANG_200ZH = "aidatatang_200zh"
|
||||
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
|
||||
|
||||
class Input(BaseModel):
|
||||
# def render_input_ui(st, input) -> Dict:
|
||||
# input["selected_dataset"] = st.selectbox(
|
||||
# '选择数据集',
|
||||
# ("aidatatang_200zh", "aidatatang_200zh_s")
|
||||
# )
|
||||
# return input
|
||||
model: Model = Field(
|
||||
Model.VC_PPG2MEL, title="目标模型",
|
||||
)
|
||||
dataset: Dataset = Field(
|
||||
Dataset.AIDATATANG_200ZH, title="数据集选择",
|
||||
)
|
||||
datasets_root: str = Field(
|
||||
..., alias="数据集根目录", description="输入数据集根目录(相对/绝对)",
|
||||
format=True,
|
||||
example="..\\trainning_data\\"
|
||||
)
|
||||
output_root: str = Field(
|
||||
..., alias="输出根目录", description="输出结果根目录(相对/绝对)",
|
||||
format=True,
|
||||
example="..\\trainning_data\\"
|
||||
)
|
||||
n_processes: int = Field(
|
||||
2, alias="处理线程数", description="根据CPU线程数来设置",
|
||||
le=32, ge=1
|
||||
)
|
||||
extractor: extractors = Field(
|
||||
..., alias="特征提取模型",
|
||||
description="选择PPG特征提取模型文件."
|
||||
)
|
||||
encoder: encoders = Field(
|
||||
..., alias="语音编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: tuple[str, int]
|
||||
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
sr, count = self.__root__
|
||||
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
|
||||
|
||||
def preprocess(input: Input) -> Output:
|
||||
"""Preprocess(预处理)"""
|
||||
finished = 0
|
||||
if input.model == Model.VC_PPG2MEL:
|
||||
from ppg2mel.preprocess import preprocess_dataset
|
||||
finished = preprocess_dataset(
|
||||
datasets_root=Path(input.datasets_root),
|
||||
dataset=input.dataset,
|
||||
out_dir=Path(input.output_root),
|
||||
n_processes=input.n_processes,
|
||||
ppg_encoder_model_fpath=Path(input.extractor.value),
|
||||
speaker_encoder_model=Path(input.encoder.value)
|
||||
)
|
||||
# TODO: pass useful return code
|
||||
return Output(__root__=(input.dataset, finished))
|
||||
BIN
mkgui/static/mb.png
Normal file
BIN
mkgui/static/mb.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.6 KiB |
156
mkgui/train.py
Normal file
156
mkgui/train.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
from utils.load_yaml import HpsYaml
|
||||
from utils.util import AttrDict
|
||||
import torch
|
||||
|
||||
# TODO: seperator for *unix systems
|
||||
# Constants
|
||||
EXT_MODELS_DIRT = "ppg_extractor\\saved_models"
|
||||
CONV_MODELS_DIRT = "ppg2mel\\saved_models"
|
||||
ENC_MODELS_DIRT = "encoder\\saved_models"
|
||||
|
||||
|
||||
if os.path.isdir(EXT_MODELS_DIRT):
|
||||
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded extractor models: " + str(len(extractors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(CONV_MODELS_DIRT):
|
||||
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
|
||||
print("Loaded convertor models: " + str(len(convertors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(ENC_MODELS_DIRT):
|
||||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded encoders models: " + str(len(encoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
class Model(str, Enum):
|
||||
VC_PPG2MEL = "ppg2mel"
|
||||
|
||||
class Dataset(str, Enum):
|
||||
AIDATATANG_200ZH = "aidatatang_200zh"
|
||||
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
|
||||
|
||||
class Input(BaseModel):
|
||||
# def render_input_ui(st, input) -> Dict:
|
||||
# input["selected_dataset"] = st.selectbox(
|
||||
# '选择数据集',
|
||||
# ("aidatatang_200zh", "aidatatang_200zh_s")
|
||||
# )
|
||||
# return input
|
||||
model: Model = Field(
|
||||
Model.VC_PPG2MEL, title="模型类型",
|
||||
)
|
||||
# datasets_root: str = Field(
|
||||
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
|
||||
# format=True,
|
||||
# example="..\\trainning_data\\"
|
||||
# )
|
||||
output_root: str = Field(
|
||||
..., alias="输出目录(可选)", description="建议不填,保持默认",
|
||||
format=True,
|
||||
example=""
|
||||
)
|
||||
continue_mode: bool = Field(
|
||||
True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练",
|
||||
)
|
||||
gpu: bool = Field(
|
||||
True, alias="GPU训练", description="选择“是”,则使用GPU训练",
|
||||
)
|
||||
verbose: bool = Field(
|
||||
True, alias="打印详情", description="选择“是”,输出更多详情",
|
||||
)
|
||||
# TODO: Move to hiden fields by default
|
||||
convertor: convertors = Field(
|
||||
..., alias="转换模型",
|
||||
description="选择语音转换模型文件."
|
||||
)
|
||||
extractor: extractors = Field(
|
||||
..., alias="特征提取模型",
|
||||
description="选择PPG特征提取模型文件."
|
||||
)
|
||||
encoder: encoders = Field(
|
||||
..., alias="语音编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
njobs: int = Field(
|
||||
8, alias="进程数", description="适用于ppg2mel",
|
||||
)
|
||||
seed: int = Field(
|
||||
default=0, alias="初始随机数", description="适用于ppg2mel",
|
||||
)
|
||||
model_name: str = Field(
|
||||
..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效",
|
||||
example="test"
|
||||
)
|
||||
model_config: str = Field(
|
||||
..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效",
|
||||
example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2"
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: tuple[str, int]
|
||||
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
sr, count = self.__root__
|
||||
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
|
||||
|
||||
def train(input: Input) -> Output:
|
||||
"""Train(训练)"""
|
||||
|
||||
print(">>> OneShot VC training ...")
|
||||
params = AttrDict()
|
||||
params.update({
|
||||
"gpu": input.gpu,
|
||||
"cpu": not input.gpu,
|
||||
"njobs": input.njobs,
|
||||
"seed": input.seed,
|
||||
"verbose": input.verbose,
|
||||
"load": input.convertor.value,
|
||||
"warm_start": False,
|
||||
})
|
||||
if input.continue_mode:
|
||||
# trace old model and config
|
||||
p = Path(input.convertor.value)
|
||||
params.name = p.parent.name
|
||||
# search a config file
|
||||
model_config_fpaths = list(p.parent.rglob("*.yaml"))
|
||||
if len(model_config_fpaths) == 0:
|
||||
raise "No model yaml config found for convertor"
|
||||
config = HpsYaml(model_config_fpaths[0])
|
||||
params.ckpdir = p.parent.parent
|
||||
params.config = model_config_fpaths[0]
|
||||
params.logdir = os.path.join(p.parent, "log")
|
||||
else:
|
||||
# Make the config dict dot visitable
|
||||
config = HpsYaml(input.config)
|
||||
np.random.seed(input.seed)
|
||||
torch.manual_seed(input.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(input.seed)
|
||||
mode = "train"
|
||||
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||
solver = Solver(config, params, mode)
|
||||
solver.load_data()
|
||||
solver.set_model()
|
||||
solver.exec()
|
||||
print(">>> Oneshot VC train finished!")
|
||||
|
||||
# TODO: pass useful return code
|
||||
return Output(__root__=(input.dataset, 0))
|
||||
209
ppg2mel/__init__.py
Normal file
209
ppg2mel/__init__.py
Normal file
@@ -0,0 +1,209 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2020 Songxiang Liu
|
||||
# Apache 2.0
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .utils.abs_model import AbsMelDecoder
|
||||
from .rnn_decoder_mol import Decoder
|
||||
from .utils.cnn_postnet import Postnet
|
||||
from .utils.vc_utils import get_mask_from_lengths
|
||||
|
||||
from utils.load_yaml import HpsYaml
|
||||
|
||||
class MelDecoderMOLv2(AbsMelDecoder):
|
||||
"""Use an encoder to preprocess ppg."""
|
||||
def __init__(
|
||||
self,
|
||||
num_speakers: int,
|
||||
spk_embed_dim: int,
|
||||
bottle_neck_feature_dim: int,
|
||||
encoder_dim: int = 256,
|
||||
encoder_downsample_rates: List = [2, 2],
|
||||
attention_rnn_dim: int = 512,
|
||||
decoder_rnn_dim: int = 512,
|
||||
num_decoder_rnn_layer: int = 1,
|
||||
concat_context_to_last: bool = True,
|
||||
prenet_dims: List = [256, 128],
|
||||
num_mixtures: int = 5,
|
||||
frames_per_step: int = 2,
|
||||
mask_padding: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.mask_padding = mask_padding
|
||||
self.bottle_neck_feature_dim = bottle_neck_feature_dim
|
||||
self.num_mels = 80
|
||||
self.encoder_down_factor=np.cumprod(encoder_downsample_rates)[-1]
|
||||
self.frames_per_step = frames_per_step
|
||||
self.use_spk_dvec = True
|
||||
|
||||
input_dim = bottle_neck_feature_dim
|
||||
|
||||
# Downsampling convolution
|
||||
self.bnf_prenet = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(input_dim, encoder_dim, kernel_size=1, bias=False),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
torch.nn.Conv1d(
|
||||
encoder_dim, encoder_dim,
|
||||
kernel_size=2*encoder_downsample_rates[0],
|
||||
stride=encoder_downsample_rates[0],
|
||||
padding=encoder_downsample_rates[0]//2,
|
||||
),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
torch.nn.Conv1d(
|
||||
encoder_dim, encoder_dim,
|
||||
kernel_size=2*encoder_downsample_rates[1],
|
||||
stride=encoder_downsample_rates[1],
|
||||
padding=encoder_downsample_rates[1]//2,
|
||||
),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
)
|
||||
decoder_enc_dim = encoder_dim
|
||||
self.pitch_convs = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(2, encoder_dim, kernel_size=1, bias=False),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
torch.nn.Conv1d(
|
||||
encoder_dim, encoder_dim,
|
||||
kernel_size=2*encoder_downsample_rates[0],
|
||||
stride=encoder_downsample_rates[0],
|
||||
padding=encoder_downsample_rates[0]//2,
|
||||
),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
torch.nn.Conv1d(
|
||||
encoder_dim, encoder_dim,
|
||||
kernel_size=2*encoder_downsample_rates[1],
|
||||
stride=encoder_downsample_rates[1],
|
||||
padding=encoder_downsample_rates[1]//2,
|
||||
),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
)
|
||||
|
||||
self.reduce_proj = torch.nn.Linear(encoder_dim + spk_embed_dim, encoder_dim)
|
||||
|
||||
# Decoder
|
||||
self.decoder = Decoder(
|
||||
enc_dim=decoder_enc_dim,
|
||||
num_mels=self.num_mels,
|
||||
frames_per_step=frames_per_step,
|
||||
attention_rnn_dim=attention_rnn_dim,
|
||||
decoder_rnn_dim=decoder_rnn_dim,
|
||||
num_decoder_rnn_layer=num_decoder_rnn_layer,
|
||||
prenet_dims=prenet_dims,
|
||||
num_mixtures=num_mixtures,
|
||||
use_stop_tokens=True,
|
||||
concat_context_to_last=concat_context_to_last,
|
||||
encoder_down_factor=self.encoder_down_factor,
|
||||
)
|
||||
|
||||
# Mel-Spec Postnet: some residual CNN layers
|
||||
self.postnet = Postnet()
|
||||
|
||||
def parse_output(self, outputs, output_lengths=None):
|
||||
if self.mask_padding and output_lengths is not None:
|
||||
mask = ~get_mask_from_lengths(output_lengths, outputs[0].size(1))
|
||||
mask = mask.unsqueeze(2).expand(mask.size(0), mask.size(1), self.num_mels)
|
||||
outputs[0].data.masked_fill_(mask, 0.0)
|
||||
outputs[1].data.masked_fill_(mask, 0.0)
|
||||
return outputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
bottle_neck_features: torch.Tensor,
|
||||
feature_lengths: torch.Tensor,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
logf0_uv: torch.Tensor = None,
|
||||
spembs: torch.Tensor = None,
|
||||
output_att_ws: bool = False,
|
||||
):
|
||||
decoder_inputs = self.bnf_prenet(
|
||||
bottle_neck_features.transpose(1, 2)
|
||||
).transpose(1, 2)
|
||||
logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
|
||||
decoder_inputs = decoder_inputs + logf0_uv
|
||||
|
||||
assert spembs is not None
|
||||
spk_embeds = F.normalize(
|
||||
spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
|
||||
decoder_inputs = torch.cat([decoder_inputs, spk_embeds], dim=-1)
|
||||
decoder_inputs = self.reduce_proj(decoder_inputs)
|
||||
|
||||
# (B, num_mels, T_dec)
|
||||
T_dec = torch.div(feature_lengths, int(self.encoder_down_factor), rounding_mode='floor')
|
||||
mel_outputs, predicted_stop, alignments = self.decoder(
|
||||
decoder_inputs, speech, T_dec)
|
||||
## Post-processing
|
||||
mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
if output_att_ws:
|
||||
return self.parse_output(
|
||||
[mel_outputs, mel_outputs_postnet, predicted_stop, alignments], speech_lengths)
|
||||
else:
|
||||
return self.parse_output(
|
||||
[mel_outputs, mel_outputs_postnet, predicted_stop], speech_lengths)
|
||||
|
||||
# return mel_outputs, mel_outputs_postnet
|
||||
|
||||
def inference(
|
||||
self,
|
||||
bottle_neck_features: torch.Tensor,
|
||||
logf0_uv: torch.Tensor = None,
|
||||
spembs: torch.Tensor = None,
|
||||
):
|
||||
decoder_inputs = self.bnf_prenet(bottle_neck_features.transpose(1, 2)).transpose(1, 2)
|
||||
logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
|
||||
decoder_inputs = decoder_inputs + logf0_uv
|
||||
|
||||
assert spembs is not None
|
||||
spk_embeds = F.normalize(
|
||||
spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
|
||||
bottle_neck_features = torch.cat([decoder_inputs, spk_embeds], dim=-1)
|
||||
bottle_neck_features = self.reduce_proj(bottle_neck_features)
|
||||
|
||||
## Decoder
|
||||
if bottle_neck_features.size(0) > 1:
|
||||
mel_outputs, alignments = self.decoder.inference_batched(bottle_neck_features)
|
||||
else:
|
||||
mel_outputs, alignments = self.decoder.inference(bottle_neck_features,)
|
||||
## Post-processing
|
||||
mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
# outputs = mel_outputs_postnet[0]
|
||||
|
||||
return mel_outputs[0], mel_outputs_postnet[0], alignments[0]
|
||||
|
||||
def load_model(model_file, device=None):
|
||||
# search a config file
|
||||
model_config_fpaths = list(model_file.parent.rglob("*.yaml"))
|
||||
if len(model_config_fpaths) == 0:
|
||||
raise "No model yaml config found for convertor"
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model_config = HpsYaml(model_config_fpaths[0])
|
||||
ppg2mel_model = MelDecoderMOLv2(
|
||||
**model_config["model"]
|
||||
).to(device)
|
||||
ckpt = torch.load(model_file, map_location=device)
|
||||
ppg2mel_model.load_state_dict(ckpt["model"])
|
||||
ppg2mel_model.eval()
|
||||
return ppg2mel_model
|
||||
113
ppg2mel/preprocess.py
Normal file
113
ppg2mel/preprocess.py
Normal file
@@ -0,0 +1,113 @@
|
||||
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
import soundfile
|
||||
import resampy
|
||||
|
||||
from ppg_extractor import load_model
|
||||
import encoder.inference as Encoder
|
||||
from encoder.audio import preprocess_wav
|
||||
from encoder import audio
|
||||
from utils.f0_utils import compute_f0
|
||||
|
||||
from torch.multiprocessing import Pool, cpu_count
|
||||
from functools import partial
|
||||
|
||||
SAMPLE_RATE=16000
|
||||
|
||||
def _compute_bnf(
|
||||
wav: any,
|
||||
output_fpath: str,
|
||||
device: torch.device,
|
||||
ppg_model_local: any,
|
||||
):
|
||||
"""
|
||||
Compute CTC-Attention Seq2seq ASR encoder bottle-neck features (BNF).
|
||||
"""
|
||||
ppg_model_local.to(device)
|
||||
wav_tensor = torch.from_numpy(wav).float().to(device).unsqueeze(0)
|
||||
wav_length = torch.LongTensor([wav.shape[0]]).to(device)
|
||||
with torch.no_grad():
|
||||
bnf = ppg_model_local(wav_tensor, wav_length)
|
||||
bnf_npy = bnf.squeeze(0).cpu().numpy()
|
||||
np.save(output_fpath, bnf_npy, allow_pickle=False)
|
||||
return bnf_npy, len(bnf_npy)
|
||||
|
||||
def _compute_f0_from_wav(wav, output_fpath):
|
||||
"""Compute merged f0 values."""
|
||||
f0 = compute_f0(wav, SAMPLE_RATE)
|
||||
np.save(output_fpath, f0, allow_pickle=False)
|
||||
return f0, len(f0)
|
||||
|
||||
def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device):
|
||||
Encoder.set_model(encoder_model_local)
|
||||
# Compute where to split the utterance into partials and pad if necessary
|
||||
wave_slices, mel_slices = Encoder.compute_partial_slices(len(wav), rate=1.3, min_pad_coverage=0.75)
|
||||
max_wave_length = wave_slices[-1].stop
|
||||
if max_wave_length >= len(wav):
|
||||
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
||||
|
||||
# Split the utterance into partials
|
||||
frames = audio.wav_to_mel_spectrogram(wav)
|
||||
frames_batch = np.array([frames[s] for s in mel_slices])
|
||||
partial_embeds = Encoder.embed_frames_batch(frames_batch)
|
||||
|
||||
# Compute the utterance embedding from the partial embeddings
|
||||
raw_embed = np.mean(partial_embeds, axis=0)
|
||||
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
||||
|
||||
np.save(output_fpath, embed, allow_pickle=False)
|
||||
return embed, len(embed)
|
||||
|
||||
def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local):
|
||||
# wav = preprocess_wav(wav_path)
|
||||
# try:
|
||||
wav, sr = soundfile.read(wav_path)
|
||||
if len(wav) < sr:
|
||||
return None, sr, len(wav)
|
||||
if sr != SAMPLE_RATE:
|
||||
wav = resampy.resample(wav, sr, SAMPLE_RATE)
|
||||
sr = SAMPLE_RATE
|
||||
utt_id = os.path.basename(wav_path).rstrip(".wav")
|
||||
|
||||
_, length_bnf = _compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local)
|
||||
_, length_f0 = _compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav)
|
||||
_, length_embed = _compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav)
|
||||
|
||||
def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model):
|
||||
# Glob wav files
|
||||
wav_file_list = sorted(Path(f"{datasets_root}/{dataset}").glob("**/*.wav"))
|
||||
print(f"Globbed {len(wav_file_list)} wav files.")
|
||||
|
||||
out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True)
|
||||
out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True)
|
||||
out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True)
|
||||
ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu")
|
||||
encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu")
|
||||
if n_processes is None:
|
||||
n_processes = cpu_count()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, encoder_model_local=encoder_model_local, device=device)
|
||||
job = Pool(n_processes).imap(func, wav_file_list)
|
||||
list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav"))
|
||||
|
||||
# finish processing and mark
|
||||
t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8")
|
||||
d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8")
|
||||
e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8")
|
||||
for file in sorted(out_dir.joinpath("f0").glob("*.npy")):
|
||||
id = os.path.basename(file).split(".f0.npy")[0]
|
||||
if id.endswith("01"):
|
||||
d_fid_file.write(id + "\n")
|
||||
elif id.endswith("09"):
|
||||
e_fid_file.write(id + "\n")
|
||||
else:
|
||||
t_fid_file.write(id + "\n")
|
||||
t_fid_file.close()
|
||||
d_fid_file.close()
|
||||
e_fid_file.close()
|
||||
return len(wav_file_list)
|
||||
374
ppg2mel/rnn_decoder_mol.py
Normal file
374
ppg2mel/rnn_decoder_mol.py
Normal file
@@ -0,0 +1,374 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from .utils.mol_attention import MOLAttention
|
||||
from .utils.basic_layers import Linear
|
||||
from .utils.vc_utils import get_mask_from_lengths
|
||||
|
||||
|
||||
class DecoderPrenet(nn.Module):
|
||||
def __init__(self, in_dim, sizes):
|
||||
super().__init__()
|
||||
in_sizes = [in_dim] + sizes[:-1]
|
||||
self.layers = nn.ModuleList(
|
||||
[Linear(in_size, out_size, bias=False)
|
||||
for (in_size, out_size) in zip(in_sizes, sizes)])
|
||||
|
||||
def forward(self, x):
|
||||
for linear in self.layers:
|
||||
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""Mixture of Logistic (MoL) attention-based RNN Decoder."""
|
||||
def __init__(
|
||||
self,
|
||||
enc_dim,
|
||||
num_mels,
|
||||
frames_per_step,
|
||||
attention_rnn_dim,
|
||||
decoder_rnn_dim,
|
||||
prenet_dims,
|
||||
num_mixtures,
|
||||
encoder_down_factor=1,
|
||||
num_decoder_rnn_layer=1,
|
||||
use_stop_tokens=False,
|
||||
concat_context_to_last=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.enc_dim = enc_dim
|
||||
self.encoder_down_factor = encoder_down_factor
|
||||
self.num_mels = num_mels
|
||||
self.frames_per_step = frames_per_step
|
||||
self.attention_rnn_dim = attention_rnn_dim
|
||||
self.decoder_rnn_dim = decoder_rnn_dim
|
||||
self.prenet_dims = prenet_dims
|
||||
self.use_stop_tokens = use_stop_tokens
|
||||
self.num_decoder_rnn_layer = num_decoder_rnn_layer
|
||||
self.concat_context_to_last = concat_context_to_last
|
||||
|
||||
# Mel prenet
|
||||
self.prenet = DecoderPrenet(num_mels, prenet_dims)
|
||||
self.prenet_pitch = DecoderPrenet(num_mels, prenet_dims)
|
||||
|
||||
# Attention RNN
|
||||
self.attention_rnn = nn.LSTMCell(
|
||||
prenet_dims[-1] + enc_dim,
|
||||
attention_rnn_dim
|
||||
)
|
||||
|
||||
# Attention
|
||||
self.attention_layer = MOLAttention(
|
||||
attention_rnn_dim,
|
||||
r=frames_per_step/encoder_down_factor,
|
||||
M=num_mixtures,
|
||||
)
|
||||
|
||||
# Decoder RNN
|
||||
self.decoder_rnn_layers = nn.ModuleList()
|
||||
for i in range(num_decoder_rnn_layer):
|
||||
if i == 0:
|
||||
self.decoder_rnn_layers.append(
|
||||
nn.LSTMCell(
|
||||
enc_dim + attention_rnn_dim,
|
||||
decoder_rnn_dim))
|
||||
else:
|
||||
self.decoder_rnn_layers.append(
|
||||
nn.LSTMCell(
|
||||
decoder_rnn_dim,
|
||||
decoder_rnn_dim))
|
||||
# self.decoder_rnn = nn.LSTMCell(
|
||||
# 2 * enc_dim + attention_rnn_dim,
|
||||
# decoder_rnn_dim
|
||||
# )
|
||||
if concat_context_to_last:
|
||||
self.linear_projection = Linear(
|
||||
enc_dim + decoder_rnn_dim,
|
||||
num_mels * frames_per_step
|
||||
)
|
||||
else:
|
||||
self.linear_projection = Linear(
|
||||
decoder_rnn_dim,
|
||||
num_mels * frames_per_step
|
||||
)
|
||||
|
||||
|
||||
# Stop-token layer
|
||||
if self.use_stop_tokens:
|
||||
if concat_context_to_last:
|
||||
self.stop_layer = Linear(
|
||||
enc_dim + decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
|
||||
)
|
||||
else:
|
||||
self.stop_layer = Linear(
|
||||
decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
|
||||
)
|
||||
|
||||
|
||||
def get_go_frame(self, memory):
|
||||
B = memory.size(0)
|
||||
go_frame = torch.zeros((B, self.num_mels), dtype=torch.float,
|
||||
device=memory.device)
|
||||
return go_frame
|
||||
|
||||
def initialize_decoder_states(self, memory, mask):
|
||||
device = next(self.parameters()).device
|
||||
B = memory.size(0)
|
||||
|
||||
# attention rnn states
|
||||
self.attention_hidden = torch.zeros(
|
||||
(B, self.attention_rnn_dim), device=device)
|
||||
self.attention_cell = torch.zeros(
|
||||
(B, self.attention_rnn_dim), device=device)
|
||||
|
||||
# decoder rnn states
|
||||
self.decoder_hiddens = []
|
||||
self.decoder_cells = []
|
||||
for i in range(self.num_decoder_rnn_layer):
|
||||
self.decoder_hiddens.append(
|
||||
torch.zeros((B, self.decoder_rnn_dim),
|
||||
device=device)
|
||||
)
|
||||
self.decoder_cells.append(
|
||||
torch.zeros((B, self.decoder_rnn_dim),
|
||||
device=device)
|
||||
)
|
||||
# self.decoder_hidden = torch.zeros(
|
||||
# (B, self.decoder_rnn_dim), device=device)
|
||||
# self.decoder_cell = torch.zeros(
|
||||
# (B, self.decoder_rnn_dim), device=device)
|
||||
|
||||
self.attention_context = torch.zeros(
|
||||
(B, self.enc_dim), device=device)
|
||||
|
||||
self.memory = memory
|
||||
# self.processed_memory = self.attention_layer.memory_layer(memory)
|
||||
self.mask = mask
|
||||
|
||||
def parse_decoder_inputs(self, decoder_inputs):
|
||||
"""Prepare decoder inputs, i.e. gt mel
|
||||
Args:
|
||||
decoder_inputs:(B, T_out, n_mel_channels) inputs used for teacher-forced training.
|
||||
"""
|
||||
decoder_inputs = decoder_inputs.reshape(
|
||||
decoder_inputs.size(0),
|
||||
int(decoder_inputs.size(1)/self.frames_per_step), -1)
|
||||
# (B, T_out//r, r*num_mels) -> (T_out//r, B, r*num_mels)
|
||||
decoder_inputs = decoder_inputs.transpose(0, 1)
|
||||
# (T_out//r, B, num_mels)
|
||||
decoder_inputs = decoder_inputs[:,:,-self.num_mels:]
|
||||
return decoder_inputs
|
||||
|
||||
def parse_decoder_outputs(self, mel_outputs, alignments, stop_outputs):
|
||||
""" Prepares decoder outputs for output
|
||||
Args:
|
||||
mel_outputs:
|
||||
alignments:
|
||||
"""
|
||||
# (T_out//r, B, T_enc) -> (B, T_out//r, T_enc)
|
||||
alignments = torch.stack(alignments).transpose(0, 1)
|
||||
# (T_out//r, B) -> (B, T_out//r)
|
||||
if stop_outputs is not None:
|
||||
if alignments.size(0) == 1:
|
||||
stop_outputs = torch.stack(stop_outputs).unsqueeze(0)
|
||||
else:
|
||||
stop_outputs = torch.stack(stop_outputs).transpose(0, 1)
|
||||
stop_outputs = stop_outputs.contiguous()
|
||||
# (T_out//r, B, num_mels*r) -> (B, T_out//r, num_mels*r)
|
||||
mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
|
||||
# decouple frames per step
|
||||
# (B, T_out, num_mels)
|
||||
mel_outputs = mel_outputs.view(
|
||||
mel_outputs.size(0), -1, self.num_mels)
|
||||
return mel_outputs, alignments, stop_outputs
|
||||
|
||||
def attend(self, decoder_input):
|
||||
cell_input = torch.cat((decoder_input, self.attention_context), -1)
|
||||
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
||||
cell_input, (self.attention_hidden, self.attention_cell))
|
||||
self.attention_context, attention_weights = self.attention_layer(
|
||||
self.attention_hidden, self.memory, None, self.mask)
|
||||
|
||||
decoder_rnn_input = torch.cat(
|
||||
(self.attention_hidden, self.attention_context), -1)
|
||||
|
||||
return decoder_rnn_input, self.attention_context, attention_weights
|
||||
|
||||
def decode(self, decoder_input):
|
||||
for i in range(self.num_decoder_rnn_layer):
|
||||
if i == 0:
|
||||
self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
|
||||
decoder_input, (self.decoder_hiddens[i], self.decoder_cells[i]))
|
||||
else:
|
||||
self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
|
||||
self.decoder_hiddens[i-1], (self.decoder_hiddens[i], self.decoder_cells[i]))
|
||||
return self.decoder_hiddens[-1]
|
||||
|
||||
def forward(self, memory, mel_inputs, memory_lengths):
|
||||
""" Decoder forward pass for training
|
||||
Args:
|
||||
memory: (B, T_enc, enc_dim) Encoder outputs
|
||||
decoder_inputs: (B, T, num_mels) Decoder inputs for teacher forcing.
|
||||
memory_lengths: (B, ) Encoder output lengths for attention masking.
|
||||
Returns:
|
||||
mel_outputs: (B, T, num_mels) mel outputs from the decoder
|
||||
alignments: (B, T//r, T_enc) attention weights.
|
||||
"""
|
||||
# [1, B, num_mels]
|
||||
go_frame = self.get_go_frame(memory).unsqueeze(0)
|
||||
# [T//r, B, num_mels]
|
||||
mel_inputs = self.parse_decoder_inputs(mel_inputs)
|
||||
# [T//r + 1, B, num_mels]
|
||||
mel_inputs = torch.cat((go_frame, mel_inputs), dim=0)
|
||||
# [T//r + 1, B, prenet_dim]
|
||||
decoder_inputs = self.prenet(mel_inputs)
|
||||
# decoder_inputs_pitch = self.prenet_pitch(decoder_inputs__)
|
||||
|
||||
self.initialize_decoder_states(
|
||||
memory, mask=~get_mask_from_lengths(memory_lengths),
|
||||
)
|
||||
|
||||
self.attention_layer.init_states(memory)
|
||||
# self.attention_layer_pitch.init_states(memory_pitch)
|
||||
|
||||
mel_outputs, alignments = [], []
|
||||
if self.use_stop_tokens:
|
||||
stop_outputs = []
|
||||
else:
|
||||
stop_outputs = None
|
||||
while len(mel_outputs) < decoder_inputs.size(0) - 1:
|
||||
decoder_input = decoder_inputs[len(mel_outputs)]
|
||||
# decoder_input_pitch = decoder_inputs_pitch[len(mel_outputs)]
|
||||
|
||||
decoder_rnn_input, context, attention_weights = self.attend(decoder_input)
|
||||
|
||||
decoder_rnn_output = self.decode(decoder_rnn_input)
|
||||
if self.concat_context_to_last:
|
||||
decoder_rnn_output = torch.cat(
|
||||
(decoder_rnn_output, context), dim=1)
|
||||
|
||||
mel_output = self.linear_projection(decoder_rnn_output)
|
||||
if self.use_stop_tokens:
|
||||
stop_output = self.stop_layer(decoder_rnn_output)
|
||||
stop_outputs += [stop_output.squeeze()]
|
||||
mel_outputs += [mel_output.squeeze(1)] #? perhaps don't need squeeze
|
||||
alignments += [attention_weights]
|
||||
# alignments_pitch += [attention_weights_pitch]
|
||||
|
||||
mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
|
||||
mel_outputs, alignments, stop_outputs)
|
||||
if stop_outputs is None:
|
||||
return mel_outputs, alignments
|
||||
else:
|
||||
return mel_outputs, stop_outputs, alignments
|
||||
|
||||
def inference(self, memory, stop_threshold=0.5):
|
||||
""" Decoder inference
|
||||
Args:
|
||||
memory: (1, T_enc, D_enc) Encoder outputs
|
||||
Returns:
|
||||
mel_outputs: mel outputs from the decoder
|
||||
alignments: sequence of attention weights from the decoder
|
||||
"""
|
||||
# [1, num_mels]
|
||||
decoder_input = self.get_go_frame(memory)
|
||||
|
||||
self.initialize_decoder_states(memory, mask=None)
|
||||
|
||||
self.attention_layer.init_states(memory)
|
||||
|
||||
mel_outputs, alignments = [], []
|
||||
# NOTE(sx): heuristic
|
||||
max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step
|
||||
min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
|
||||
while True:
|
||||
decoder_input = self.prenet(decoder_input)
|
||||
|
||||
decoder_input_final, context, alignment = self.attend(decoder_input)
|
||||
|
||||
#mel_output, stop_output, alignment = self.decode(decoder_input)
|
||||
decoder_rnn_output = self.decode(decoder_input_final)
|
||||
if self.concat_context_to_last:
|
||||
decoder_rnn_output = torch.cat(
|
||||
(decoder_rnn_output, context), dim=1)
|
||||
|
||||
mel_output = self.linear_projection(decoder_rnn_output)
|
||||
stop_output = self.stop_layer(decoder_rnn_output)
|
||||
|
||||
mel_outputs += [mel_output.squeeze(1)]
|
||||
alignments += [alignment]
|
||||
|
||||
if torch.sigmoid(stop_output.data) > stop_threshold and len(mel_outputs) >= min_decoder_step:
|
||||
break
|
||||
if len(mel_outputs) >= max_decoder_step:
|
||||
# print("Warning! Decoding steps reaches max decoder steps.")
|
||||
break
|
||||
|
||||
decoder_input = mel_output[:,-self.num_mels:]
|
||||
|
||||
|
||||
mel_outputs, alignments, _ = self.parse_decoder_outputs(
|
||||
mel_outputs, alignments, None)
|
||||
|
||||
return mel_outputs, alignments
|
||||
|
||||
def inference_batched(self, memory, stop_threshold=0.5):
|
||||
""" Decoder inference
|
||||
Args:
|
||||
memory: (B, T_enc, D_enc) Encoder outputs
|
||||
Returns:
|
||||
mel_outputs: mel outputs from the decoder
|
||||
alignments: sequence of attention weights from the decoder
|
||||
"""
|
||||
# [1, num_mels]
|
||||
decoder_input = self.get_go_frame(memory)
|
||||
|
||||
self.initialize_decoder_states(memory, mask=None)
|
||||
|
||||
self.attention_layer.init_states(memory)
|
||||
|
||||
mel_outputs, alignments = [], []
|
||||
stop_outputs = []
|
||||
# NOTE(sx): heuristic
|
||||
max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step
|
||||
min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
|
||||
while True:
|
||||
decoder_input = self.prenet(decoder_input)
|
||||
|
||||
decoder_input_final, context, alignment = self.attend(decoder_input)
|
||||
|
||||
#mel_output, stop_output, alignment = self.decode(decoder_input)
|
||||
decoder_rnn_output = self.decode(decoder_input_final)
|
||||
if self.concat_context_to_last:
|
||||
decoder_rnn_output = torch.cat(
|
||||
(decoder_rnn_output, context), dim=1)
|
||||
|
||||
mel_output = self.linear_projection(decoder_rnn_output)
|
||||
# (B, 1)
|
||||
stop_output = self.stop_layer(decoder_rnn_output)
|
||||
stop_outputs += [stop_output.squeeze()]
|
||||
# stop_outputs.append(stop_output)
|
||||
|
||||
mel_outputs += [mel_output.squeeze(1)]
|
||||
alignments += [alignment]
|
||||
# print(stop_output.shape)
|
||||
if torch.all(torch.sigmoid(stop_output.squeeze().data) > stop_threshold) \
|
||||
and len(mel_outputs) >= min_decoder_step:
|
||||
break
|
||||
if len(mel_outputs) >= max_decoder_step:
|
||||
# print("Warning! Decoding steps reaches max decoder steps.")
|
||||
break
|
||||
|
||||
decoder_input = mel_output[:,-self.num_mels:]
|
||||
|
||||
|
||||
mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
|
||||
mel_outputs, alignments, stop_outputs)
|
||||
mel_outputs_stacked = []
|
||||
for mel, stop_logit in zip(mel_outputs, stop_outputs):
|
||||
idx = np.argwhere(torch.sigmoid(stop_logit.cpu()) > stop_threshold)[0][0].item()
|
||||
mel_outputs_stacked.append(mel[:idx,:])
|
||||
mel_outputs = torch.cat(mel_outputs_stacked, dim=0).unsqueeze(0)
|
||||
return mel_outputs, alignments
|
||||
62
ppg2mel/train.py
Normal file
62
ppg2mel/train.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import sys
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from utils.load_yaml import HpsYaml
|
||||
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||
|
||||
# For reproducibility, comment these may speed up training
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def main():
|
||||
# Arguments
|
||||
parser = argparse.ArgumentParser(description=
|
||||
'Training PPG2Mel VC model.')
|
||||
parser.add_argument('--config', type=str,
|
||||
help='Path to experiment config, e.g., config/vc.yaml')
|
||||
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
||||
parser.add_argument('--logdir', default='log/', type=str,
|
||||
help='Logging path.', required=False)
|
||||
parser.add_argument('--ckpdir', default='ckpt/', type=str,
|
||||
help='Checkpoint path.', required=False)
|
||||
parser.add_argument('--outdir', default='result/', type=str,
|
||||
help='Decode output path.', required=False)
|
||||
parser.add_argument('--load', default=None, type=str,
|
||||
help='Load pre-trained model (for training only)', required=False)
|
||||
parser.add_argument('--warm_start', action='store_true',
|
||||
help='Load model weights only, ignore specified layers.')
|
||||
parser.add_argument('--seed', default=0, type=int,
|
||||
help='Random seed for reproducable results.', required=False)
|
||||
parser.add_argument('--njobs', default=8, type=int,
|
||||
help='Number of threads for dataloader/decoding.', required=False)
|
||||
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
||||
# parser.add_argument('--no-pin', action='store_true',
|
||||
# help='Disable pin-memory for dataloader')
|
||||
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
||||
|
||||
###
|
||||
|
||||
paras = parser.parse_args()
|
||||
setattr(paras, 'gpu', not paras.cpu)
|
||||
setattr(paras, 'pin_memory', not paras.no_pin)
|
||||
setattr(paras, 'verbose', not paras.no_msg)
|
||||
# Make the config dict dot visitable
|
||||
config = HpsYaml(paras.config)
|
||||
|
||||
np.random.seed(paras.seed)
|
||||
torch.manual_seed(paras.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(paras.seed)
|
||||
|
||||
print(">>> OneShot VC training ...")
|
||||
mode = "train"
|
||||
solver = Solver(config, paras, mode)
|
||||
solver.load_data()
|
||||
solver.set_model()
|
||||
solver.exec()
|
||||
print(">>> Oneshot VC train finished!")
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
ppg2mel/train/__init__.py
Normal file
1
ppg2mel/train/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
#
|
||||
50
ppg2mel/train/loss.py
Normal file
50
ppg2mel/train/loss.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import Dict
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class MaskedMSELoss(nn.Module):
|
||||
def __init__(self, frames_per_step):
|
||||
super().__init__()
|
||||
self.frames_per_step = frames_per_step
|
||||
self.mel_loss_criterion = nn.MSELoss(reduction='none')
|
||||
# self.loss = nn.MSELoss()
|
||||
self.stop_loss_criterion = nn.BCEWithLogitsLoss(reduction='none')
|
||||
|
||||
def get_mask(self, lengths, max_len=None):
|
||||
# lengths: [B,]
|
||||
if max_len is None:
|
||||
max_len = torch.max(lengths)
|
||||
batch_size = lengths.size(0)
|
||||
seq_range = torch.arange(0, max_len).long()
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len).to(lengths.device)
|
||||
seq_length_expand = lengths.unsqueeze(1).expand_as(seq_range_expand)
|
||||
return (seq_range_expand < seq_length_expand).float()
|
||||
|
||||
def forward(self, mel_pred, mel_pred_postnet, mel_trg, lengths,
|
||||
stop_target, stop_pred):
|
||||
## process stop_target
|
||||
B = stop_target.size(0)
|
||||
stop_target = stop_target.reshape(B, -1, self.frames_per_step)[:, :, 0]
|
||||
stop_lengths = torch.ceil(lengths.float() / self.frames_per_step).long()
|
||||
stop_mask = self.get_mask(stop_lengths, int(mel_trg.size(1)/self.frames_per_step))
|
||||
|
||||
mel_trg.requires_grad = False
|
||||
# (B, T, 1)
|
||||
mel_mask = self.get_mask(lengths, mel_trg.size(1)).unsqueeze(-1)
|
||||
# (B, T, D)
|
||||
mel_mask = mel_mask.expand_as(mel_trg)
|
||||
mel_loss_pre = (self.mel_loss_criterion(mel_pred, mel_trg) * mel_mask).sum() / mel_mask.sum()
|
||||
mel_loss_post = (self.mel_loss_criterion(mel_pred_postnet, mel_trg) * mel_mask).sum() / mel_mask.sum()
|
||||
|
||||
mel_loss = mel_loss_pre + mel_loss_post
|
||||
|
||||
# stop token loss
|
||||
stop_loss = torch.sum(self.stop_loss_criterion(stop_pred, stop_target) * stop_mask) / stop_mask.sum()
|
||||
|
||||
return mel_loss, stop_loss
|
||||
45
ppg2mel/train/optim.py
Normal file
45
ppg2mel/train/optim.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Optimizer():
|
||||
def __init__(self, parameters, optimizer, lr, eps, lr_scheduler,
|
||||
**kwargs):
|
||||
|
||||
# Setup torch optimizer
|
||||
self.opt_type = optimizer
|
||||
self.init_lr = lr
|
||||
self.sch_type = lr_scheduler
|
||||
opt = getattr(torch.optim, optimizer)
|
||||
if lr_scheduler == 'warmup':
|
||||
warmup_step = 4000.0
|
||||
init_lr = lr
|
||||
self.lr_scheduler = lambda step: init_lr * warmup_step ** 0.5 * \
|
||||
np.minimum((step+1)*warmup_step**-1.5, (step+1)**-0.5)
|
||||
self.opt = opt(parameters, lr=1.0)
|
||||
else:
|
||||
self.lr_scheduler = None
|
||||
self.opt = opt(parameters, lr=lr, eps=eps) # ToDo: 1e-8 better?
|
||||
|
||||
def get_opt_state_dict(self):
|
||||
return self.opt.state_dict()
|
||||
|
||||
def load_opt_state_dict(self, state_dict):
|
||||
self.opt.load_state_dict(state_dict)
|
||||
|
||||
def pre_step(self, step):
|
||||
if self.lr_scheduler is not None:
|
||||
cur_lr = self.lr_scheduler(step)
|
||||
for param_group in self.opt.param_groups:
|
||||
param_group['lr'] = cur_lr
|
||||
else:
|
||||
cur_lr = self.init_lr
|
||||
self.opt.zero_grad()
|
||||
return cur_lr
|
||||
|
||||
def step(self):
|
||||
self.opt.step()
|
||||
|
||||
def create_msg(self):
|
||||
return ['Optim.Info.| Algo. = {}\t| Lr = {}\t (schedule = {})'
|
||||
.format(self.opt_type, self.init_lr, self.sch_type)]
|
||||
10
ppg2mel/train/option.py
Normal file
10
ppg2mel/train/option.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# Default parameters which will be imported by solver
|
||||
default_hparas = {
|
||||
'GRAD_CLIP': 5.0, # Grad. clip threshold
|
||||
'PROGRESS_STEP': 100, # Std. output refresh freq.
|
||||
# Decode steps for objective validation (step = ratio*input_txt_len)
|
||||
'DEV_STEP_RATIO': 1.2,
|
||||
# Number of examples (alignment/text) to show in tensorboard
|
||||
'DEV_N_EXAMPLE': 4,
|
||||
'TB_FLUSH_FREQ': 180 # Update frequency of tensorboard (secs)
|
||||
}
|
||||
217
ppg2mel/train/solver.py
Normal file
217
ppg2mel/train/solver.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import os
|
||||
import sys
|
||||
import abc
|
||||
import math
|
||||
import yaml
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from .option import default_hparas
|
||||
from utils.util import human_format, Timer
|
||||
from utils.load_yaml import HpsYaml
|
||||
|
||||
|
||||
class BaseSolver():
|
||||
'''
|
||||
Prototype Solver for all kinds of tasks
|
||||
Arguments
|
||||
config - yaml-styled config
|
||||
paras - argparse outcome
|
||||
mode - "train"/"test"
|
||||
'''
|
||||
|
||||
def __init__(self, config, paras, mode="train"):
|
||||
# General Settings
|
||||
self.config = config # load from yaml file
|
||||
self.paras = paras # command line args
|
||||
self.mode = mode # 'train' or 'test'
|
||||
for k, v in default_hparas.items():
|
||||
setattr(self, k, v)
|
||||
self.device = torch.device('cuda') if self.paras.gpu and torch.cuda.is_available() \
|
||||
else torch.device('cpu')
|
||||
|
||||
# Name experiment
|
||||
self.exp_name = paras.name
|
||||
if self.exp_name is None:
|
||||
if 'exp_name' in self.config:
|
||||
self.exp_name = self.config.exp_name
|
||||
else:
|
||||
# By default, exp is named after config file
|
||||
self.exp_name = paras.config.split('/')[-1].replace('.yaml', '')
|
||||
if mode == 'train':
|
||||
self.exp_name += '_seed{}'.format(paras.seed)
|
||||
|
||||
|
||||
if mode == 'train':
|
||||
# Filepath setup
|
||||
os.makedirs(paras.ckpdir, exist_ok=True)
|
||||
self.ckpdir = os.path.join(paras.ckpdir, self.exp_name)
|
||||
os.makedirs(self.ckpdir, exist_ok=True)
|
||||
|
||||
# Logger settings
|
||||
self.logdir = os.path.join(paras.logdir, self.exp_name)
|
||||
self.log = SummaryWriter(
|
||||
self.logdir, flush_secs=self.TB_FLUSH_FREQ)
|
||||
self.timer = Timer()
|
||||
|
||||
# Hyper-parameters
|
||||
self.step = 0
|
||||
self.valid_step = config.hparas.valid_step
|
||||
self.max_step = config.hparas.max_step
|
||||
|
||||
self.verbose('Exp. name : {}'.format(self.exp_name))
|
||||
self.verbose('Loading data... large corpus may took a while.')
|
||||
|
||||
# elif mode == 'test':
|
||||
# # Output path
|
||||
# os.makedirs(paras.outdir, exist_ok=True)
|
||||
# self.ckpdir = os.path.join(paras.outdir, self.exp_name)
|
||||
|
||||
# Load training config to get acoustic feat and build model
|
||||
# self.src_config = HpsYaml(config.src.config)
|
||||
# self.paras.load = config.src.ckpt
|
||||
|
||||
# self.verbose('Evaluating result of tr. config @ {}'.format(
|
||||
# config.src.config))
|
||||
|
||||
def backward(self, loss):
|
||||
'''
|
||||
Standard backward step with self.timer and debugger
|
||||
Arguments
|
||||
loss - the loss to perform loss.backward()
|
||||
'''
|
||||
self.timer.set()
|
||||
loss.backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(), self.GRAD_CLIP)
|
||||
if math.isnan(grad_norm):
|
||||
self.verbose('Error : grad norm is NaN @ step '+str(self.step))
|
||||
else:
|
||||
self.optimizer.step()
|
||||
self.timer.cnt('bw')
|
||||
return grad_norm
|
||||
|
||||
def load_ckpt(self):
|
||||
''' Load ckpt if --load option is specified '''
|
||||
print(self.paras)
|
||||
if self.paras.load is not None:
|
||||
if self.paras.warm_start:
|
||||
self.verbose(f"Warm starting model from checkpoint {self.paras.load}.")
|
||||
ckpt = torch.load(
|
||||
self.paras.load, map_location=self.device if self.mode == 'train'
|
||||
else 'cpu')
|
||||
model_dict = ckpt['model']
|
||||
if "ignore_layers" in self.config.model and len(self.config.model.ignore_layers) > 0:
|
||||
model_dict = {k:v for k, v in model_dict.items()
|
||||
if k not in self.config.model.ignore_layers}
|
||||
dummy_dict = self.model.state_dict()
|
||||
dummy_dict.update(model_dict)
|
||||
model_dict = dummy_dict
|
||||
self.model.load_state_dict(model_dict)
|
||||
else:
|
||||
# Load weights
|
||||
ckpt = torch.load(
|
||||
self.paras.load, map_location=self.device if self.mode == 'train'
|
||||
else 'cpu')
|
||||
self.model.load_state_dict(ckpt['model'])
|
||||
|
||||
# Load task-dependent items
|
||||
if self.mode == 'train':
|
||||
self.step = ckpt['global_step']
|
||||
self.optimizer.load_opt_state_dict(ckpt['optimizer'])
|
||||
self.verbose('Load ckpt from {}, restarting at step {}'.format(
|
||||
self.paras.load, self.step))
|
||||
else:
|
||||
for k, v in ckpt.items():
|
||||
if type(v) is float:
|
||||
metric, score = k, v
|
||||
self.model.eval()
|
||||
self.verbose('Evaluation target = {} (recorded {} = {:.2f} %)'.format(
|
||||
self.paras.load, metric, score))
|
||||
|
||||
def verbose(self, msg):
|
||||
''' Verbose function for print information to stdout'''
|
||||
if self.paras.verbose:
|
||||
if type(msg) == list:
|
||||
for m in msg:
|
||||
print('[INFO]', m.ljust(100))
|
||||
else:
|
||||
print('[INFO]', msg.ljust(100))
|
||||
|
||||
def progress(self, msg):
|
||||
''' Verbose function for updating progress on stdout (do not include newline) '''
|
||||
if self.paras.verbose:
|
||||
sys.stdout.write("\033[K") # Clear line
|
||||
print('[{}] {}'.format(human_format(self.step), msg), end='\r')
|
||||
|
||||
def write_log(self, log_name, log_dict):
|
||||
'''
|
||||
Write log to TensorBoard
|
||||
log_name - <str> Name of tensorboard variable
|
||||
log_value - <dict>/<array> Value of variable (e.g. dict of losses), passed if value = None
|
||||
'''
|
||||
if type(log_dict) is dict:
|
||||
log_dict = {key: val for key, val in log_dict.items() if (
|
||||
val is not None and not math.isnan(val))}
|
||||
if log_dict is None:
|
||||
pass
|
||||
elif len(log_dict) > 0:
|
||||
if 'align' in log_name or 'spec' in log_name:
|
||||
img, form = log_dict
|
||||
self.log.add_image(
|
||||
log_name, img, global_step=self.step, dataformats=form)
|
||||
elif 'text' in log_name or 'hyp' in log_name:
|
||||
self.log.add_text(log_name, log_dict, self.step)
|
||||
else:
|
||||
self.log.add_scalars(log_name, log_dict, self.step)
|
||||
|
||||
def save_checkpoint(self, f_name, metric, score, show_msg=True):
|
||||
''''
|
||||
Ckpt saver
|
||||
f_name - <str> the name of ckpt file (w/o prefix) to store, overwrite if existed
|
||||
score - <float> The value of metric used to evaluate model
|
||||
'''
|
||||
ckpt_path = os.path.join(self.ckpdir, f_name)
|
||||
full_dict = {
|
||||
"model": self.model.state_dict(),
|
||||
"optimizer": self.optimizer.get_opt_state_dict(),
|
||||
"global_step": self.step,
|
||||
metric: score
|
||||
}
|
||||
|
||||
torch.save(full_dict, ckpt_path)
|
||||
if show_msg:
|
||||
self.verbose("Saved checkpoint (step = {}, {} = {:.2f}) and status @ {}".
|
||||
format(human_format(self.step), metric, score, ckpt_path))
|
||||
|
||||
|
||||
# ----------------------------------- Abtract Methods ------------------------------------------ #
|
||||
@abc.abstractmethod
|
||||
def load_data(self):
|
||||
'''
|
||||
Called by main to load all data
|
||||
After this call, data related attributes should be setup (e.g. self.tr_set, self.dev_set)
|
||||
No return value
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_model(self):
|
||||
'''
|
||||
Called by main to set models
|
||||
After this call, model related attributes should be setup (e.g. self.l2_loss)
|
||||
The followings MUST be setup
|
||||
- self.model (torch.nn.Module)
|
||||
- self.optimizer (src.Optimizer),
|
||||
init. w/ self.optimizer = src.Optimizer(self.model.parameters(),**self.config['hparas'])
|
||||
Loading pre-trained model should also be performed here
|
||||
No return value
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def exec(self):
|
||||
'''
|
||||
Called by main to execute training/inference
|
||||
'''
|
||||
raise NotImplementedError
|
||||
288
ppg2mel/train/train_linglf02mel_seq2seq_oneshotvc.py
Normal file
288
ppg2mel/train/train_linglf02mel_seq2seq_oneshotvc.py
Normal file
@@ -0,0 +1,288 @@
|
||||
import os, sys
|
||||
# sys.path.append('/home/shaunxliu/projects/nnsp')
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
from .solver import BaseSolver
|
||||
from utils.data_load import OneshotVcDataset, MultiSpkVcCollate
|
||||
# from src.rnn_ppg2mel import BiRnnPpg2MelModel
|
||||
# from src.mel_decoder_mol_encAddlf0 import MelDecoderMOL
|
||||
from .loss import MaskedMSELoss
|
||||
from .optim import Optimizer
|
||||
from utils.util import human_format
|
||||
from ppg2mel import MelDecoderMOLv2
|
||||
|
||||
|
||||
class Solver(BaseSolver):
|
||||
"""Customized Solver."""
|
||||
def __init__(self, config, paras, mode):
|
||||
super().__init__(config, paras, mode)
|
||||
self.num_att_plots = 5
|
||||
self.att_ws_dir = f"{self.logdir}/att_ws"
|
||||
os.makedirs(self.att_ws_dir, exist_ok=True)
|
||||
self.best_loss = np.inf
|
||||
|
||||
def fetch_data(self, data):
|
||||
"""Move data to device"""
|
||||
data = [i.to(self.device) for i in data]
|
||||
return data
|
||||
|
||||
def load_data(self):
|
||||
""" Load data for training/validation/plotting."""
|
||||
train_dataset = OneshotVcDataset(
|
||||
meta_file=self.config.data.train_fid_list,
|
||||
vctk_ppg_dir=self.config.data.vctk_ppg_dir,
|
||||
libri_ppg_dir=self.config.data.libri_ppg_dir,
|
||||
vctk_f0_dir=self.config.data.vctk_f0_dir,
|
||||
libri_f0_dir=self.config.data.libri_f0_dir,
|
||||
vctk_wav_dir=self.config.data.vctk_wav_dir,
|
||||
libri_wav_dir=self.config.data.libri_wav_dir,
|
||||
vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir,
|
||||
libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir,
|
||||
ppg_file_ext=self.config.data.ppg_file_ext,
|
||||
min_max_norm_mel=self.config.data.min_max_norm_mel,
|
||||
mel_min=self.config.data.mel_min,
|
||||
mel_max=self.config.data.mel_max,
|
||||
)
|
||||
dev_dataset = OneshotVcDataset(
|
||||
meta_file=self.config.data.dev_fid_list,
|
||||
vctk_ppg_dir=self.config.data.vctk_ppg_dir,
|
||||
libri_ppg_dir=self.config.data.libri_ppg_dir,
|
||||
vctk_f0_dir=self.config.data.vctk_f0_dir,
|
||||
libri_f0_dir=self.config.data.libri_f0_dir,
|
||||
vctk_wav_dir=self.config.data.vctk_wav_dir,
|
||||
libri_wav_dir=self.config.data.libri_wav_dir,
|
||||
vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir,
|
||||
libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir,
|
||||
ppg_file_ext=self.config.data.ppg_file_ext,
|
||||
min_max_norm_mel=self.config.data.min_max_norm_mel,
|
||||
mel_min=self.config.data.mel_min,
|
||||
mel_max=self.config.data.mel_max,
|
||||
)
|
||||
self.train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
num_workers=self.paras.njobs,
|
||||
shuffle=True,
|
||||
batch_size=self.config.hparas.batch_size,
|
||||
pin_memory=False,
|
||||
drop_last=True,
|
||||
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
|
||||
use_spk_dvec=True),
|
||||
)
|
||||
self.dev_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
num_workers=self.paras.njobs,
|
||||
shuffle=False,
|
||||
batch_size=self.config.hparas.batch_size,
|
||||
pin_memory=False,
|
||||
drop_last=False,
|
||||
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
|
||||
use_spk_dvec=True),
|
||||
)
|
||||
self.plot_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
num_workers=self.paras.njobs,
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
pin_memory=False,
|
||||
drop_last=False,
|
||||
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
|
||||
use_spk_dvec=True,
|
||||
give_uttids=True),
|
||||
)
|
||||
msg = "Have prepared training set and dev set."
|
||||
self.verbose(msg)
|
||||
|
||||
def load_pretrained_params(self):
|
||||
print("Load pretrained model from: ", self.config.data.pretrain_model_file)
|
||||
ignore_layer_prefixes = ["speaker_embedding_table"]
|
||||
pretrain_model_file = self.config.data.pretrain_model_file
|
||||
pretrain_ckpt = torch.load(
|
||||
pretrain_model_file, map_location=self.device
|
||||
)["model"]
|
||||
model_dict = self.model.state_dict()
|
||||
print(self.model)
|
||||
|
||||
# 1. filter out unnecessrary keys
|
||||
for prefix in ignore_layer_prefixes:
|
||||
pretrain_ckpt = {k : v
|
||||
for k, v in pretrain_ckpt.items() if not k.startswith(prefix)
|
||||
}
|
||||
# 2. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrain_ckpt)
|
||||
|
||||
# 3. load the new state dict
|
||||
self.model.load_state_dict(model_dict)
|
||||
|
||||
def set_model(self):
|
||||
"""Setup model and optimizer"""
|
||||
# Model
|
||||
print("[INFO] Model name: ", self.config["model_name"])
|
||||
self.model = MelDecoderMOLv2(
|
||||
**self.config["model"]
|
||||
).to(self.device)
|
||||
# self.load_pretrained_params()
|
||||
|
||||
# model_params = [{'params': self.model.spk_embedding.weight}]
|
||||
model_params = [{'params': self.model.parameters()}]
|
||||
|
||||
# Loss criterion
|
||||
self.loss_criterion = MaskedMSELoss(self.config.model.frames_per_step)
|
||||
|
||||
# Optimizer
|
||||
self.optimizer = Optimizer(model_params, **self.config["hparas"])
|
||||
self.verbose(self.optimizer.create_msg())
|
||||
|
||||
# Automatically load pre-trained model if self.paras.load is given
|
||||
self.load_ckpt()
|
||||
|
||||
def exec(self):
|
||||
self.verbose("Total training steps {}.".format(
|
||||
human_format(self.max_step)))
|
||||
|
||||
mel_loss = None
|
||||
n_epochs = 0
|
||||
# Set as current time
|
||||
self.timer.set()
|
||||
|
||||
while self.step < self.max_step:
|
||||
for data in self.train_dataloader:
|
||||
# Pre-step: updata lr_rate and do zero_grad
|
||||
lr_rate = self.optimizer.pre_step(self.step)
|
||||
total_loss = 0
|
||||
# data to device
|
||||
ppgs, lf0_uvs, mels, in_lengths, \
|
||||
out_lengths, spk_ids, stop_tokens = self.fetch_data(data)
|
||||
self.timer.cnt("rd")
|
||||
mel_outputs, mel_outputs_postnet, predicted_stop = self.model(
|
||||
ppgs,
|
||||
in_lengths,
|
||||
mels,
|
||||
out_lengths,
|
||||
lf0_uvs,
|
||||
spk_ids
|
||||
)
|
||||
mel_loss, stop_loss = self.loss_criterion(
|
||||
mel_outputs,
|
||||
mel_outputs_postnet,
|
||||
mels,
|
||||
out_lengths,
|
||||
stop_tokens,
|
||||
predicted_stop
|
||||
)
|
||||
loss = mel_loss + stop_loss
|
||||
|
||||
self.timer.cnt("fw")
|
||||
|
||||
# Back-prop
|
||||
grad_norm = self.backward(loss)
|
||||
self.step += 1
|
||||
|
||||
# Logger
|
||||
if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
|
||||
self.progress("Tr|loss:{:.4f},mel-loss:{:.4f},stop-loss:{:.4f}|Grad.Norm-{:.2f}|{}"
|
||||
.format(loss.cpu().item(), mel_loss.cpu().item(),
|
||||
stop_loss.cpu().item(), grad_norm, self.timer.show()))
|
||||
self.write_log('loss', {'tr/loss': loss,
|
||||
'tr/mel-loss': mel_loss,
|
||||
'tr/stop-loss': stop_loss})
|
||||
|
||||
# Validation
|
||||
if (self.step == 1) or (self.step % self.valid_step == 0):
|
||||
self.validate()
|
||||
|
||||
# End of step
|
||||
# https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
|
||||
torch.cuda.empty_cache()
|
||||
self.timer.set()
|
||||
if self.step > self.max_step:
|
||||
break
|
||||
n_epochs += 1
|
||||
self.log.close()
|
||||
|
||||
def validate(self):
|
||||
self.model.eval()
|
||||
dev_loss, dev_mel_loss, dev_stop_loss = 0.0, 0.0, 0.0
|
||||
|
||||
for i, data in enumerate(self.dev_dataloader):
|
||||
self.progress('Valid step - {}/{}'.format(i+1, len(self.dev_dataloader)))
|
||||
# Fetch data
|
||||
ppgs, lf0_uvs, mels, in_lengths, \
|
||||
out_lengths, spk_ids, stop_tokens = self.fetch_data(data)
|
||||
with torch.no_grad():
|
||||
mel_outputs, mel_outputs_postnet, predicted_stop = self.model(
|
||||
ppgs,
|
||||
in_lengths,
|
||||
mels,
|
||||
out_lengths,
|
||||
lf0_uvs,
|
||||
spk_ids
|
||||
)
|
||||
mel_loss, stop_loss = self.loss_criterion(
|
||||
mel_outputs,
|
||||
mel_outputs_postnet,
|
||||
mels,
|
||||
out_lengths,
|
||||
stop_tokens,
|
||||
predicted_stop
|
||||
)
|
||||
loss = mel_loss + stop_loss
|
||||
|
||||
dev_loss += loss.cpu().item()
|
||||
dev_mel_loss += mel_loss.cpu().item()
|
||||
dev_stop_loss += stop_loss.cpu().item()
|
||||
|
||||
dev_loss = dev_loss / (i + 1)
|
||||
dev_mel_loss = dev_mel_loss / (i + 1)
|
||||
dev_stop_loss = dev_stop_loss / (i + 1)
|
||||
self.save_checkpoint(f'step_{self.step}.pth', 'loss', dev_loss, show_msg=False)
|
||||
if dev_loss < self.best_loss:
|
||||
self.best_loss = dev_loss
|
||||
self.save_checkpoint(f'best_loss_step_{self.step}.pth', 'loss', dev_loss)
|
||||
self.write_log('loss', {'dv/loss': dev_loss,
|
||||
'dv/mel-loss': dev_mel_loss,
|
||||
'dv/stop-loss': dev_stop_loss})
|
||||
|
||||
# plot attention
|
||||
for i, data in enumerate(self.plot_dataloader):
|
||||
if i == self.num_att_plots:
|
||||
break
|
||||
# Fetch data
|
||||
ppgs, lf0_uvs, mels, in_lengths, \
|
||||
out_lengths, spk_ids, stop_tokens = self.fetch_data(data[:-1])
|
||||
fid = data[-1][0]
|
||||
with torch.no_grad():
|
||||
_, _, _, att_ws = self.model(
|
||||
ppgs,
|
||||
in_lengths,
|
||||
mels,
|
||||
out_lengths,
|
||||
lf0_uvs,
|
||||
spk_ids,
|
||||
output_att_ws=True
|
||||
)
|
||||
att_ws = att_ws.squeeze(0).cpu().numpy()
|
||||
att_ws = att_ws[None]
|
||||
w, h = plt.figaspect(1.0 / len(att_ws))
|
||||
fig = plt.Figure(figsize=(w * 1.3, h * 1.3))
|
||||
axes = fig.subplots(1, len(att_ws))
|
||||
if len(att_ws) == 1:
|
||||
axes = [axes]
|
||||
|
||||
for ax, aw in zip(axes, att_ws):
|
||||
ax.imshow(aw.astype(np.float32), aspect="auto")
|
||||
ax.set_title(f"{fid}")
|
||||
ax.set_xlabel("Input")
|
||||
ax.set_ylabel("Output")
|
||||
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||||
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
|
||||
fig_name = f"{self.att_ws_dir}/{fid}_step{self.step}.png"
|
||||
fig.savefig(fig_name)
|
||||
|
||||
# Resume training
|
||||
self.model.train()
|
||||
|
||||
23
ppg2mel/utils/abs_model.py
Normal file
23
ppg2mel/utils/abs_model.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
class AbsMelDecoder(torch.nn.Module, ABC):
|
||||
"""The abstract PPG-based voice conversion class
|
||||
This "model" is one of mediator objects for "Task" class.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
bottle_neck_features: torch.Tensor,
|
||||
feature_lengths: torch.Tensor,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
logf0_uv: torch.Tensor = None,
|
||||
spembs: torch.Tensor = None,
|
||||
styleembs: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
79
ppg2mel/utils/basic_layers.py
Normal file
79
ppg2mel/utils/basic_layers.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
def tile(x, count, dim=0):
|
||||
"""
|
||||
Tiles x on dimension dim count times.
|
||||
"""
|
||||
perm = list(range(len(x.size())))
|
||||
if dim != 0:
|
||||
perm[0], perm[dim] = perm[dim], perm[0]
|
||||
x = x.permute(perm).contiguous()
|
||||
out_size = list(x.size())
|
||||
out_size[0] *= count
|
||||
batch = x.size(0)
|
||||
x = x.view(batch, -1) \
|
||||
.transpose(0, 1) \
|
||||
.repeat(count, 1) \
|
||||
.transpose(0, 1) \
|
||||
.contiguous() \
|
||||
.view(*out_size)
|
||||
if dim != 0:
|
||||
x = x.permute(perm).contiguous()
|
||||
return x
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
||||
super(Linear, self).__init__()
|
||||
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
||||
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear_layer.weight,
|
||||
gain=torch.nn.init.calculate_gain(w_init_gain))
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_layer(x)
|
||||
|
||||
class Conv1d(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
|
||||
padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
|
||||
super(Conv1d, self).__init__()
|
||||
if padding is None:
|
||||
assert(kernel_size % 2 == 1)
|
||||
padding = int(dilation * (kernel_size - 1)/2)
|
||||
|
||||
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation,
|
||||
bias=bias)
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
|
||||
|
||||
def forward(self, x):
|
||||
# x: BxDxT
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
|
||||
def tile(x, count, dim=0):
|
||||
"""
|
||||
Tiles x on dimension dim count times.
|
||||
"""
|
||||
perm = list(range(len(x.size())))
|
||||
if dim != 0:
|
||||
perm[0], perm[dim] = perm[dim], perm[0]
|
||||
x = x.permute(perm).contiguous()
|
||||
out_size = list(x.size())
|
||||
out_size[0] *= count
|
||||
batch = x.size(0)
|
||||
x = x.view(batch, -1) \
|
||||
.transpose(0, 1) \
|
||||
.repeat(count, 1) \
|
||||
.transpose(0, 1) \
|
||||
.contiguous() \
|
||||
.view(*out_size)
|
||||
if dim != 0:
|
||||
x = x.permute(perm).contiguous()
|
||||
return x
|
||||
52
ppg2mel/utils/cnn_postnet.py
Normal file
52
ppg2mel/utils/cnn_postnet.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .basic_layers import Linear, Conv1d
|
||||
|
||||
|
||||
class Postnet(nn.Module):
|
||||
"""Postnet
|
||||
- Five 1-d convolution with 512 channels and kernel size 5
|
||||
"""
|
||||
def __init__(self, num_mels=80,
|
||||
num_layers=5,
|
||||
hidden_dim=512,
|
||||
kernel_size=5):
|
||||
super(Postnet, self).__init__()
|
||||
self.convolutions = nn.ModuleList()
|
||||
|
||||
self.convolutions.append(
|
||||
nn.Sequential(
|
||||
Conv1d(
|
||||
num_mels, hidden_dim,
|
||||
kernel_size=kernel_size, stride=1,
|
||||
padding=int((kernel_size - 1) / 2),
|
||||
dilation=1, w_init_gain='tanh'),
|
||||
nn.BatchNorm1d(hidden_dim)))
|
||||
|
||||
for i in range(1, num_layers - 1):
|
||||
self.convolutions.append(
|
||||
nn.Sequential(
|
||||
Conv1d(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
kernel_size=kernel_size, stride=1,
|
||||
padding=int((kernel_size - 1) / 2),
|
||||
dilation=1, w_init_gain='tanh'),
|
||||
nn.BatchNorm1d(hidden_dim)))
|
||||
|
||||
self.convolutions.append(
|
||||
nn.Sequential(
|
||||
Conv1d(
|
||||
hidden_dim, num_mels,
|
||||
kernel_size=kernel_size, stride=1,
|
||||
padding=int((kernel_size - 1) / 2),
|
||||
dilation=1, w_init_gain='linear'),
|
||||
nn.BatchNorm1d(num_mels)))
|
||||
|
||||
def forward(self, x):
|
||||
# x: (B, num_mels, T_dec)
|
||||
for i in range(len(self.convolutions) - 1):
|
||||
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
|
||||
x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
|
||||
return x
|
||||
123
ppg2mel/utils/mol_attention.py
Normal file
123
ppg2mel/utils/mol_attention.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class MOLAttention(nn.Module):
|
||||
""" Discretized Mixture of Logistic (MOL) attention.
|
||||
C.f. Section 5 of "MelNet: A Generative Model for Audio in the Frequency Domain" and
|
||||
GMMv2b model in "Location-relative attention mechanisms for robust long-form speech synthesis".
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
query_dim,
|
||||
r=1,
|
||||
M=5,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
query_dim: attention_rnn_dim.
|
||||
M: number of mixtures.
|
||||
"""
|
||||
super().__init__()
|
||||
if r < 1:
|
||||
self.r = float(r)
|
||||
else:
|
||||
self.r = int(r)
|
||||
self.M = M
|
||||
self.score_mask_value = 0.0 # -float("inf")
|
||||
self.eps = 1e-5
|
||||
# Position arrary for encoder time steps
|
||||
self.J = None
|
||||
# Query layer: [w, sigma,]
|
||||
self.query_layer = torch.nn.Sequential(
|
||||
nn.Linear(query_dim, 256, bias=True),
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 3*M, bias=True)
|
||||
)
|
||||
self.mu_prev = None
|
||||
self.initialize_bias()
|
||||
|
||||
def initialize_bias(self):
|
||||
"""Initialize sigma and Delta."""
|
||||
# sigma
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[self.M:2*self.M], 1.0)
|
||||
# Delta: softplus(1.8545) = 2.0; softplus(3.9815) = 4.0; softplus(0.5413) = 1.0
|
||||
# softplus(-0.432) = 0.5003
|
||||
if self.r == 2:
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 1.8545)
|
||||
elif self.r == 4:
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 3.9815)
|
||||
elif self.r == 1:
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 0.5413)
|
||||
else:
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], -0.432)
|
||||
|
||||
|
||||
def init_states(self, memory):
|
||||
"""Initialize mu_prev and J.
|
||||
This function should be called by the decoder before decoding one batch.
|
||||
Args:
|
||||
memory: (B, T, D_enc) encoder output.
|
||||
"""
|
||||
B, T_enc, _ = memory.size()
|
||||
device = memory.device
|
||||
self.J = torch.arange(0, T_enc + 2.0).to(device) + 0.5 # NOTE: for discretize usage
|
||||
# self.J = memory.new_tensor(np.arange(T_enc), dtype=torch.float)
|
||||
self.mu_prev = torch.zeros(B, self.M).to(device)
|
||||
|
||||
def forward(self, att_rnn_h, memory, memory_pitch=None, mask=None):
|
||||
"""
|
||||
att_rnn_h: attetion rnn hidden state.
|
||||
memory: encoder outputs (B, T_enc, D).
|
||||
mask: binary mask for padded data (B, T_enc).
|
||||
"""
|
||||
# [B, 3M]
|
||||
mixture_params = self.query_layer(att_rnn_h)
|
||||
|
||||
# [B, M]
|
||||
w_hat = mixture_params[:, :self.M]
|
||||
sigma_hat = mixture_params[:, self.M:2*self.M]
|
||||
Delta_hat = mixture_params[:, 2*self.M:3*self.M]
|
||||
|
||||
# print("w_hat: ", w_hat)
|
||||
# print("sigma_hat: ", sigma_hat)
|
||||
# print("Delta_hat: ", Delta_hat)
|
||||
|
||||
# Dropout to de-correlate attention heads
|
||||
w_hat = F.dropout(w_hat, p=0.5, training=self.training) # NOTE(sx): needed?
|
||||
|
||||
# Mixture parameters
|
||||
w = torch.softmax(w_hat, dim=-1) + self.eps
|
||||
sigma = F.softplus(sigma_hat) + self.eps
|
||||
Delta = F.softplus(Delta_hat)
|
||||
mu_cur = self.mu_prev + Delta
|
||||
# print("w:", w)
|
||||
j = self.J[:memory.size(1) + 1]
|
||||
|
||||
# Attention weights
|
||||
# CDF of logistic distribution
|
||||
phi_t = w.unsqueeze(-1) * (1 / (1 + torch.sigmoid(
|
||||
(mu_cur.unsqueeze(-1) - j) / sigma.unsqueeze(-1))))
|
||||
# print("phi_t:", phi_t)
|
||||
|
||||
# Discretize attention weights
|
||||
# (B, T_enc + 1)
|
||||
alpha_t = torch.sum(phi_t, dim=1)
|
||||
alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1]
|
||||
alpha_t[alpha_t == 0] = self.eps
|
||||
# print("alpha_t: ", alpha_t.size())
|
||||
# Apply masking
|
||||
if mask is not None:
|
||||
alpha_t.data.masked_fill_(mask, self.score_mask_value)
|
||||
|
||||
context = torch.bmm(alpha_t.unsqueeze(1), memory).squeeze(1)
|
||||
if memory_pitch is not None:
|
||||
context_pitch = torch.bmm(alpha_t.unsqueeze(1), memory_pitch).squeeze(1)
|
||||
|
||||
self.mu_prev = mu_cur
|
||||
|
||||
if memory_pitch is not None:
|
||||
return context, context_pitch, alpha_t
|
||||
return context, alpha_t
|
||||
|
||||
451
ppg2mel/utils/nets_utils.py
Normal file
451
ppg2mel/utils/nets_utils.py
Normal file
@@ -0,0 +1,451 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""Network related utility tools."""
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def to_device(m, x):
|
||||
"""Send tensor into the device of the module.
|
||||
|
||||
Args:
|
||||
m (torch.nn.Module): Torch module.
|
||||
x (Tensor): Torch tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: Torch tensor located in the same place as torch module.
|
||||
|
||||
"""
|
||||
assert isinstance(m, torch.nn.Module)
|
||||
device = next(m.parameters()).device
|
||||
return x.to(device)
|
||||
|
||||
|
||||
def pad_list(xs, pad_value):
|
||||
"""Perform padding for the list of tensors.
|
||||
|
||||
Args:
|
||||
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
||||
pad_value (float): Value for padding.
|
||||
|
||||
Returns:
|
||||
Tensor: Padded tensor (B, Tmax, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
||||
>>> x
|
||||
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
||||
>>> pad_list(x, 0)
|
||||
tensor([[1., 1., 1., 1.],
|
||||
[1., 1., 0., 0.],
|
||||
[1., 0., 0., 0.]])
|
||||
|
||||
"""
|
||||
n_batch = len(xs)
|
||||
max_len = max(x.size(0) for x in xs)
|
||||
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
||||
|
||||
for i in range(n_batch):
|
||||
pad[i, :xs[i].size(0)] = xs[i]
|
||||
|
||||
return pad
|
||||
|
||||
|
||||
def make_pad_mask(lengths, xs=None, length_dim=-1):
|
||||
"""Make mask tensor containing indices of padded part.
|
||||
|
||||
Args:
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
|
||||
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
|
||||
|
||||
Returns:
|
||||
Tensor: Mask tensor containing indices of padded part.
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||
|
||||
Examples:
|
||||
With only lengths.
|
||||
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_non_pad_mask(lengths)
|
||||
masks = [[0, 0, 0, 0 ,0],
|
||||
[0, 0, 0, 1, 1],
|
||||
[0, 0, 1, 1, 1]]
|
||||
|
||||
With the reference tensor.
|
||||
|
||||
>>> xs = torch.zeros((3, 2, 4))
|
||||
>>> make_pad_mask(lengths, xs)
|
||||
tensor([[[0, 0, 0, 0],
|
||||
[0, 0, 0, 0]],
|
||||
[[0, 0, 0, 1],
|
||||
[0, 0, 0, 1]],
|
||||
[[0, 0, 1, 1],
|
||||
[0, 0, 1, 1]]], dtype=torch.uint8)
|
||||
>>> xs = torch.zeros((3, 2, 6))
|
||||
>>> make_pad_mask(lengths, xs)
|
||||
tensor([[[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1]],
|
||||
[[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1]],
|
||||
[[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
|
||||
With the reference tensor and dimension indicator.
|
||||
|
||||
>>> xs = torch.zeros((3, 6, 6))
|
||||
>>> make_pad_mask(lengths, xs, 1)
|
||||
tensor([[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1]],
|
||||
[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1]],
|
||||
[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
>>> make_pad_mask(lengths, xs, 2)
|
||||
tensor([[[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1]],
|
||||
[[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1]],
|
||||
[[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
|
||||
"""
|
||||
if length_dim == 0:
|
||||
raise ValueError('length_dim cannot be 0: {}'.format(length_dim))
|
||||
|
||||
if not isinstance(lengths, list):
|
||||
lengths = lengths.tolist()
|
||||
bs = int(len(lengths))
|
||||
if xs is None:
|
||||
maxlen = int(max(lengths))
|
||||
else:
|
||||
maxlen = xs.size(length_dim)
|
||||
|
||||
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
||||
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
|
||||
if xs is not None:
|
||||
assert xs.size(0) == bs, (xs.size(0), bs)
|
||||
|
||||
if length_dim < 0:
|
||||
length_dim = xs.dim() + length_dim
|
||||
# ind = (:, None, ..., None, :, , None, ..., None)
|
||||
ind = tuple(slice(None) if i in (0, length_dim) else None
|
||||
for i in range(xs.dim()))
|
||||
mask = mask[ind].expand_as(xs).to(xs.device)
|
||||
return mask
|
||||
|
||||
|
||||
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
|
||||
"""Make mask tensor containing indices of non-padded part.
|
||||
|
||||
Args:
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
|
||||
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
|
||||
|
||||
Returns:
|
||||
ByteTensor: mask tensor containing indices of padded part.
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||
|
||||
Examples:
|
||||
With only lengths.
|
||||
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_non_pad_mask(lengths)
|
||||
masks = [[1, 1, 1, 1 ,1],
|
||||
[1, 1, 1, 0, 0],
|
||||
[1, 1, 0, 0, 0]]
|
||||
|
||||
With the reference tensor.
|
||||
|
||||
>>> xs = torch.zeros((3, 2, 4))
|
||||
>>> make_non_pad_mask(lengths, xs)
|
||||
tensor([[[1, 1, 1, 1],
|
||||
[1, 1, 1, 1]],
|
||||
[[1, 1, 1, 0],
|
||||
[1, 1, 1, 0]],
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0]]], dtype=torch.uint8)
|
||||
>>> xs = torch.zeros((3, 2, 6))
|
||||
>>> make_non_pad_mask(lengths, xs)
|
||||
tensor([[[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0]],
|
||||
[[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0]],
|
||||
[[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
|
||||
With the reference tensor and dimension indicator.
|
||||
|
||||
>>> xs = torch.zeros((3, 6, 6))
|
||||
>>> make_non_pad_mask(lengths, xs, 1)
|
||||
tensor([[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0]],
|
||||
[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0]],
|
||||
[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
>>> make_non_pad_mask(lengths, xs, 2)
|
||||
tensor([[[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0]],
|
||||
[[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0]],
|
||||
[[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
|
||||
"""
|
||||
return ~make_pad_mask(lengths, xs, length_dim)
|
||||
|
||||
|
||||
def mask_by_length(xs, lengths, fill=0):
|
||||
"""Mask tensor according to length.
|
||||
|
||||
Args:
|
||||
xs (Tensor): Batch of input tensor (B, `*`).
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
fill (int or float): Value to fill masked part.
|
||||
|
||||
Returns:
|
||||
Tensor: Batch of masked input tensor (B, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = torch.arange(5).repeat(3, 1) + 1
|
||||
>>> x
|
||||
tensor([[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 4, 5]])
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> mask_by_length(x, lengths)
|
||||
tensor([[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 0, 0],
|
||||
[1, 2, 0, 0, 0]])
|
||||
|
||||
"""
|
||||
assert xs.size(0) == len(lengths)
|
||||
ret = xs.data.new(*xs.size()).fill_(fill)
|
||||
for i, l in enumerate(lengths):
|
||||
ret[i, :l] = xs[i, :l]
|
||||
return ret
|
||||
|
||||
|
||||
def th_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||
"""Calculate accuracy.
|
||||
|
||||
Args:
|
||||
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||||
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
||||
ignore_label (int): Ignore label id.
|
||||
|
||||
Returns:
|
||||
float: Accuracy value (0.0 - 1.0).
|
||||
|
||||
"""
|
||||
pad_pred = pad_outputs.view(
|
||||
pad_targets.size(0),
|
||||
pad_targets.size(1),
|
||||
pad_outputs.size(1)).argmax(2)
|
||||
mask = pad_targets != ignore_label
|
||||
numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||||
denominator = torch.sum(mask)
|
||||
return float(numerator) / float(denominator)
|
||||
|
||||
|
||||
def to_torch_tensor(x):
|
||||
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
|
||||
|
||||
Args:
|
||||
x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
|
||||
|
||||
Returns:
|
||||
Tensor or ComplexTensor: Type converted inputs.
|
||||
|
||||
Examples:
|
||||
>>> xs = np.ones(3, dtype=np.float32)
|
||||
>>> xs = to_torch_tensor(xs)
|
||||
tensor([1., 1., 1.])
|
||||
>>> xs = torch.ones(3, 4, 5)
|
||||
>>> assert to_torch_tensor(xs) is xs
|
||||
>>> xs = {'real': xs, 'imag': xs}
|
||||
>>> to_torch_tensor(xs)
|
||||
ComplexTensor(
|
||||
Real:
|
||||
tensor([1., 1., 1.])
|
||||
Imag;
|
||||
tensor([1., 1., 1.])
|
||||
)
|
||||
|
||||
"""
|
||||
# If numpy, change to torch tensor
|
||||
if isinstance(x, np.ndarray):
|
||||
if x.dtype.kind == 'c':
|
||||
# Dynamically importing because torch_complex requires python3
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
return ComplexTensor(x)
|
||||
else:
|
||||
return torch.from_numpy(x)
|
||||
|
||||
# If {'real': ..., 'imag': ...}, convert to ComplexTensor
|
||||
elif isinstance(x, dict):
|
||||
# Dynamically importing because torch_complex requires python3
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
if 'real' not in x or 'imag' not in x:
|
||||
raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
|
||||
# Relative importing because of using python3 syntax
|
||||
return ComplexTensor(x['real'], x['imag'])
|
||||
|
||||
# If torch.Tensor, as it is
|
||||
elif isinstance(x, torch.Tensor):
|
||||
return x
|
||||
|
||||
else:
|
||||
error = ("x must be numpy.ndarray, torch.Tensor or a dict like "
|
||||
"{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
|
||||
"but got {}".format(type(x)))
|
||||
try:
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
except Exception:
|
||||
# If PY2
|
||||
raise ValueError(error)
|
||||
else:
|
||||
# If PY3
|
||||
if isinstance(x, ComplexTensor):
|
||||
return x
|
||||
else:
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def get_subsample(train_args, mode, arch):
|
||||
"""Parse the subsampling factors from the training args for the specified `mode` and `arch`.
|
||||
|
||||
Args:
|
||||
train_args: argument Namespace containing options.
|
||||
mode: one of ('asr', 'mt', 'st')
|
||||
arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
|
||||
|
||||
Returns:
|
||||
np.ndarray / List[np.ndarray]: subsampling factors.
|
||||
"""
|
||||
if arch == 'transformer':
|
||||
return np.array([1])
|
||||
|
||||
elif mode == 'mt' and arch == 'rnn':
|
||||
# +1 means input (+1) and layers outputs (train_args.elayer)
|
||||
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||
logging.warning('Subsampling is not performed for machine translation.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif (mode == 'asr' and arch in ('rnn', 'rnn-t')) or \
|
||||
(mode == 'mt' and arch == 'rnn') or \
|
||||
(mode == 'st' and arch == 'rnn'):
|
||||
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||
ss = train_args.subsample.split("_")
|
||||
for j in range(min(train_args.elayers + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif mode == 'asr' and arch == 'rnn_mix':
|
||||
subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int)
|
||||
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||
ss = train_args.subsample.split("_")
|
||||
for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif mode == 'asr' and arch == 'rnn_mulenc':
|
||||
subsample_list = []
|
||||
for idx in range(train_args.num_encs):
|
||||
subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int)
|
||||
if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"):
|
||||
ss = train_args.subsample[idx].split("_")
|
||||
for j in range(min(train_args.elayers[idx] + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Encoder %d: Subsampling is not performed for vgg*. '
|
||||
'It is performed in max pooling layers at CNN.', idx + 1)
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
subsample_list.append(subsample)
|
||||
return subsample_list
|
||||
|
||||
else:
|
||||
raise ValueError('Invalid options: mode={}, arch={}'.format(mode, arch))
|
||||
|
||||
|
||||
def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]):
|
||||
"""Replace keys of old prefix with new prefix in state dict."""
|
||||
# need this list not to break the dict iterator
|
||||
old_keys = [k for k in state_dict if k.startswith(old_prefix)]
|
||||
if len(old_keys) > 0:
|
||||
logging.warning(f'Rename: {old_prefix} -> {new_prefix}')
|
||||
for k in old_keys:
|
||||
v = state_dict.pop(k)
|
||||
new_k = k.replace(old_prefix, new_prefix)
|
||||
state_dict[new_k] = v
|
||||
22
ppg2mel/utils/vc_utils.py
Normal file
22
ppg2mel/utils/vc_utils.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
|
||||
|
||||
def gcd(a, b):
|
||||
"""Greatest common divisor."""
|
||||
a, b = (a, b) if a >=b else (b, a)
|
||||
if a%b == 0:
|
||||
return b
|
||||
else :
|
||||
return gcd(b, a%b)
|
||||
|
||||
def lcm(a, b):
|
||||
"""Least common multiple"""
|
||||
return a * b // gcd(a, b)
|
||||
|
||||
def get_mask_from_lengths(lengths, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = torch.max(lengths).item()
|
||||
ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
|
||||
mask = (ids < lengths.unsqueeze(1)).bool()
|
||||
return mask
|
||||
|
||||
67
ppg2mel_train.py
Normal file
67
ppg2mel_train.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import sys
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from utils.load_yaml import HpsYaml
|
||||
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||
|
||||
# For reproducibility, comment these may speed up training
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def main():
|
||||
# Arguments
|
||||
parser = argparse.ArgumentParser(description=
|
||||
'Training PPG2Mel VC model.')
|
||||
parser.add_argument('--config', type=str,
|
||||
help='Path to experiment config, e.g., config/vc.yaml')
|
||||
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
||||
parser.add_argument('--logdir', default='log/', type=str,
|
||||
help='Logging path.', required=False)
|
||||
parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str,
|
||||
help='Checkpoint path.', required=False)
|
||||
parser.add_argument('--outdir', default='result/', type=str,
|
||||
help='Decode output path.', required=False)
|
||||
parser.add_argument('--load', default=None, type=str,
|
||||
help='Load pre-trained model (for training only)', required=False)
|
||||
parser.add_argument('--warm_start', action='store_true',
|
||||
help='Load model weights only, ignore specified layers.')
|
||||
parser.add_argument('--seed', default=0, type=int,
|
||||
help='Random seed for reproducable results.', required=False)
|
||||
parser.add_argument('--njobs', default=8, type=int,
|
||||
help='Number of threads for dataloader/decoding.', required=False)
|
||||
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
||||
parser.add_argument('--no-pin', action='store_true',
|
||||
help='Disable pin-memory for dataloader')
|
||||
parser.add_argument('--test', action='store_true', help='Test the model.')
|
||||
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
||||
parser.add_argument('--finetune', action='store_true', help='Finetune model')
|
||||
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
|
||||
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
|
||||
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
|
||||
|
||||
###
|
||||
|
||||
paras = parser.parse_args()
|
||||
setattr(paras, 'gpu', not paras.cpu)
|
||||
setattr(paras, 'pin_memory', not paras.no_pin)
|
||||
setattr(paras, 'verbose', not paras.no_msg)
|
||||
# Make the config dict dot visitable
|
||||
config = HpsYaml(paras.config)
|
||||
|
||||
np.random.seed(paras.seed)
|
||||
torch.manual_seed(paras.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(paras.seed)
|
||||
|
||||
print(">>> OneShot VC training ...")
|
||||
mode = "train"
|
||||
solver = Solver(config, paras, mode)
|
||||
solver.load_data()
|
||||
solver.set_model()
|
||||
solver.exec()
|
||||
print(">>> Oneshot VC train finished!")
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
102
ppg_extractor/__init__.py
Normal file
102
ppg_extractor/__init__.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import argparse
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
|
||||
from .frontend import DefaultFrontend
|
||||
from .utterance_mvn import UtteranceMVN
|
||||
from .encoder.conformer_encoder import ConformerEncoder
|
||||
|
||||
_model = None # type: PPGModel
|
||||
_device = None
|
||||
|
||||
class PPGModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
frontend,
|
||||
normalizer,
|
||||
encoder,
|
||||
):
|
||||
super().__init__()
|
||||
self.frontend = frontend
|
||||
self.normalize = normalizer
|
||||
self.encoder = encoder
|
||||
|
||||
def forward(self, speech, speech_lengths):
|
||||
"""
|
||||
|
||||
Args:
|
||||
speech (tensor): (B, L)
|
||||
speech_lengths (tensor): (B, )
|
||||
|
||||
Returns:
|
||||
bottle_neck_feats (tensor): (B, L//hop_size, 144)
|
||||
|
||||
"""
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
||||
return encoder_out
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
):
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
|
||||
if self.frontend is not None:
|
||||
# Frontend
|
||||
# e.g. STFT and Feature extract
|
||||
# data_loader may send time-domain signal in this case
|
||||
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||||
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
# No frontend and no feature extract
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return feats, feats_lengths
|
||||
|
||||
def extract_from_wav(self, src_wav):
|
||||
src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(_device)
|
||||
src_wav_lengths = torch.LongTensor([len(src_wav)]).to(_device)
|
||||
return self(src_wav_tensor, src_wav_lengths)
|
||||
|
||||
|
||||
def build_model(args):
|
||||
normalizer = UtteranceMVN(**args.normalize_conf)
|
||||
frontend = DefaultFrontend(**args.frontend_conf)
|
||||
encoder = ConformerEncoder(input_size=80, **args.encoder_conf)
|
||||
model = PPGModel(frontend, normalizer, encoder)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_model(model_file, device=None):
|
||||
global _model, _device
|
||||
|
||||
if device is None:
|
||||
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
_device = device
|
||||
# search a config file
|
||||
model_config_fpaths = list(model_file.parent.rglob("*.yaml"))
|
||||
config_file = model_config_fpaths[0]
|
||||
with config_file.open("r", encoding="utf-8") as f:
|
||||
args = yaml.safe_load(f)
|
||||
|
||||
args = argparse.Namespace(**args)
|
||||
|
||||
model = build_model(args)
|
||||
model_state_dict = model.state_dict()
|
||||
|
||||
ckpt_state_dict = torch.load(model_file, map_location=_device)
|
||||
ckpt_state_dict = {k:v for k,v in ckpt_state_dict.items() if 'encoder' in k}
|
||||
|
||||
model_state_dict.update(ckpt_state_dict)
|
||||
model.load_state_dict(model_state_dict)
|
||||
|
||||
_model = model.eval().to(_device)
|
||||
return _model
|
||||
|
||||
|
||||
398
ppg_extractor/e2e_asr_common.py
Normal file
398
ppg_extractor/e2e_asr_common.py
Normal file
@@ -0,0 +1,398 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Common functions for ASR."""
|
||||
|
||||
import argparse
|
||||
import editdistance
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
import six
|
||||
import sys
|
||||
|
||||
from itertools import groupby
|
||||
|
||||
|
||||
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
|
||||
"""End detection.
|
||||
|
||||
desribed in Eq. (50) of S. Watanabe et al
|
||||
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
|
||||
|
||||
:param ended_hyps:
|
||||
:param i:
|
||||
:param M:
|
||||
:param D_end:
|
||||
:return:
|
||||
"""
|
||||
if len(ended_hyps) == 0:
|
||||
return False
|
||||
count = 0
|
||||
best_hyp = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[0]
|
||||
for m in six.moves.range(M):
|
||||
# get ended_hyps with their length is i - m
|
||||
hyp_length = i - m
|
||||
hyps_same_length = [x for x in ended_hyps if len(x['yseq']) == hyp_length]
|
||||
if len(hyps_same_length) > 0:
|
||||
best_hyp_same_length = sorted(hyps_same_length, key=lambda x: x['score'], reverse=True)[0]
|
||||
if best_hyp_same_length['score'] - best_hyp['score'] < D_end:
|
||||
count += 1
|
||||
|
||||
if count == M:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
# TODO(takaaki-hori): add different smoothing methods
|
||||
def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
|
||||
"""Obtain label distribution for loss smoothing.
|
||||
|
||||
:param odim:
|
||||
:param lsm_type:
|
||||
:param blank:
|
||||
:param transcript:
|
||||
:return:
|
||||
"""
|
||||
if transcript is not None:
|
||||
with open(transcript, 'rb') as f:
|
||||
trans_json = json.load(f)['utts']
|
||||
|
||||
if lsm_type == 'unigram':
|
||||
assert transcript is not None, 'transcript is required for %s label smoothing' % lsm_type
|
||||
labelcount = np.zeros(odim)
|
||||
for k, v in trans_json.items():
|
||||
ids = np.array([int(n) for n in v['output'][0]['tokenid'].split()])
|
||||
# to avoid an error when there is no text in an uttrance
|
||||
if len(ids) > 0:
|
||||
labelcount[ids] += 1
|
||||
labelcount[odim - 1] = len(transcript) # count <eos>
|
||||
labelcount[labelcount == 0] = 1 # flooring
|
||||
labelcount[blank] = 0 # remove counts for blank
|
||||
labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
|
||||
else:
|
||||
logging.error(
|
||||
"Error: unexpected label smoothing type: %s" % lsm_type)
|
||||
sys.exit()
|
||||
|
||||
return labeldist
|
||||
|
||||
|
||||
def get_vgg2l_odim(idim, in_channel=3, out_channel=128, downsample=True):
|
||||
"""Return the output size of the VGG frontend.
|
||||
|
||||
:param in_channel: input channel size
|
||||
:param out_channel: output channel size
|
||||
:return: output size
|
||||
:rtype int
|
||||
"""
|
||||
idim = idim / in_channel
|
||||
if downsample:
|
||||
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling
|
||||
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling
|
||||
return int(idim) * out_channel # numer of channels
|
||||
|
||||
|
||||
class ErrorCalculator(object):
|
||||
"""Calculate CER and WER for E2E_ASR and CTC models during training.
|
||||
|
||||
:param y_hats: numpy array with predicted text
|
||||
:param y_pads: numpy array with true (target) text
|
||||
:param char_list:
|
||||
:param sym_space:
|
||||
:param sym_blank:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def __init__(self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False,
|
||||
trans_type="char"):
|
||||
"""Construct an ErrorCalculator object."""
|
||||
super(ErrorCalculator, self).__init__()
|
||||
|
||||
self.report_cer = report_cer
|
||||
self.report_wer = report_wer
|
||||
self.trans_type = trans_type
|
||||
self.char_list = char_list
|
||||
self.space = sym_space
|
||||
self.blank = sym_blank
|
||||
self.idx_blank = self.char_list.index(self.blank)
|
||||
if self.space in self.char_list:
|
||||
self.idx_space = self.char_list.index(self.space)
|
||||
else:
|
||||
self.idx_space = None
|
||||
|
||||
def __call__(self, ys_hat, ys_pad, is_ctc=False):
|
||||
"""Calculate sentence-level WER/CER score.
|
||||
|
||||
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
||||
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
||||
:param bool is_ctc: calculate CER score for CTC
|
||||
:return: sentence-level WER score
|
||||
:rtype float
|
||||
:return: sentence-level CER score
|
||||
:rtype float
|
||||
"""
|
||||
cer, wer = None, None
|
||||
if is_ctc:
|
||||
return self.calculate_cer_ctc(ys_hat, ys_pad)
|
||||
elif not self.report_cer and not self.report_wer:
|
||||
return cer, wer
|
||||
|
||||
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
|
||||
if self.report_cer:
|
||||
cer = self.calculate_cer(seqs_hat, seqs_true)
|
||||
|
||||
if self.report_wer:
|
||||
wer = self.calculate_wer(seqs_hat, seqs_true)
|
||||
return cer, wer
|
||||
|
||||
def calculate_cer_ctc(self, ys_hat, ys_pad):
|
||||
"""Calculate sentence-level CER score for CTC.
|
||||
|
||||
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
||||
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
||||
:return: average sentence-level CER score
|
||||
:rtype float
|
||||
"""
|
||||
cers, char_ref_lens = [], []
|
||||
for i, y in enumerate(ys_hat):
|
||||
y_hat = [x[0] for x in groupby(y)]
|
||||
y_true = ys_pad[i]
|
||||
seq_hat, seq_true = [], []
|
||||
for idx in y_hat:
|
||||
idx = int(idx)
|
||||
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
||||
seq_hat.append(self.char_list[int(idx)])
|
||||
|
||||
for idx in y_true:
|
||||
idx = int(idx)
|
||||
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
||||
seq_true.append(self.char_list[int(idx)])
|
||||
if self.trans_type == "char":
|
||||
hyp_chars = "".join(seq_hat)
|
||||
ref_chars = "".join(seq_true)
|
||||
else:
|
||||
hyp_chars = " ".join(seq_hat)
|
||||
ref_chars = " ".join(seq_true)
|
||||
|
||||
if len(ref_chars) > 0:
|
||||
cers.append(editdistance.eval(hyp_chars, ref_chars))
|
||||
char_ref_lens.append(len(ref_chars))
|
||||
|
||||
cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
|
||||
return cer_ctc
|
||||
|
||||
def convert_to_char(self, ys_hat, ys_pad):
|
||||
"""Convert index to character.
|
||||
|
||||
:param torch.Tensor seqs_hat: prediction (batch, seqlen)
|
||||
:param torch.Tensor seqs_true: reference (batch, seqlen)
|
||||
:return: token list of prediction
|
||||
:rtype list
|
||||
:return: token list of reference
|
||||
:rtype list
|
||||
"""
|
||||
seqs_hat, seqs_true = [], []
|
||||
for i, y_hat in enumerate(ys_hat):
|
||||
y_true = ys_pad[i]
|
||||
eos_true = np.where(y_true == -1)[0]
|
||||
eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true)
|
||||
# To avoid wrong higher WER than the one obtained from the decoding
|
||||
# eos from y_true is used to mark the eos in y_hat
|
||||
# because of that y_hats has not padded outs with -1.
|
||||
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]]
|
||||
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
|
||||
# seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
|
||||
seq_hat_text = " ".join(seq_hat).replace(self.space, ' ')
|
||||
seq_hat_text = seq_hat_text.replace(self.blank, '')
|
||||
# seq_true_text = "".join(seq_true).replace(self.space, ' ')
|
||||
seq_true_text = " ".join(seq_true).replace(self.space, ' ')
|
||||
seqs_hat.append(seq_hat_text)
|
||||
seqs_true.append(seq_true_text)
|
||||
return seqs_hat, seqs_true
|
||||
|
||||
def calculate_cer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level CER score.
|
||||
|
||||
:param list seqs_hat: prediction
|
||||
:param list seqs_true: reference
|
||||
:return: average sentence-level CER score
|
||||
:rtype float
|
||||
"""
|
||||
char_eds, char_ref_lens = [], []
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_chars = seq_hat_text.replace(' ', '')
|
||||
ref_chars = seq_true_text.replace(' ', '')
|
||||
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
|
||||
char_ref_lens.append(len(ref_chars))
|
||||
return float(sum(char_eds)) / sum(char_ref_lens)
|
||||
|
||||
def calculate_wer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level WER score.
|
||||
|
||||
:param list seqs_hat: prediction
|
||||
:param list seqs_true: reference
|
||||
:return: average sentence-level WER score
|
||||
:rtype float
|
||||
"""
|
||||
word_eds, word_ref_lens = [], []
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_words = seq_hat_text.split()
|
||||
ref_words = seq_true_text.split()
|
||||
word_eds.append(editdistance.eval(hyp_words, ref_words))
|
||||
word_ref_lens.append(len(ref_words))
|
||||
return float(sum(word_eds)) / sum(word_ref_lens)
|
||||
|
||||
|
||||
class ErrorCalculatorTrans(object):
|
||||
"""Calculate CER and WER for transducer models.
|
||||
|
||||
Args:
|
||||
decoder (nn.Module): decoder module
|
||||
args (Namespace): argument Namespace containing options
|
||||
report_cer (boolean): compute CER option
|
||||
report_wer (boolean): compute WER option
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, decoder, args, report_cer=False, report_wer=False):
|
||||
"""Construct an ErrorCalculator object for transducer model."""
|
||||
super(ErrorCalculatorTrans, self).__init__()
|
||||
|
||||
self.dec = decoder
|
||||
|
||||
recog_args = {'beam_size': args.beam_size,
|
||||
'nbest': args.nbest,
|
||||
'space': args.sym_space,
|
||||
'score_norm_transducer': args.score_norm_transducer}
|
||||
|
||||
self.recog_args = argparse.Namespace(**recog_args)
|
||||
|
||||
self.char_list = args.char_list
|
||||
self.space = args.sym_space
|
||||
self.blank = args.sym_blank
|
||||
|
||||
self.report_cer = args.report_cer
|
||||
self.report_wer = args.report_wer
|
||||
|
||||
def __call__(self, hs_pad, ys_pad):
|
||||
"""Calculate sentence-level WER/CER score for transducer models.
|
||||
|
||||
Args:
|
||||
hs_pad (torch.Tensor): batch of padded input sequence (batch, T, D)
|
||||
ys_pad (torch.Tensor): reference (batch, seqlen)
|
||||
|
||||
Returns:
|
||||
(float): sentence-level CER score
|
||||
(float): sentence-level WER score
|
||||
|
||||
"""
|
||||
cer, wer = None, None
|
||||
|
||||
if not self.report_cer and not self.report_wer:
|
||||
return cer, wer
|
||||
|
||||
batchsize = int(hs_pad.size(0))
|
||||
batch_nbest = []
|
||||
|
||||
for b in six.moves.range(batchsize):
|
||||
if self.recog_args.beam_size == 1:
|
||||
nbest_hyps = self.dec.recognize(hs_pad[b], self.recog_args)
|
||||
else:
|
||||
nbest_hyps = self.dec.recognize_beam(hs_pad[b], self.recog_args)
|
||||
batch_nbest.append(nbest_hyps)
|
||||
|
||||
ys_hat = [nbest_hyp[0]['yseq'][1:] for nbest_hyp in batch_nbest]
|
||||
|
||||
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad.cpu())
|
||||
|
||||
if self.report_cer:
|
||||
cer = self.calculate_cer(seqs_hat, seqs_true)
|
||||
|
||||
if self.report_wer:
|
||||
wer = self.calculate_wer(seqs_hat, seqs_true)
|
||||
|
||||
return cer, wer
|
||||
|
||||
def convert_to_char(self, ys_hat, ys_pad):
|
||||
"""Convert index to character.
|
||||
|
||||
Args:
|
||||
ys_hat (torch.Tensor): prediction (batch, seqlen)
|
||||
ys_pad (torch.Tensor): reference (batch, seqlen)
|
||||
|
||||
Returns:
|
||||
(list): token list of prediction
|
||||
(list): token list of reference
|
||||
|
||||
"""
|
||||
seqs_hat, seqs_true = [], []
|
||||
|
||||
for i, y_hat in enumerate(ys_hat):
|
||||
y_true = ys_pad[i]
|
||||
|
||||
eos_true = np.where(y_true == -1)[0]
|
||||
eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true)
|
||||
|
||||
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]]
|
||||
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
|
||||
|
||||
seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
|
||||
seq_hat_text = seq_hat_text.replace(self.blank, '')
|
||||
seq_true_text = "".join(seq_true).replace(self.space, ' ')
|
||||
|
||||
seqs_hat.append(seq_hat_text)
|
||||
seqs_true.append(seq_true_text)
|
||||
|
||||
return seqs_hat, seqs_true
|
||||
|
||||
def calculate_cer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level CER score for transducer model.
|
||||
|
||||
Args:
|
||||
seqs_hat (torch.Tensor): prediction (batch, seqlen)
|
||||
seqs_true (torch.Tensor): reference (batch, seqlen)
|
||||
|
||||
Returns:
|
||||
(float): average sentence-level CER score
|
||||
|
||||
"""
|
||||
char_eds, char_ref_lens = [], []
|
||||
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_chars = seq_hat_text.replace(' ', '')
|
||||
ref_chars = seq_true_text.replace(' ', '')
|
||||
|
||||
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
|
||||
char_ref_lens.append(len(ref_chars))
|
||||
|
||||
return float(sum(char_eds)) / sum(char_ref_lens)
|
||||
|
||||
def calculate_wer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level WER score for transducer model.
|
||||
|
||||
Args:
|
||||
seqs_hat (torch.Tensor): prediction (batch, seqlen)
|
||||
seqs_true (torch.Tensor): reference (batch, seqlen)
|
||||
|
||||
Returns:
|
||||
(float): average sentence-level WER score
|
||||
|
||||
"""
|
||||
word_eds, word_ref_lens = [], []
|
||||
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_words = seq_hat_text.split()
|
||||
ref_words = seq_true_text.split()
|
||||
|
||||
word_eds.append(editdistance.eval(hyp_words, ref_words))
|
||||
word_ref_lens.append(len(ref_words))
|
||||
|
||||
return float(sum(word_eds)) / sum(word_ref_lens)
|
||||
0
ppg_extractor/encoder/__init__.py
Normal file
0
ppg_extractor/encoder/__init__.py
Normal file
183
ppg_extractor/encoder/attention.py
Normal file
183
ppg_extractor/encoder/attention.py
Normal file
@@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Multi-Head Attention layer definition."""
|
||||
|
||||
import math
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class MultiHeadedAttention(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
:param int n_head: the number of head s
|
||||
:param int n_feat: the number of features
|
||||
:param float dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_head, n_feat, dropout_rate):
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super(MultiHeadedAttention, self).__init__()
|
||||
assert n_feat % n_head == 0
|
||||
# We assume d_v always equals d_k
|
||||
self.d_k = n_feat // n_head
|
||||
self.h = n_head
|
||||
self.linear_q = nn.Linear(n_feat, n_feat)
|
||||
self.linear_k = nn.Linear(n_feat, n_feat)
|
||||
self.linear_v = nn.Linear(n_feat, n_feat)
|
||||
self.linear_out = nn.Linear(n_feat, n_feat)
|
||||
self.attn = None
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
def forward_qkv(self, query, key, value):
|
||||
"""Transform query, key and value.
|
||||
|
||||
:param torch.Tensor query: (batch, time1, size)
|
||||
:param torch.Tensor key: (batch, time2, size)
|
||||
:param torch.Tensor value: (batch, time2, size)
|
||||
:return torch.Tensor transformed query, key and value
|
||||
|
||||
"""
|
||||
n_batch = query.size(0)
|
||||
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
||||
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
||||
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
||||
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
||||
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
||||
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward_attention(self, value, scores, mask):
|
||||
"""Compute attention context vector.
|
||||
|
||||
:param torch.Tensor value: (batch, head, time2, size)
|
||||
:param torch.Tensor scores: (batch, head, time1, time2)
|
||||
:param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2)
|
||||
:return torch.Tensor transformed `value` (batch, time1, d_model)
|
||||
weighted by the attention score (batch, time1, time2)
|
||||
|
||||
"""
|
||||
n_batch = value.size(0)
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
min_value = float(
|
||||
numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
|
||||
)
|
||||
scores = scores.masked_fill(mask, min_value)
|
||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||
mask, 0.0
|
||||
) # (batch, head, time1, time2)
|
||||
else:
|
||||
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
||||
|
||||
p_attn = self.dropout(self.attn)
|
||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||
x = (
|
||||
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
||||
) # (batch, time1, d_model)
|
||||
|
||||
return self.linear_out(x) # (batch, time1, d_model)
|
||||
|
||||
def forward(self, query, key, value, mask):
|
||||
"""Compute 'Scaled Dot Product Attention'.
|
||||
|
||||
:param torch.Tensor query: (batch, time1, size)
|
||||
:param torch.Tensor key: (batch, time2, size)
|
||||
:param torch.Tensor value: (batch, time2, size)
|
||||
:param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2)
|
||||
:param torch.nn.Dropout dropout:
|
||||
:return torch.Tensor: attention output (batch, time1, d_model)
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
||||
return self.forward_attention(v, scores, mask)
|
||||
|
||||
|
||||
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
||||
"""Multi-Head Attention layer with relative position encoding.
|
||||
|
||||
Paper: https://arxiv.org/abs/1901.02860
|
||||
|
||||
:param int n_head: the number of head s
|
||||
:param int n_feat: the number of features
|
||||
:param float dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_head, n_feat, dropout_rate):
|
||||
"""Construct an RelPositionMultiHeadedAttention object."""
|
||||
super().__init__(n_head, n_feat, dropout_rate)
|
||||
# linear transformation for positional ecoding
|
||||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
def rel_shift(self, x, zero_triu=False):
|
||||
"""Compute relative positinal encoding.
|
||||
|
||||
:param torch.Tensor x: (batch, time, size)
|
||||
:param bool zero_triu: return the lower triangular part of the matrix
|
||||
"""
|
||||
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
||||
x_padded = torch.cat([zero_pad, x], dim=-1)
|
||||
|
||||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
||||
x = x_padded[:, :, 1:].view_as(x)
|
||||
|
||||
if zero_triu:
|
||||
ones = torch.ones((x.size(2), x.size(3)))
|
||||
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, query, key, value, pos_emb, mask):
|
||||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
||||
|
||||
:param torch.Tensor query: (batch, time1, size)
|
||||
:param torch.Tensor key: (batch, time2, size)
|
||||
:param torch.Tensor value: (batch, time2, size)
|
||||
:param torch.Tensor pos_emb: (batch, time1, size)
|
||||
:param torch.Tensor mask: (batch, time1, time2)
|
||||
:param torch.nn.Dropout dropout:
|
||||
:return torch.Tensor: attention output (batch, time1, d_model)
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
||||
|
||||
n_batch_pos = pos_emb.size(0)
|
||||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
||||
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
||||
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
||||
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
# (batch, head, time1, time2)
|
||||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
||||
|
||||
# compute matrix b and matrix d
|
||||
# (batch, head, time1, time2)
|
||||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
||||
self.d_k
|
||||
) # (batch, head, time1, time2)
|
||||
|
||||
return self.forward_attention(v, scores, mask)
|
||||
262
ppg_extractor/encoder/conformer_encoder.py
Normal file
262
ppg_extractor/encoder/conformer_encoder.py
Normal file
@@ -0,0 +1,262 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Encoder definition."""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from typing import Callable
|
||||
from typing import Collection
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
from .convolution import ConvolutionModule
|
||||
from .encoder_layer import EncoderLayer
|
||||
from ..nets_utils import get_activation, make_pad_mask
|
||||
from .vgg import VGG2L
|
||||
from .attention import MultiHeadedAttention, RelPositionMultiHeadedAttention
|
||||
from .embedding import PositionalEncoding, ScaledPositionalEncoding, RelPositionalEncoding
|
||||
from .layer_norm import LayerNorm
|
||||
from .multi_layer_conv import Conv1dLinear, MultiLayeredConv1d
|
||||
from .positionwise_feed_forward import PositionwiseFeedForward
|
||||
from .repeat import repeat
|
||||
from .subsampling import Conv2dNoSubsampling, Conv2dSubsampling
|
||||
|
||||
|
||||
class ConformerEncoder(torch.nn.Module):
|
||||
"""Conformer encoder module.
|
||||
|
||||
:param int idim: input dim
|
||||
:param int attention_dim: dimention of attention
|
||||
:param int attention_heads: the number of heads of multi head attention
|
||||
:param int linear_units: the number of units of position-wise feed forward
|
||||
:param int num_blocks: the number of decoder blocks
|
||||
:param float dropout_rate: dropout rate
|
||||
:param float attention_dropout_rate: dropout rate in attention
|
||||
:param float positional_dropout_rate: dropout rate after adding positional encoding
|
||||
:param str or torch.nn.Module input_layer: input layer type
|
||||
:param bool normalize_before: whether to use layer_norm before the first block
|
||||
:param bool concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
:param str positionwise_layer_type: linear of conv1d
|
||||
:param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
|
||||
:param str encoder_pos_enc_layer_type: encoder positional encoding layer type
|
||||
:param str encoder_attn_layer_type: encoder attention layer type
|
||||
:param str activation_type: encoder activation function type
|
||||
:param bool macaron_style: whether to use macaron style for positionwise layer
|
||||
:param bool use_cnn_module: whether to use convolution module
|
||||
:param int cnn_module_kernel: kernerl size of convolution module
|
||||
:param int padding_idx: padding_idx for input_layer=embed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
attention_dim=256,
|
||||
attention_heads=4,
|
||||
linear_units=2048,
|
||||
num_blocks=6,
|
||||
dropout_rate=0.1,
|
||||
positional_dropout_rate=0.1,
|
||||
attention_dropout_rate=0.0,
|
||||
input_layer="conv2d",
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
positionwise_layer_type="linear",
|
||||
positionwise_conv_kernel_size=1,
|
||||
macaron_style=False,
|
||||
pos_enc_layer_type="abs_pos",
|
||||
selfattention_layer_type="selfattn",
|
||||
activation_type="swish",
|
||||
use_cnn_module=False,
|
||||
cnn_module_kernel=31,
|
||||
padding_idx=-1,
|
||||
no_subsample=False,
|
||||
subsample_by_2=False,
|
||||
):
|
||||
"""Construct an Encoder object."""
|
||||
super().__init__()
|
||||
|
||||
self._output_size = attention_dim
|
||||
idim = input_size
|
||||
|
||||
activation = get_activation(activation_type)
|
||||
if pos_enc_layer_type == "abs_pos":
|
||||
pos_enc_class = PositionalEncoding
|
||||
elif pos_enc_layer_type == "scaled_abs_pos":
|
||||
pos_enc_class = ScaledPositionalEncoding
|
||||
elif pos_enc_layer_type == "rel_pos":
|
||||
assert selfattention_layer_type == "rel_selfattn"
|
||||
pos_enc_class = RelPositionalEncoding
|
||||
else:
|
||||
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(idim, attention_dim),
|
||||
torch.nn.LayerNorm(attention_dim),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
logging.info("Encoder input layer type: conv2d")
|
||||
if no_subsample:
|
||||
self.embed = Conv2dNoSubsampling(
|
||||
idim,
|
||||
attention_dim,
|
||||
dropout_rate,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
else:
|
||||
self.embed = Conv2dSubsampling(
|
||||
idim,
|
||||
attention_dim,
|
||||
dropout_rate,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
subsample_by_2, # NOTE(Sx): added by songxiang
|
||||
)
|
||||
elif input_layer == "vgg2l":
|
||||
self.embed = VGG2L(idim, attention_dim)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif isinstance(input_layer, torch.nn.Module):
|
||||
self.embed = torch.nn.Sequential(
|
||||
input_layer,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
self.embed = torch.nn.Sequential(
|
||||
pos_enc_class(attention_dim, positional_dropout_rate)
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
activation,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
logging.info("encoder self-attention layer type = self-attention")
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
elif selfattention_layer_type == "rel_selfattn":
|
||||
assert pos_enc_layer_type == "rel_pos"
|
||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
|
||||
|
||||
convolution_layer = ConvolutionModule
|
||||
convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
attention_dim,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
||||
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input lengths (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
Position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
|
||||
if isinstance(self.embed, (Conv2dSubsampling, Conv2dNoSubsampling, VGG2L)):
|
||||
# print(xs_pad.shape)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
# print(xs_pad[0].size())
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
xs_pad, masks = self.encoders(xs_pad, masks)
|
||||
if isinstance(xs_pad, tuple):
|
||||
xs_pad = xs_pad[0]
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
return xs_pad, olens, None
|
||||
|
||||
# def forward(self, xs, masks):
|
||||
# """Encode input sequence.
|
||||
|
||||
# :param torch.Tensor xs: input tensor
|
||||
# :param torch.Tensor masks: input mask
|
||||
# :return: position embedded tensor and mask
|
||||
# :rtype Tuple[torch.Tensor, torch.Tensor]:
|
||||
# """
|
||||
# if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
|
||||
# xs, masks = self.embed(xs, masks)
|
||||
# else:
|
||||
# xs = self.embed(xs)
|
||||
|
||||
# xs, masks = self.encoders(xs, masks)
|
||||
# if isinstance(xs, tuple):
|
||||
# xs = xs[0]
|
||||
|
||||
# if self.normalize_before:
|
||||
# xs = self.after_norm(xs)
|
||||
# return xs, masks
|
||||
74
ppg_extractor/encoder/convolution.py
Normal file
74
ppg_extractor/encoder/convolution.py
Normal file
@@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
||||
# Northwestern Polytechnical University (Pengcheng Guo)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""ConvolutionModule definition."""
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model.
|
||||
|
||||
:param int channels: channels of cnn
|
||||
:param int kernel_size: kernerl size of cnn
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
|
||||
self.pointwise_conv1 = nn.Conv1d(
|
||||
channels,
|
||||
2 * channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
self.pointwise_conv2 = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
"""Compute convolution module.
|
||||
|
||||
:param torch.Tensor x: (batch, time, size)
|
||||
:return torch.Tensor: convoluted `value` (batch, time, d_model)
|
||||
"""
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
# GLU mechanism
|
||||
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
||||
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.activation(self.norm(x))
|
||||
|
||||
x = self.pointwise_conv2(x)
|
||||
|
||||
return x.transpose(1, 2)
|
||||
166
ppg_extractor/encoder/embedding.py
Normal file
166
ppg_extractor/encoder/embedding.py
Normal file
@@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Positonal Encoding Module."""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _pre_hook(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
"""Perform pre-hook in load_state_dict for backward compatibility.
|
||||
|
||||
Note:
|
||||
We saved self.pe until v.0.5.2 but we have omitted it later.
|
||||
Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
|
||||
|
||||
"""
|
||||
k = prefix + "pe"
|
||||
if k in state_dict:
|
||||
state_dict.pop(k)
|
||||
|
||||
|
||||
class PositionalEncoding(torch.nn.Module):
|
||||
"""Positional encoding.
|
||||
|
||||
:param int d_model: embedding dim
|
||||
:param float dropout_rate: dropout rate
|
||||
:param int max_len: maximum input length
|
||||
:param reverse: whether to reverse the input position
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.reverse = reverse
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
self._register_load_state_dict_pre_hook(_pre_hook)
|
||||
|
||||
def extend_pe(self, x):
|
||||
"""Reset the positional encodings."""
|
||||
if self.pe is not None:
|
||||
if self.pe.size(1) >= x.size(1):
|
||||
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
pe = torch.zeros(x.size(1), self.d_model)
|
||||
if self.reverse:
|
||||
position = torch.arange(
|
||||
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
||||
).unsqueeze(1)
|
||||
else:
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale + self.pe[:, : x.size(1)]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class ScaledPositionalEncoding(PositionalEncoding):
|
||||
"""Scaled positional encoding module.
|
||||
|
||||
See also: Sec. 3.2 https://arxiv.org/pdf/1809.08895.pdf
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000):
|
||||
"""Initialize class.
|
||||
|
||||
:param int d_model: embedding dim
|
||||
:param float dropout_rate: dropout rate
|
||||
:param int max_len: maximum input length
|
||||
|
||||
"""
|
||||
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
|
||||
|
||||
def reset_parameters(self):
|
||||
"""Reset parameters."""
|
||||
self.alpha.data = torch.tensor(1.0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x + self.alpha * self.pe[:, : x.size(1)]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class RelPositionalEncoding(PositionalEncoding):
|
||||
"""Relitive positional encoding module.
|
||||
|
||||
See : Appendix B in https://arxiv.org/abs/1901.02860
|
||||
|
||||
:param int d_model: embedding dim
|
||||
:param float dropout_rate: dropout rate
|
||||
:param int max_len: maximum input length
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000):
|
||||
"""Initialize class.
|
||||
|
||||
:param int d_model: embedding dim
|
||||
:param float dropout_rate: dropout rate
|
||||
:param int max_len: maximum input length
|
||||
|
||||
"""
|
||||
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
||||
|
||||
def forward(self, x):
|
||||
"""Compute positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: x. Its shape is (batch, time, ...)
|
||||
torch.Tensor: pos_emb. Its shape is (1, time, ...)
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale
|
||||
pos_emb = self.pe[:, : x.size(1)]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
217
ppg_extractor/encoder/encoder.py
Normal file
217
ppg_extractor/encoder/encoder.py
Normal file
@@ -0,0 +1,217 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Encoder definition."""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule
|
||||
from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer
|
||||
from espnet.nets.pytorch_backend.nets_utils import get_activation
|
||||
from espnet.nets.pytorch_backend.transducer.vgg import VGG2L
|
||||
from espnet.nets.pytorch_backend.transformer.attention import (
|
||||
MultiHeadedAttention, # noqa: H301
|
||||
RelPositionMultiHeadedAttention, # noqa: H301
|
||||
)
|
||||
from espnet.nets.pytorch_backend.transformer.embedding import (
|
||||
PositionalEncoding, # noqa: H301
|
||||
ScaledPositionalEncoding, # noqa: H301
|
||||
RelPositionalEncoding, # noqa: H301
|
||||
)
|
||||
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
|
||||
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear
|
||||
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d
|
||||
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from espnet.nets.pytorch_backend.transformer.repeat import repeat
|
||||
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling
|
||||
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
"""Conformer encoder module.
|
||||
|
||||
:param int idim: input dim
|
||||
:param int attention_dim: dimention of attention
|
||||
:param int attention_heads: the number of heads of multi head attention
|
||||
:param int linear_units: the number of units of position-wise feed forward
|
||||
:param int num_blocks: the number of decoder blocks
|
||||
:param float dropout_rate: dropout rate
|
||||
:param float attention_dropout_rate: dropout rate in attention
|
||||
:param float positional_dropout_rate: dropout rate after adding positional encoding
|
||||
:param str or torch.nn.Module input_layer: input layer type
|
||||
:param bool normalize_before: whether to use layer_norm before the first block
|
||||
:param bool concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
:param str positionwise_layer_type: linear of conv1d
|
||||
:param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
|
||||
:param str encoder_pos_enc_layer_type: encoder positional encoding layer type
|
||||
:param str encoder_attn_layer_type: encoder attention layer type
|
||||
:param str activation_type: encoder activation function type
|
||||
:param bool macaron_style: whether to use macaron style for positionwise layer
|
||||
:param bool use_cnn_module: whether to use convolution module
|
||||
:param int cnn_module_kernel: kernerl size of convolution module
|
||||
:param int padding_idx: padding_idx for input_layer=embed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
idim,
|
||||
attention_dim=256,
|
||||
attention_heads=4,
|
||||
linear_units=2048,
|
||||
num_blocks=6,
|
||||
dropout_rate=0.1,
|
||||
positional_dropout_rate=0.1,
|
||||
attention_dropout_rate=0.0,
|
||||
input_layer="conv2d",
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
positionwise_layer_type="linear",
|
||||
positionwise_conv_kernel_size=1,
|
||||
macaron_style=False,
|
||||
pos_enc_layer_type="abs_pos",
|
||||
selfattention_layer_type="selfattn",
|
||||
activation_type="swish",
|
||||
use_cnn_module=False,
|
||||
cnn_module_kernel=31,
|
||||
padding_idx=-1,
|
||||
):
|
||||
"""Construct an Encoder object."""
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
activation = get_activation(activation_type)
|
||||
if pos_enc_layer_type == "abs_pos":
|
||||
pos_enc_class = PositionalEncoding
|
||||
elif pos_enc_layer_type == "scaled_abs_pos":
|
||||
pos_enc_class = ScaledPositionalEncoding
|
||||
elif pos_enc_layer_type == "rel_pos":
|
||||
assert selfattention_layer_type == "rel_selfattn"
|
||||
pos_enc_class = RelPositionalEncoding
|
||||
else:
|
||||
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(idim, attention_dim),
|
||||
torch.nn.LayerNorm(attention_dim),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(
|
||||
idim,
|
||||
attention_dim,
|
||||
dropout_rate,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "vgg2l":
|
||||
self.embed = VGG2L(idim, attention_dim)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif isinstance(input_layer, torch.nn.Module):
|
||||
self.embed = torch.nn.Sequential(
|
||||
input_layer,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
self.embed = torch.nn.Sequential(
|
||||
pos_enc_class(attention_dim, positional_dropout_rate)
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
activation,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
logging.info("encoder self-attention layer type = self-attention")
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
elif selfattention_layer_type == "rel_selfattn":
|
||||
assert pos_enc_layer_type == "rel_pos"
|
||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
|
||||
|
||||
convolution_layer = ConvolutionModule
|
||||
convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
attention_dim,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
||||
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
|
||||
def forward(self, xs, masks):
|
||||
"""Encode input sequence.
|
||||
|
||||
:param torch.Tensor xs: input tensor
|
||||
:param torch.Tensor masks: input mask
|
||||
:return: position embedded tensor and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
|
||||
xs, masks = self.embed(xs, masks)
|
||||
else:
|
||||
xs = self.embed(xs)
|
||||
|
||||
xs, masks = self.encoders(xs, masks)
|
||||
if isinstance(xs, tuple):
|
||||
xs = xs[0]
|
||||
|
||||
if self.normalize_before:
|
||||
xs = self.after_norm(xs)
|
||||
return xs, masks
|
||||
152
ppg_extractor/encoder/encoder_layer.py
Normal file
152
ppg_extractor/encoder/encoder_layer.py
Normal file
@@ -0,0 +1,152 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Encoder self-attention layer definition."""
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .layer_norm import LayerNorm
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
"""Encoder layer module.
|
||||
|
||||
:param int size: input dim
|
||||
:param espnet.nets.pytorch_backend.transformer.attention.
|
||||
MultiHeadedAttention self_attn: self attention module
|
||||
RelPositionMultiHeadedAttention self_attn: self attention module
|
||||
:param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.
|
||||
PositionwiseFeedForward feed_forward:
|
||||
feed forward module
|
||||
:param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward
|
||||
for macaron style
|
||||
PositionwiseFeedForward feed_forward:
|
||||
feed forward module
|
||||
:param espnet.nets.pytorch_backend.conformer.convolution.
|
||||
ConvolutionModule feed_foreard:
|
||||
feed forward module
|
||||
:param float dropout_rate: dropout rate
|
||||
:param bool normalize_before: whether to use layer_norm before the first block
|
||||
:param bool concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
feed_forward_macaron,
|
||||
conv_module,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayer, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.feed_forward_macaron = feed_forward_macaron
|
||||
self.conv_module = conv_module
|
||||
self.norm_ff = LayerNorm(size) # for the FNN module
|
||||
self.norm_mha = LayerNorm(size) # for the MHA module
|
||||
if feed_forward_macaron is not None:
|
||||
self.norm_ff_macaron = LayerNorm(size)
|
||||
self.ff_scale = 0.5
|
||||
else:
|
||||
self.ff_scale = 1.0
|
||||
if self.conv_module is not None:
|
||||
self.norm_conv = LayerNorm(size) # for the CNN module
|
||||
self.norm_final = LayerNorm(size) # for the final output of the block
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
|
||||
def forward(self, x_input, mask, cache=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
:param torch.Tensor x_input: encoded source features, w/o pos_emb
|
||||
tuple((batch, max_time_in, size), (1, max_time_in, size))
|
||||
or (batch, max_time_in, size)
|
||||
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
||||
:param torch.Tensor cache: cache for x (batch, max_time_in - 1, size)
|
||||
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
if isinstance(x_input, tuple):
|
||||
x, pos_emb = x_input[0], x_input[1]
|
||||
else:
|
||||
x, pos_emb = x_input, None
|
||||
|
||||
# whether to use macaron style
|
||||
if self.feed_forward_macaron is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_ff_macaron(x)
|
||||
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_ff_macaron(x)
|
||||
|
||||
# multi-headed self-attention module
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_mha(x)
|
||||
|
||||
if cache is None:
|
||||
x_q = x
|
||||
else:
|
||||
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
|
||||
x_q = x[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
mask = None if mask is None else mask[:, -1:, :]
|
||||
|
||||
if pos_emb is not None:
|
||||
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
|
||||
else:
|
||||
x_att = self.self_attn(x_q, x, x, mask)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat((x, x_att), dim=-1)
|
||||
x = residual + self.concat_linear(x_concat)
|
||||
else:
|
||||
x = residual + self.dropout(x_att)
|
||||
if not self.normalize_before:
|
||||
x = self.norm_mha(x)
|
||||
|
||||
# convolution module
|
||||
if self.conv_module is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_conv(x)
|
||||
x = residual + self.dropout(self.conv_module(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_conv(x)
|
||||
|
||||
# feed forward module
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_ff(x)
|
||||
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_ff(x)
|
||||
|
||||
if self.conv_module is not None:
|
||||
x = self.norm_final(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
if pos_emb is not None:
|
||||
return (x, pos_emb), mask
|
||||
|
||||
return x, mask
|
||||
33
ppg_extractor/encoder/layer_norm.py
Normal file
33
ppg_extractor/encoder/layer_norm.py
Normal file
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Layer normalization module."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
"""Layer normalization module.
|
||||
|
||||
:param int nout: output dim size
|
||||
:param int dim: dimension to be normalized
|
||||
"""
|
||||
|
||||
def __init__(self, nout, dim=-1):
|
||||
"""Construct an LayerNorm object."""
|
||||
super(LayerNorm, self).__init__(nout, eps=1e-12)
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
"""Apply layer normalization.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:return: layer normalized tensor
|
||||
:rtype torch.Tensor
|
||||
"""
|
||||
if self.dim == -1:
|
||||
return super(LayerNorm, self).forward(x)
|
||||
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
||||
105
ppg_extractor/encoder/multi_layer_conv.py
Normal file
105
ppg_extractor/encoder/multi_layer_conv.py
Normal file
@@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Layer modules for FFT block in FastSpeech (Feed-forward Transformer)."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MultiLayeredConv1d(torch.nn.Module):
|
||||
"""Multi-layered conv1d for Transformer block.
|
||||
|
||||
This is a module of multi-leyered conv1d designed
|
||||
to replace positionwise feed-forward network
|
||||
in Transforner block, which is introduced in
|
||||
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
||||
|
||||
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
||||
https://arxiv.org/pdf/1905.09263.pdf
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
||||
"""Initialize MultiLayeredConv1d module.
|
||||
|
||||
Args:
|
||||
in_chans (int): Number of input channels.
|
||||
hidden_chans (int): Number of hidden channels.
|
||||
kernel_size (int): Kernel size of conv1d.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
super(MultiLayeredConv1d, self).__init__()
|
||||
self.w_1 = torch.nn.Conv1d(
|
||||
in_chans,
|
||||
hidden_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
self.w_2 = torch.nn.Conv1d(
|
||||
hidden_chans,
|
||||
in_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, x):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of input tensors (B, ..., in_chans).
|
||||
|
||||
Returns:
|
||||
Tensor: Batch of output tensors (B, ..., hidden_chans).
|
||||
|
||||
"""
|
||||
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
||||
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
|
||||
|
||||
|
||||
class Conv1dLinear(torch.nn.Module):
|
||||
"""Conv1D + Linear for Transformer block.
|
||||
|
||||
A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
||||
"""Initialize Conv1dLinear module.
|
||||
|
||||
Args:
|
||||
in_chans (int): Number of input channels.
|
||||
hidden_chans (int): Number of hidden channels.
|
||||
kernel_size (int): Kernel size of conv1d.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
super(Conv1dLinear, self).__init__()
|
||||
self.w_1 = torch.nn.Conv1d(
|
||||
in_chans,
|
||||
hidden_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, x):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of input tensors (B, ..., in_chans).
|
||||
|
||||
Returns:
|
||||
Tensor: Batch of output tensors (B, ..., hidden_chans).
|
||||
|
||||
"""
|
||||
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
||||
return self.w_2(self.dropout(x))
|
||||
31
ppg_extractor/encoder/positionwise_feed_forward.py
Normal file
31
ppg_extractor/encoder/positionwise_feed_forward.py
Normal file
@@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Positionwise feed forward layer definition."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class PositionwiseFeedForward(torch.nn.Module):
|
||||
"""Positionwise feed forward layer.
|
||||
|
||||
:param int idim: input dimenstion
|
||||
:param int hidden_units: number of hidden units
|
||||
:param float dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
|
||||
"""Construct an PositionwiseFeedForward object."""
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
||||
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward funciton."""
|
||||
return self.w_2(self.dropout(self.activation(self.w_1(x))))
|
||||
30
ppg_extractor/encoder/repeat.py
Normal file
30
ppg_extractor/encoder/repeat.py
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Repeat the same layer definition."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MultiSequential(torch.nn.Sequential):
|
||||
"""Multi-input multi-output torch.nn.Sequential."""
|
||||
|
||||
def forward(self, *args):
|
||||
"""Repeat."""
|
||||
for m in self:
|
||||
args = m(*args)
|
||||
return args
|
||||
|
||||
|
||||
def repeat(N, fn):
|
||||
"""Repeat module N times.
|
||||
|
||||
:param int N: repeat time
|
||||
:param function fn: function to generate module
|
||||
:return: repeated modules
|
||||
:rtype: MultiSequential
|
||||
"""
|
||||
return MultiSequential(*[fn(n) for n in range(N)])
|
||||
218
ppg_extractor/encoder/subsampling.py
Normal file
218
ppg_extractor/encoder/subsampling.py
Normal file
@@ -0,0 +1,218 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Subsampling layer definition."""
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
|
||||
|
||||
|
||||
class Conv2dSubsampling(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length or 1/2 length).
|
||||
|
||||
:param int idim: input dim
|
||||
:param int odim: output dim
|
||||
:param flaot dropout_rate: dropout rate
|
||||
:param torch.nn.Module pos_enc: custom position encoding layer
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate, pos_enc=None,
|
||||
subsample_by_2=False,
|
||||
):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super(Conv2dSubsampling, self).__init__()
|
||||
self.subsample_by_2 = subsample_by_2
|
||||
if subsample_by_2:
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * (idim // 2), odim),
|
||||
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
else:
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, kernel_size=4, stride=2, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * (idim // 4), odim),
|
||||
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:param torch.Tensor x_mask: input mask
|
||||
:return: subsampled x and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
if self.subsample_by_2:
|
||||
return x, x_mask[:, :, ::2]
|
||||
else:
|
||||
return x, x_mask[:, :, ::2][:, :, ::2]
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Subsample x.
|
||||
|
||||
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
||||
return the positioning encoding.
|
||||
|
||||
"""
|
||||
if key != -1:
|
||||
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
||||
return self.out[key]
|
||||
|
||||
|
||||
class Conv2dNoSubsampling(torch.nn.Module):
|
||||
"""Convolutional 2D without subsampling.
|
||||
|
||||
:param int idim: input dim
|
||||
:param int odim: output dim
|
||||
:param flaot dropout_rate: dropout rate
|
||||
:param torch.nn.Module pos_enc: custom position encoding layer
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super().__init__()
|
||||
logging.info("Encoder does not do down-sample on mel-spectrogram.")
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, kernel_size=5, stride=1, padding=2),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * idim, odim),
|
||||
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:param torch.Tensor x_mask: input mask
|
||||
:return: subsampled x and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Subsample x.
|
||||
|
||||
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
||||
return the positioning encoding.
|
||||
|
||||
"""
|
||||
if key != -1:
|
||||
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
||||
return self.out[key]
|
||||
|
||||
|
||||
class Conv2dSubsampling6(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/6 length).
|
||||
|
||||
:param int idim: input dim
|
||||
:param int odim: output dim
|
||||
:param flaot dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super(Conv2dSubsampling6, self).__init__()
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 5, 3),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim),
|
||||
PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:param torch.Tensor x_mask: input mask
|
||||
:return: subsampled x and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask[:, :, :-2:2][:, :, :-4:3]
|
||||
|
||||
|
||||
class Conv2dSubsampling8(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/8 length).
|
||||
|
||||
:param int idim: input dim
|
||||
:param int odim: output dim
|
||||
:param flaot dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super(Conv2dSubsampling8, self).__init__()
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim),
|
||||
PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:param torch.Tensor x_mask: input mask
|
||||
:return: subsampled x and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
|
||||
18
ppg_extractor/encoder/swish.py
Normal file
18
ppg_extractor/encoder/swish.py
Normal file
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
||||
# Northwestern Polytechnical University (Pengcheng Guo)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Swish() activation function for Conformer."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Swish(torch.nn.Module):
|
||||
"""Construct an Swish object."""
|
||||
|
||||
def forward(self, x):
|
||||
"""Return Swich activation function."""
|
||||
return x * torch.sigmoid(x)
|
||||
77
ppg_extractor/encoder/vgg.py
Normal file
77
ppg_extractor/encoder/vgg.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""VGG2L definition for transformer-transducer."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class VGG2L(torch.nn.Module):
|
||||
"""VGG2L module for transformer-transducer encoder."""
|
||||
|
||||
def __init__(self, idim, odim):
|
||||
"""Construct a VGG2L object.
|
||||
|
||||
Args:
|
||||
idim (int): dimension of inputs
|
||||
odim (int): dimension of outputs
|
||||
|
||||
"""
|
||||
super(VGG2L, self).__init__()
|
||||
|
||||
self.vgg2l = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 64, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(64, 64, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.MaxPool2d((3, 2)),
|
||||
torch.nn.Conv2d(64, 128, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.MaxPool2d((2, 2)),
|
||||
)
|
||||
|
||||
self.output = torch.nn.Linear(128 * ((idim // 2) // 2), odim)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""VGG2L forward for x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): input torch (B, T, idim)
|
||||
x_mask (torch.Tensor): (B, 1, T)
|
||||
|
||||
Returns:
|
||||
x (torch.Tensor): input torch (B, sub(T), attention_dim)
|
||||
x_mask (torch.Tensor): (B, 1, sub(T))
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1)
|
||||
x = self.vgg2l(x)
|
||||
|
||||
b, c, t, f = x.size()
|
||||
|
||||
x = self.output(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
else:
|
||||
x_mask = self.create_new_mask(x_mask, x)
|
||||
|
||||
return x, x_mask
|
||||
|
||||
def create_new_mask(self, x_mask, x):
|
||||
"""Create a subsampled version of x_mask.
|
||||
|
||||
Args:
|
||||
x_mask (torch.Tensor): (B, 1, T)
|
||||
x (torch.Tensor): (B, sub(T), attention_dim)
|
||||
|
||||
Returns:
|
||||
x_mask (torch.Tensor): (B, 1, sub(T))
|
||||
|
||||
"""
|
||||
x_t1 = x_mask.size(2) - (x_mask.size(2) % 3)
|
||||
x_mask = x_mask[:, :, :x_t1][:, :, ::3]
|
||||
|
||||
x_t2 = x_mask.size(2) - (x_mask.size(2) % 2)
|
||||
x_mask = x_mask[:, :, :x_t2][:, :, ::2]
|
||||
|
||||
return x_mask
|
||||
298
ppg_extractor/encoders.py
Normal file
298
ppg_extractor/encoders.py
Normal file
@@ -0,0 +1,298 @@
|
||||
import logging
|
||||
import six
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pack_padded_sequence
|
||||
from torch.nn.utils.rnn import pad_packed_sequence
|
||||
|
||||
from .e2e_asr_common import get_vgg2l_odim
|
||||
from .nets_utils import make_pad_mask, to_device
|
||||
|
||||
|
||||
class RNNP(torch.nn.Module):
|
||||
"""RNN with projection layer module
|
||||
|
||||
:param int idim: dimension of inputs
|
||||
:param int elayers: number of encoder layers
|
||||
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
|
||||
:param int hdim: number of projection units
|
||||
:param np.ndarray subsample: list of subsampling numbers
|
||||
:param float dropout: dropout rate
|
||||
:param str typ: The RNN type
|
||||
"""
|
||||
|
||||
def __init__(self, idim, elayers, cdim, hdim, subsample, dropout, typ="blstm"):
|
||||
super(RNNP, self).__init__()
|
||||
bidir = typ[0] == "b"
|
||||
for i in six.moves.range(elayers):
|
||||
if i == 0:
|
||||
inputdim = idim
|
||||
else:
|
||||
inputdim = hdim
|
||||
rnn = torch.nn.LSTM(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir,
|
||||
batch_first=True) if "lstm" in typ \
|
||||
else torch.nn.GRU(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir, batch_first=True)
|
||||
setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn)
|
||||
# bottleneck layer to merge
|
||||
if bidir:
|
||||
setattr(self, "bt%d" % i, torch.nn.Linear(2 * cdim, hdim))
|
||||
else:
|
||||
setattr(self, "bt%d" % i, torch.nn.Linear(cdim, hdim))
|
||||
|
||||
self.elayers = elayers
|
||||
self.cdim = cdim
|
||||
self.subsample = subsample
|
||||
self.typ = typ
|
||||
self.bidir = bidir
|
||||
|
||||
def forward(self, xs_pad, ilens, prev_state=None):
|
||||
"""RNNP forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:param torch.Tensor prev_state: batch of previous RNN states
|
||||
:return: batch of hidden state sequences (B, Tmax, hdim)
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
|
||||
elayer_states = []
|
||||
for layer in six.moves.range(self.elayers):
|
||||
xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True, enforce_sorted=False)
|
||||
rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer))
|
||||
rnn.flatten_parameters()
|
||||
if prev_state is not None and rnn.bidirectional:
|
||||
prev_state = reset_backward_rnn_state(prev_state)
|
||||
ys, states = rnn(xs_pack, hx=None if prev_state is None else prev_state[layer])
|
||||
elayer_states.append(states)
|
||||
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
|
||||
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
|
||||
sub = self.subsample[layer + 1]
|
||||
if sub > 1:
|
||||
ys_pad = ys_pad[:, ::sub]
|
||||
ilens = [int(i + 1) // sub for i in ilens]
|
||||
# (sum _utt frame_utt) x dim
|
||||
projected = getattr(self, 'bt' + str(layer)
|
||||
)(ys_pad.contiguous().view(-1, ys_pad.size(2)))
|
||||
if layer == self.elayers - 1:
|
||||
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
|
||||
else:
|
||||
xs_pad = torch.tanh(projected.view(ys_pad.size(0), ys_pad.size(1), -1))
|
||||
|
||||
return xs_pad, ilens, elayer_states # x: utt list of frame x dim
|
||||
|
||||
|
||||
class RNN(torch.nn.Module):
|
||||
"""RNN module
|
||||
|
||||
:param int idim: dimension of inputs
|
||||
:param int elayers: number of encoder layers
|
||||
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
|
||||
:param int hdim: number of final projection units
|
||||
:param float dropout: dropout rate
|
||||
:param str typ: The RNN type
|
||||
"""
|
||||
|
||||
def __init__(self, idim, elayers, cdim, hdim, dropout, typ="blstm"):
|
||||
super(RNN, self).__init__()
|
||||
bidir = typ[0] == "b"
|
||||
self.nbrnn = torch.nn.LSTM(idim, cdim, elayers, batch_first=True,
|
||||
dropout=dropout, bidirectional=bidir) if "lstm" in typ \
|
||||
else torch.nn.GRU(idim, cdim, elayers, batch_first=True, dropout=dropout,
|
||||
bidirectional=bidir)
|
||||
if bidir:
|
||||
self.l_last = torch.nn.Linear(cdim * 2, hdim)
|
||||
else:
|
||||
self.l_last = torch.nn.Linear(cdim, hdim)
|
||||
self.typ = typ
|
||||
|
||||
def forward(self, xs_pad, ilens, prev_state=None):
|
||||
"""RNN forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:param torch.Tensor prev_state: batch of previous RNN states
|
||||
:return: batch of hidden state sequences (B, Tmax, eprojs)
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
|
||||
xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True)
|
||||
self.nbrnn.flatten_parameters()
|
||||
if prev_state is not None and self.nbrnn.bidirectional:
|
||||
# We assume that when previous state is passed, it means that we're streaming the input
|
||||
# and therefore cannot propagate backward BRNN state (otherwise it goes in the wrong direction)
|
||||
prev_state = reset_backward_rnn_state(prev_state)
|
||||
ys, states = self.nbrnn(xs_pack, hx=prev_state)
|
||||
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
|
||||
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
|
||||
# (sum _utt frame_utt) x dim
|
||||
projected = torch.tanh(self.l_last(
|
||||
ys_pad.contiguous().view(-1, ys_pad.size(2))))
|
||||
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
|
||||
return xs_pad, ilens, states # x: utt list of frame x dim
|
||||
|
||||
|
||||
def reset_backward_rnn_state(states):
|
||||
"""Sets backward BRNN states to zeroes - useful in processing of sliding windows over the inputs"""
|
||||
if isinstance(states, (list, tuple)):
|
||||
for state in states:
|
||||
state[1::2] = 0.
|
||||
else:
|
||||
states[1::2] = 0.
|
||||
return states
|
||||
|
||||
|
||||
class VGG2L(torch.nn.Module):
|
||||
"""VGG-like module
|
||||
|
||||
:param int in_channel: number of input channels
|
||||
"""
|
||||
|
||||
def __init__(self, in_channel=1, downsample=True):
|
||||
super(VGG2L, self).__init__()
|
||||
# CNN layer (VGG motivated)
|
||||
self.conv1_1 = torch.nn.Conv2d(in_channel, 64, 3, stride=1, padding=1)
|
||||
self.conv1_2 = torch.nn.Conv2d(64, 64, 3, stride=1, padding=1)
|
||||
self.conv2_1 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1)
|
||||
self.conv2_2 = torch.nn.Conv2d(128, 128, 3, stride=1, padding=1)
|
||||
|
||||
self.in_channel = in_channel
|
||||
self.downsample = downsample
|
||||
if downsample:
|
||||
self.stride = 2
|
||||
else:
|
||||
self.stride = 1
|
||||
|
||||
def forward(self, xs_pad, ilens, **kwargs):
|
||||
"""VGG2L forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:return: batch of padded hidden state sequences (B, Tmax // 4, 128 * D // 4) if downsample
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
|
||||
|
||||
# x: utt x frame x dim
|
||||
# xs_pad = F.pad_sequence(xs_pad)
|
||||
|
||||
# x: utt x 1 (input channel num) x frame x dim
|
||||
xs_pad = xs_pad.view(xs_pad.size(0), xs_pad.size(1), self.in_channel,
|
||||
xs_pad.size(2) // self.in_channel).transpose(1, 2)
|
||||
|
||||
# NOTE: max_pool1d ?
|
||||
xs_pad = F.relu(self.conv1_1(xs_pad))
|
||||
xs_pad = F.relu(self.conv1_2(xs_pad))
|
||||
if self.downsample:
|
||||
xs_pad = F.max_pool2d(xs_pad, 2, stride=self.stride, ceil_mode=True)
|
||||
|
||||
xs_pad = F.relu(self.conv2_1(xs_pad))
|
||||
xs_pad = F.relu(self.conv2_2(xs_pad))
|
||||
if self.downsample:
|
||||
xs_pad = F.max_pool2d(xs_pad, 2, stride=self.stride, ceil_mode=True)
|
||||
if torch.is_tensor(ilens):
|
||||
ilens = ilens.cpu().numpy()
|
||||
else:
|
||||
ilens = np.array(ilens, dtype=np.float32)
|
||||
if self.downsample:
|
||||
ilens = np.array(np.ceil(ilens / 2), dtype=np.int64)
|
||||
ilens = np.array(
|
||||
np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64).tolist()
|
||||
|
||||
# x: utt_list of frame (remove zeropaded frames) x (input channel num x dim)
|
||||
xs_pad = xs_pad.transpose(1, 2)
|
||||
xs_pad = xs_pad.contiguous().view(
|
||||
xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3))
|
||||
return xs_pad, ilens, None # no state in this layer
|
||||
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
"""Encoder module
|
||||
|
||||
:param str etype: type of encoder network
|
||||
:param int idim: number of dimensions of encoder network
|
||||
:param int elayers: number of layers of encoder network
|
||||
:param int eunits: number of lstm units of encoder network
|
||||
:param int eprojs: number of projection units of encoder network
|
||||
:param np.ndarray subsample: list of subsampling numbers
|
||||
:param float dropout: dropout rate
|
||||
:param int in_channel: number of input channels
|
||||
"""
|
||||
|
||||
def __init__(self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1):
|
||||
super(Encoder, self).__init__()
|
||||
typ = etype.lstrip("vgg").rstrip("p")
|
||||
if typ not in ['lstm', 'gru', 'blstm', 'bgru']:
|
||||
logging.error("Error: need to specify an appropriate encoder architecture")
|
||||
|
||||
if etype.startswith("vgg"):
|
||||
if etype[-1] == "p":
|
||||
self.enc = torch.nn.ModuleList([VGG2L(in_channel),
|
||||
RNNP(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits,
|
||||
eprojs,
|
||||
subsample, dropout, typ=typ)])
|
||||
logging.info('Use CNN-VGG + ' + typ.upper() + 'P for encoder')
|
||||
else:
|
||||
self.enc = torch.nn.ModuleList([VGG2L(in_channel),
|
||||
RNN(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits,
|
||||
eprojs,
|
||||
dropout, typ=typ)])
|
||||
logging.info('Use CNN-VGG + ' + typ.upper() + ' for encoder')
|
||||
else:
|
||||
if etype[-1] == "p":
|
||||
self.enc = torch.nn.ModuleList(
|
||||
[RNNP(idim, elayers, eunits, eprojs, subsample, dropout, typ=typ)])
|
||||
logging.info(typ.upper() + ' with every-layer projection for encoder')
|
||||
else:
|
||||
self.enc = torch.nn.ModuleList([RNN(idim, elayers, eunits, eprojs, dropout, typ=typ)])
|
||||
logging.info(typ.upper() + ' without projection for encoder')
|
||||
|
||||
def forward(self, xs_pad, ilens, prev_states=None):
|
||||
"""Encoder forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...)
|
||||
:return: batch of hidden state sequences (B, Tmax, eprojs)
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if prev_states is None:
|
||||
prev_states = [None] * len(self.enc)
|
||||
assert len(prev_states) == len(self.enc)
|
||||
|
||||
current_states = []
|
||||
for module, prev_state in zip(self.enc, prev_states):
|
||||
xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
|
||||
current_states.append(states)
|
||||
|
||||
# make mask to remove bias value in padded part
|
||||
mask = to_device(self, make_pad_mask(ilens).unsqueeze(-1))
|
||||
|
||||
return xs_pad.masked_fill(mask, 0.0), ilens, current_states
|
||||
|
||||
|
||||
def encoder_for(args, idim, subsample):
|
||||
"""Instantiates an encoder module given the program arguments
|
||||
|
||||
:param Namespace args: The arguments
|
||||
:param int or List of integer idim: dimension of input, e.g. 83, or
|
||||
List of dimensions of inputs, e.g. [83,83]
|
||||
:param List or List of List subsample: subsample factors, e.g. [1,2,2,1,1], or
|
||||
List of subsample factors of each encoder. e.g. [[1,2,2,1,1], [1,2,2,1,1]]
|
||||
:rtype torch.nn.Module
|
||||
:return: The encoder module
|
||||
"""
|
||||
num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility
|
||||
if num_encs == 1:
|
||||
# compatible with single encoder asr mode
|
||||
return Encoder(args.etype, idim, args.elayers, args.eunits, args.eprojs, subsample, args.dropout_rate)
|
||||
elif num_encs >= 1:
|
||||
enc_list = torch.nn.ModuleList()
|
||||
for idx in range(num_encs):
|
||||
enc = Encoder(args.etype[idx], idim[idx], args.elayers[idx], args.eunits[idx], args.eprojs, subsample[idx],
|
||||
args.dropout_rate[idx])
|
||||
enc_list.append(enc)
|
||||
return enc_list
|
||||
else:
|
||||
raise ValueError("Number of encoders needs to be more than one. {}".format(num_encs))
|
||||
115
ppg_extractor/frontend.py
Normal file
115
ppg_extractor/frontend.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import copy
|
||||
from typing import Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
from .log_mel import LogMel
|
||||
from .stft import Stft
|
||||
|
||||
|
||||
class DefaultFrontend(torch.nn.Module):
|
||||
"""Conventional frontend structure for ASR
|
||||
|
||||
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: 16000,
|
||||
n_fft: int = 1024,
|
||||
win_length: int = 800,
|
||||
hop_length: int = 160,
|
||||
center: bool = True,
|
||||
pad_mode: str = "reflect",
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
n_mels: int = 80,
|
||||
fmin: int = None,
|
||||
fmax: int = None,
|
||||
htk: bool = False,
|
||||
norm=1,
|
||||
frontend_conf=None, #Optional[dict] = get_default_kwargs(Frontend),
|
||||
kaldi_padding_mode=False,
|
||||
downsample_rate: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.downsample_rate = downsample_rate
|
||||
|
||||
# Deepcopy (In general, dict shouldn't be used as default arg)
|
||||
frontend_conf = copy.deepcopy(frontend_conf)
|
||||
|
||||
self.stft = Stft(
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
center=center,
|
||||
pad_mode=pad_mode,
|
||||
normalized=normalized,
|
||||
onesided=onesided,
|
||||
kaldi_padding_mode=kaldi_padding_mode
|
||||
)
|
||||
if frontend_conf is not None:
|
||||
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
|
||||
else:
|
||||
self.frontend = None
|
||||
|
||||
self.logmel = LogMel(
|
||||
fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm,
|
||||
)
|
||||
self.n_mels = n_mels
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Domain-conversion: e.g. Stft: time -> time-freq
|
||||
input_stft, feats_lens = self.stft(input, input_lengths)
|
||||
|
||||
assert input_stft.dim() >= 4, input_stft.shape
|
||||
# "2" refers to the real/imag parts of Complex
|
||||
assert input_stft.shape[-1] == 2, input_stft.shape
|
||||
|
||||
# Change torch.Tensor to ComplexTensor
|
||||
# input_stft: (..., F, 2) -> (..., F)
|
||||
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
|
||||
|
||||
# 2. [Option] Speech enhancement
|
||||
if self.frontend is not None:
|
||||
assert isinstance(input_stft, ComplexTensor), type(input_stft)
|
||||
# input_stft: (Batch, Length, [Channel], Freq)
|
||||
input_stft, _, mask = self.frontend(input_stft, feats_lens)
|
||||
|
||||
# 3. [Multi channel case]: Select a channel
|
||||
if input_stft.dim() == 4:
|
||||
# h: (B, T, C, F) -> h: (B, T, F)
|
||||
if self.training:
|
||||
# Select 1ch randomly
|
||||
ch = np.random.randint(input_stft.size(2))
|
||||
input_stft = input_stft[:, :, ch, :]
|
||||
else:
|
||||
# Use the first channel
|
||||
input_stft = input_stft[:, :, 0, :]
|
||||
|
||||
# 4. STFT -> Power spectrum
|
||||
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
|
||||
input_power = input_stft.real ** 2 + input_stft.imag ** 2
|
||||
|
||||
# 5. Feature transform e.g. Stft -> Log-Mel-Fbank
|
||||
# input_power: (Batch, [Channel,] Length, Freq)
|
||||
# -> input_feats: (Batch, Length, Dim)
|
||||
input_feats, _ = self.logmel(input_power, feats_lens)
|
||||
|
||||
# NOTE(sx): pad
|
||||
max_len = input_feats.size(1)
|
||||
if self.downsample_rate > 1 and max_len % self.downsample_rate != 0:
|
||||
padding = self.downsample_rate - max_len % self.downsample_rate
|
||||
# print("Logmel: ", input_feats.size())
|
||||
input_feats = torch.nn.functional.pad(input_feats, (0, 0, 0, padding),
|
||||
"constant", 0)
|
||||
# print("Logmel(after padding): ",input_feats.size())
|
||||
feats_lens[torch.argmax(feats_lens)] = max_len + padding
|
||||
|
||||
return input_feats, feats_lens
|
||||
74
ppg_extractor/log_mel.py
Normal file
74
ppg_extractor/log_mel.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from .nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class LogMel(torch.nn.Module):
|
||||
"""Convert STFT to fbank feats
|
||||
|
||||
The arguments is same as librosa.filters.mel
|
||||
|
||||
Args:
|
||||
fs: number > 0 [scalar] sampling rate of the incoming signal
|
||||
n_fft: int > 0 [scalar] number of FFT components
|
||||
n_mels: int > 0 [scalar] number of Mel bands to generate
|
||||
fmin: float >= 0 [scalar] lowest frequency (in Hz)
|
||||
fmax: float >= 0 [scalar] highest frequency (in Hz).
|
||||
If `None`, use `fmax = fs / 2.0`
|
||||
htk: use HTK formula instead of Slaney
|
||||
norm: {None, 1, np.inf} [scalar]
|
||||
if 1, divide the triangular mel weights by the width of the mel band
|
||||
(area normalization). Otherwise, leave all the triangles aiming for
|
||||
a peak value of 1.0
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: int = 16000,
|
||||
n_fft: int = 512,
|
||||
n_mels: int = 80,
|
||||
fmin: float = None,
|
||||
fmax: float = None,
|
||||
htk: bool = False,
|
||||
norm=1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
fmin = 0 if fmin is None else fmin
|
||||
fmax = fs / 2 if fmax is None else fmax
|
||||
_mel_options = dict(
|
||||
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
|
||||
)
|
||||
self.mel_options = _mel_options
|
||||
|
||||
# Note(kamo): The mel matrix of librosa is different from kaldi.
|
||||
melmat = librosa.filters.mel(**_mel_options)
|
||||
# melmat: (D2, D1) -> (D1, D2)
|
||||
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
|
||||
inv_mel = np.linalg.pinv(melmat)
|
||||
self.register_buffer("inv_melmat", torch.from_numpy(inv_mel.T).float())
|
||||
|
||||
def extra_repr(self):
|
||||
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
|
||||
|
||||
def forward(
|
||||
self, feat: torch.Tensor, ilens: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
|
||||
mel_feat = torch.matmul(feat, self.melmat)
|
||||
|
||||
logmel_feat = (mel_feat + 1e-20).log()
|
||||
# Zero padding
|
||||
if ilens is not None:
|
||||
logmel_feat = logmel_feat.masked_fill(
|
||||
make_pad_mask(ilens, logmel_feat, 1), 0.0
|
||||
)
|
||||
else:
|
||||
ilens = feat.new_full(
|
||||
[feat.size(0)], fill_value=feat.size(1), dtype=torch.long
|
||||
)
|
||||
return logmel_feat, ilens
|
||||
465
ppg_extractor/nets_utils.py
Normal file
465
ppg_extractor/nets_utils.py
Normal file
@@ -0,0 +1,465 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""Network related utility tools."""
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def to_device(m, x):
|
||||
"""Send tensor into the device of the module.
|
||||
|
||||
Args:
|
||||
m (torch.nn.Module): Torch module.
|
||||
x (Tensor): Torch tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: Torch tensor located in the same place as torch module.
|
||||
|
||||
"""
|
||||
assert isinstance(m, torch.nn.Module)
|
||||
device = next(m.parameters()).device
|
||||
return x.to(device)
|
||||
|
||||
|
||||
def pad_list(xs, pad_value):
|
||||
"""Perform padding for the list of tensors.
|
||||
|
||||
Args:
|
||||
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
||||
pad_value (float): Value for padding.
|
||||
|
||||
Returns:
|
||||
Tensor: Padded tensor (B, Tmax, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
||||
>>> x
|
||||
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
||||
>>> pad_list(x, 0)
|
||||
tensor([[1., 1., 1., 1.],
|
||||
[1., 1., 0., 0.],
|
||||
[1., 0., 0., 0.]])
|
||||
|
||||
"""
|
||||
n_batch = len(xs)
|
||||
max_len = max(x.size(0) for x in xs)
|
||||
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
||||
|
||||
for i in range(n_batch):
|
||||
pad[i, :xs[i].size(0)] = xs[i]
|
||||
|
||||
return pad
|
||||
|
||||
|
||||
def make_pad_mask(lengths, xs=None, length_dim=-1):
|
||||
"""Make mask tensor containing indices of padded part.
|
||||
|
||||
Args:
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
|
||||
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
|
||||
|
||||
Returns:
|
||||
Tensor: Mask tensor containing indices of padded part.
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||
|
||||
Examples:
|
||||
With only lengths.
|
||||
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_non_pad_mask(lengths)
|
||||
masks = [[0, 0, 0, 0 ,0],
|
||||
[0, 0, 0, 1, 1],
|
||||
[0, 0, 1, 1, 1]]
|
||||
|
||||
With the reference tensor.
|
||||
|
||||
>>> xs = torch.zeros((3, 2, 4))
|
||||
>>> make_pad_mask(lengths, xs)
|
||||
tensor([[[0, 0, 0, 0],
|
||||
[0, 0, 0, 0]],
|
||||
[[0, 0, 0, 1],
|
||||
[0, 0, 0, 1]],
|
||||
[[0, 0, 1, 1],
|
||||
[0, 0, 1, 1]]], dtype=torch.uint8)
|
||||
>>> xs = torch.zeros((3, 2, 6))
|
||||
>>> make_pad_mask(lengths, xs)
|
||||
tensor([[[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1]],
|
||||
[[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1]],
|
||||
[[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
|
||||
With the reference tensor and dimension indicator.
|
||||
|
||||
>>> xs = torch.zeros((3, 6, 6))
|
||||
>>> make_pad_mask(lengths, xs, 1)
|
||||
tensor([[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1]],
|
||||
[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1]],
|
||||
[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
>>> make_pad_mask(lengths, xs, 2)
|
||||
tensor([[[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1]],
|
||||
[[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1]],
|
||||
[[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
|
||||
"""
|
||||
if length_dim == 0:
|
||||
raise ValueError('length_dim cannot be 0: {}'.format(length_dim))
|
||||
|
||||
if not isinstance(lengths, list):
|
||||
lengths = lengths.tolist()
|
||||
bs = int(len(lengths))
|
||||
if xs is None:
|
||||
maxlen = int(max(lengths))
|
||||
else:
|
||||
maxlen = xs.size(length_dim)
|
||||
|
||||
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
||||
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
|
||||
if xs is not None:
|
||||
assert xs.size(0) == bs, (xs.size(0), bs)
|
||||
|
||||
if length_dim < 0:
|
||||
length_dim = xs.dim() + length_dim
|
||||
# ind = (:, None, ..., None, :, , None, ..., None)
|
||||
ind = tuple(slice(None) if i in (0, length_dim) else None
|
||||
for i in range(xs.dim()))
|
||||
mask = mask[ind].expand_as(xs).to(xs.device)
|
||||
return mask
|
||||
|
||||
|
||||
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
|
||||
"""Make mask tensor containing indices of non-padded part.
|
||||
|
||||
Args:
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
|
||||
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
|
||||
|
||||
Returns:
|
||||
ByteTensor: mask tensor containing indices of padded part.
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||
|
||||
Examples:
|
||||
With only lengths.
|
||||
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_non_pad_mask(lengths)
|
||||
masks = [[1, 1, 1, 1 ,1],
|
||||
[1, 1, 1, 0, 0],
|
||||
[1, 1, 0, 0, 0]]
|
||||
|
||||
With the reference tensor.
|
||||
|
||||
>>> xs = torch.zeros((3, 2, 4))
|
||||
>>> make_non_pad_mask(lengths, xs)
|
||||
tensor([[[1, 1, 1, 1],
|
||||
[1, 1, 1, 1]],
|
||||
[[1, 1, 1, 0],
|
||||
[1, 1, 1, 0]],
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0]]], dtype=torch.uint8)
|
||||
>>> xs = torch.zeros((3, 2, 6))
|
||||
>>> make_non_pad_mask(lengths, xs)
|
||||
tensor([[[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0]],
|
||||
[[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0]],
|
||||
[[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
|
||||
With the reference tensor and dimension indicator.
|
||||
|
||||
>>> xs = torch.zeros((3, 6, 6))
|
||||
>>> make_non_pad_mask(lengths, xs, 1)
|
||||
tensor([[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0]],
|
||||
[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0]],
|
||||
[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
>>> make_non_pad_mask(lengths, xs, 2)
|
||||
tensor([[[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0]],
|
||||
[[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0]],
|
||||
[[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
|
||||
"""
|
||||
return ~make_pad_mask(lengths, xs, length_dim)
|
||||
|
||||
|
||||
def mask_by_length(xs, lengths, fill=0):
|
||||
"""Mask tensor according to length.
|
||||
|
||||
Args:
|
||||
xs (Tensor): Batch of input tensor (B, `*`).
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
fill (int or float): Value to fill masked part.
|
||||
|
||||
Returns:
|
||||
Tensor: Batch of masked input tensor (B, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = torch.arange(5).repeat(3, 1) + 1
|
||||
>>> x
|
||||
tensor([[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 4, 5]])
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> mask_by_length(x, lengths)
|
||||
tensor([[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 0, 0],
|
||||
[1, 2, 0, 0, 0]])
|
||||
|
||||
"""
|
||||
assert xs.size(0) == len(lengths)
|
||||
ret = xs.data.new(*xs.size()).fill_(fill)
|
||||
for i, l in enumerate(lengths):
|
||||
ret[i, :l] = xs[i, :l]
|
||||
return ret
|
||||
|
||||
|
||||
def th_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||
"""Calculate accuracy.
|
||||
|
||||
Args:
|
||||
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||||
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
||||
ignore_label (int): Ignore label id.
|
||||
|
||||
Returns:
|
||||
float: Accuracy value (0.0 - 1.0).
|
||||
|
||||
"""
|
||||
pad_pred = pad_outputs.view(
|
||||
pad_targets.size(0),
|
||||
pad_targets.size(1),
|
||||
pad_outputs.size(1)).argmax(2)
|
||||
mask = pad_targets != ignore_label
|
||||
numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||||
denominator = torch.sum(mask)
|
||||
return float(numerator) / float(denominator)
|
||||
|
||||
|
||||
def to_torch_tensor(x):
|
||||
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
|
||||
|
||||
Args:
|
||||
x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
|
||||
|
||||
Returns:
|
||||
Tensor or ComplexTensor: Type converted inputs.
|
||||
|
||||
Examples:
|
||||
>>> xs = np.ones(3, dtype=np.float32)
|
||||
>>> xs = to_torch_tensor(xs)
|
||||
tensor([1., 1., 1.])
|
||||
>>> xs = torch.ones(3, 4, 5)
|
||||
>>> assert to_torch_tensor(xs) is xs
|
||||
>>> xs = {'real': xs, 'imag': xs}
|
||||
>>> to_torch_tensor(xs)
|
||||
ComplexTensor(
|
||||
Real:
|
||||
tensor([1., 1., 1.])
|
||||
Imag;
|
||||
tensor([1., 1., 1.])
|
||||
)
|
||||
|
||||
"""
|
||||
# If numpy, change to torch tensor
|
||||
if isinstance(x, np.ndarray):
|
||||
if x.dtype.kind == 'c':
|
||||
# Dynamically importing because torch_complex requires python3
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
return ComplexTensor(x)
|
||||
else:
|
||||
return torch.from_numpy(x)
|
||||
|
||||
# If {'real': ..., 'imag': ...}, convert to ComplexTensor
|
||||
elif isinstance(x, dict):
|
||||
# Dynamically importing because torch_complex requires python3
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
if 'real' not in x or 'imag' not in x:
|
||||
raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
|
||||
# Relative importing because of using python3 syntax
|
||||
return ComplexTensor(x['real'], x['imag'])
|
||||
|
||||
# If torch.Tensor, as it is
|
||||
elif isinstance(x, torch.Tensor):
|
||||
return x
|
||||
|
||||
else:
|
||||
error = ("x must be numpy.ndarray, torch.Tensor or a dict like "
|
||||
"{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
|
||||
"but got {}".format(type(x)))
|
||||
try:
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
except Exception:
|
||||
# If PY2
|
||||
raise ValueError(error)
|
||||
else:
|
||||
# If PY3
|
||||
if isinstance(x, ComplexTensor):
|
||||
return x
|
||||
else:
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def get_subsample(train_args, mode, arch):
|
||||
"""Parse the subsampling factors from the training args for the specified `mode` and `arch`.
|
||||
|
||||
Args:
|
||||
train_args: argument Namespace containing options.
|
||||
mode: one of ('asr', 'mt', 'st')
|
||||
arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
|
||||
|
||||
Returns:
|
||||
np.ndarray / List[np.ndarray]: subsampling factors.
|
||||
"""
|
||||
if arch == 'transformer':
|
||||
return np.array([1])
|
||||
|
||||
elif mode == 'mt' and arch == 'rnn':
|
||||
# +1 means input (+1) and layers outputs (train_args.elayer)
|
||||
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||
logging.warning('Subsampling is not performed for machine translation.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif (mode == 'asr' and arch in ('rnn', 'rnn-t')) or \
|
||||
(mode == 'mt' and arch == 'rnn') or \
|
||||
(mode == 'st' and arch == 'rnn'):
|
||||
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||
ss = train_args.subsample.split("_")
|
||||
for j in range(min(train_args.elayers + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif mode == 'asr' and arch == 'rnn_mix':
|
||||
subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int)
|
||||
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||
ss = train_args.subsample.split("_")
|
||||
for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif mode == 'asr' and arch == 'rnn_mulenc':
|
||||
subsample_list = []
|
||||
for idx in range(train_args.num_encs):
|
||||
subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int)
|
||||
if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"):
|
||||
ss = train_args.subsample[idx].split("_")
|
||||
for j in range(min(train_args.elayers[idx] + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Encoder %d: Subsampling is not performed for vgg*. '
|
||||
'It is performed in max pooling layers at CNN.', idx + 1)
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
subsample_list.append(subsample)
|
||||
return subsample_list
|
||||
|
||||
else:
|
||||
raise ValueError('Invalid options: mode={}, arch={}'.format(mode, arch))
|
||||
|
||||
|
||||
def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]):
|
||||
"""Replace keys of old prefix with new prefix in state dict."""
|
||||
# need this list not to break the dict iterator
|
||||
old_keys = [k for k in state_dict if k.startswith(old_prefix)]
|
||||
if len(old_keys) > 0:
|
||||
logging.warning(f'Rename: {old_prefix} -> {new_prefix}')
|
||||
for k in old_keys:
|
||||
v = state_dict.pop(k)
|
||||
new_k = k.replace(old_prefix, new_prefix)
|
||||
state_dict[new_k] = v
|
||||
|
||||
def get_activation(act):
|
||||
"""Return activation function."""
|
||||
# Lazy load to avoid unused import
|
||||
from .encoder.swish import Swish
|
||||
|
||||
activation_funcs = {
|
||||
"hardtanh": torch.nn.Hardtanh,
|
||||
"relu": torch.nn.ReLU,
|
||||
"selu": torch.nn.SELU,
|
||||
"swish": Swish,
|
||||
}
|
||||
|
||||
return activation_funcs[act]()
|
||||
118
ppg_extractor/stft.py
Normal file
118
ppg_extractor/stft.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from .nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class Stft(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_fft: int = 512,
|
||||
win_length: Union[int, None] = 512,
|
||||
hop_length: int = 128,
|
||||
center: bool = True,
|
||||
pad_mode: str = "reflect",
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
kaldi_padding_mode=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
if win_length is None:
|
||||
self.win_length = n_fft
|
||||
else:
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.center = center
|
||||
self.pad_mode = pad_mode
|
||||
self.normalized = normalized
|
||||
self.onesided = onesided
|
||||
self.kaldi_padding_mode = kaldi_padding_mode
|
||||
if self.kaldi_padding_mode:
|
||||
self.win_length = 400
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
f"n_fft={self.n_fft}, "
|
||||
f"win_length={self.win_length}, "
|
||||
f"hop_length={self.hop_length}, "
|
||||
f"center={self.center}, "
|
||||
f"pad_mode={self.pad_mode}, "
|
||||
f"normalized={self.normalized}, "
|
||||
f"onesided={self.onesided}"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, ilens: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""STFT forward function.
|
||||
|
||||
Args:
|
||||
input: (Batch, Nsamples) or (Batch, Nsample, Channels)
|
||||
ilens: (Batch)
|
||||
Returns:
|
||||
output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
|
||||
|
||||
"""
|
||||
bs = input.size(0)
|
||||
if input.dim() == 3:
|
||||
multi_channel = True
|
||||
# input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
|
||||
input = input.transpose(1, 2).reshape(-1, input.size(1))
|
||||
else:
|
||||
multi_channel = False
|
||||
|
||||
# output: (Batch, Freq, Frames, 2=real_imag)
|
||||
# or (Batch, Channel, Freq, Frames, 2=real_imag)
|
||||
if not self.kaldi_padding_mode:
|
||||
output = torch.stft(
|
||||
input,
|
||||
n_fft=self.n_fft,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length,
|
||||
center=self.center,
|
||||
pad_mode=self.pad_mode,
|
||||
normalized=self.normalized,
|
||||
onesided=self.onesided,
|
||||
return_complex=False
|
||||
)
|
||||
else:
|
||||
# NOTE(sx): Use Kaldi-fasion padding, maybe wrong
|
||||
num_pads = self.n_fft - self.win_length
|
||||
input = torch.nn.functional.pad(input, (num_pads, 0))
|
||||
output = torch.stft(
|
||||
input,
|
||||
n_fft=self.n_fft,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length,
|
||||
center=False,
|
||||
pad_mode=self.pad_mode,
|
||||
normalized=self.normalized,
|
||||
onesided=self.onesided,
|
||||
return_complex=False
|
||||
)
|
||||
|
||||
# output: (Batch, Freq, Frames, 2=real_imag)
|
||||
# -> (Batch, Frames, Freq, 2=real_imag)
|
||||
output = output.transpose(1, 2)
|
||||
if multi_channel:
|
||||
# output: (Batch * Channel, Frames, Freq, 2=real_imag)
|
||||
# -> (Batch, Frame, Channel, Freq, 2=real_imag)
|
||||
output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
|
||||
1, 2
|
||||
)
|
||||
|
||||
if ilens is not None:
|
||||
if self.center:
|
||||
pad = self.win_length // 2
|
||||
ilens = ilens + 2 * pad
|
||||
olens = torch.div(ilens - self.win_length, self.hop_length, rounding_mode='floor') + 1
|
||||
# olens = ilens - self.win_length // self.hop_length + 1
|
||||
output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
|
||||
else:
|
||||
olens = None
|
||||
|
||||
return output, olens
|
||||
82
ppg_extractor/utterance_mvn.py
Normal file
82
ppg_extractor/utterance_mvn.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class UtteranceMVN(torch.nn.Module):
|
||||
def __init__(
|
||||
self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_means = norm_means
|
||||
self.norm_vars = norm_vars
|
||||
self.eps = eps
|
||||
|
||||
def extra_repr(self):
|
||||
return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, ilens: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward function
|
||||
|
||||
Args:
|
||||
x: (B, L, ...)
|
||||
ilens: (B,)
|
||||
|
||||
"""
|
||||
return utterance_mvn(
|
||||
x,
|
||||
ilens,
|
||||
norm_means=self.norm_means,
|
||||
norm_vars=self.norm_vars,
|
||||
eps=self.eps,
|
||||
)
|
||||
|
||||
|
||||
def utterance_mvn(
|
||||
x: torch.Tensor,
|
||||
ilens: torch.Tensor = None,
|
||||
norm_means: bool = True,
|
||||
norm_vars: bool = False,
|
||||
eps: float = 1.0e-20,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Apply utterance mean and variance normalization
|
||||
|
||||
Args:
|
||||
x: (B, T, D), assumed zero padded
|
||||
ilens: (B,)
|
||||
norm_means:
|
||||
norm_vars:
|
||||
eps:
|
||||
|
||||
"""
|
||||
if ilens is None:
|
||||
ilens = x.new_full([x.size(0)], x.size(1))
|
||||
ilens_ = ilens.to(x.device, x.dtype).view(-1, *[1 for _ in range(x.dim() - 1)])
|
||||
# Zero padding
|
||||
if x.requires_grad:
|
||||
x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
|
||||
else:
|
||||
x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
|
||||
# mean: (B, 1, D)
|
||||
mean = x.sum(dim=1, keepdim=True) / ilens_
|
||||
|
||||
if norm_means:
|
||||
x -= mean
|
||||
|
||||
if norm_vars:
|
||||
var = x.pow(2).sum(dim=1, keepdim=True) / ilens_
|
||||
std = torch.clamp(var.sqrt(), min=eps)
|
||||
x = x / std.sqrt()
|
||||
return x, ilens
|
||||
else:
|
||||
if norm_vars:
|
||||
y = x - mean
|
||||
y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0)
|
||||
var = y.pow(2).sum(dim=1, keepdim=True) / ilens_
|
||||
std = torch.clamp(var.sqrt(), min=eps)
|
||||
x /= std
|
||||
return x, ilens
|
||||
17
pre.py
17
pre.py
@@ -13,7 +13,7 @@ recognized_datasets = [
|
||||
"aidatatang_200zh",
|
||||
"magicdata",
|
||||
"aishell3",
|
||||
"BZNSYP"
|
||||
"data_aishell"
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -29,8 +29,7 @@ if __name__ == "__main__":
|
||||
"Path to the output directory that will contain the mel spectrograms, the audios and the "
|
||||
"embeds. Defaults to <datasets_root>/SV2TTS/synthesizer/")
|
||||
parser.add_argument("-n", "--n_processes", type=int, default=1, help=\
|
||||
"Number of processes in parallel.An encoder is created for each, so you may need to lower "
|
||||
"this value on GPUs with low memory. Set it to 1 if CUDA is unhappy")
|
||||
"Number of processes in parallel.")
|
||||
parser.add_argument("-s", "--skip_existing", action="store_true", help=\
|
||||
"Whether to overwrite existing files with the same name. Useful if the preprocessing was "
|
||||
"interrupted. ")
|
||||
@@ -41,10 +40,13 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--no_alignments", action="store_true", help=\
|
||||
"Use this option when dataset does not include alignments\
|
||||
(these are used to split long audio files into sub-utterances.)")
|
||||
parser.add_argument("-d","--dataset", type=str, default="aidatatang_200zh", help=\
|
||||
"Name of the dataset to process, allowing values: magicdata, aidatatang_200zh, aishell3, BZNSYP.")
|
||||
parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\
|
||||
"Name of the dataset to process, allowing values: magicdata, aidatatang_200zh, aishell3, data_aishell.")
|
||||
parser.add_argument("-e", "--encoder_model_fpath", type=Path, default="encoder/saved_models/pretrained.pt", help=\
|
||||
"Path your trained encoder model.")
|
||||
parser.add_argument("-ne", "--n_processes_embed", type=int, default=1, help=\
|
||||
"Number of processes in parallel.An encoder is created for each, so you may need to lower "
|
||||
"this value on GPUs with low memory. Set it to 1 if CUDA is unhappy")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Process the arguments
|
||||
@@ -67,7 +69,8 @@ if __name__ == "__main__":
|
||||
del args.no_trim, args.encoder_model_fpath
|
||||
|
||||
args.hparams = hparams.parse(args.hparams)
|
||||
|
||||
n_processes_embed = args.n_processes_embed
|
||||
del args.n_processes_embed
|
||||
preprocess_dataset(**vars(args))
|
||||
|
||||
create_embeddings(synthesizer_root=args.out_dir, n_processes=args.n_processes, encoder_model_fpath=encoder_model_fpath)
|
||||
create_embeddings(synthesizer_root=args.out_dir, n_processes=n_processes_embed, encoder_model_fpath=encoder_model_fpath)
|
||||
|
||||
49
pre4ppg.py
Normal file
49
pre4ppg.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
from ppg2mel.preprocess import preprocess_dataset
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
recognized_datasets = [
|
||||
"aidatatang_200zh",
|
||||
"aidatatang_200zh_s", # sample
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocesses audio files from datasets, to be used by the "
|
||||
"ppg2mel model for training.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument("datasets_root", type=Path, help=\
|
||||
"Path to the directory containing your datasets.")
|
||||
parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\
|
||||
"Name of the dataset to process, allowing values: aidatatang_200zh.")
|
||||
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
|
||||
"Path to the output directory that will contain the mel spectrograms, the audios and the "
|
||||
"embeds. Defaults to <datasets_root>/PPGVC/ppg2mel/")
|
||||
parser.add_argument("-n", "--n_processes", type=int, default=8, help=\
|
||||
"Number of processes in parallel.")
|
||||
# parser.add_argument("-s", "--skip_existing", action="store_true", help=\
|
||||
# "Whether to overwrite existing files with the same name. Useful if the preprocessing was "
|
||||
# "interrupted. ")
|
||||
# parser.add_argument("--hparams", type=str, default="", help=\
|
||||
# "Hyperparameter overrides as a comma-separated list of name-value pairs")
|
||||
# parser.add_argument("--no_trim", action="store_true", help=\
|
||||
# "Preprocess audio without trimming silences (not recommended).")
|
||||
parser.add_argument("-pf", "--ppg_encoder_model_fpath", type=Path, default="ppg_extractor/saved_models/24epoch.pt", help=\
|
||||
"Path your trained ppg encoder model.")
|
||||
parser.add_argument("-sf", "--speaker_encoder_model", type=Path, default="encoder/saved_models/pretrained_bak_5805000.pt", help=\
|
||||
"Path your trained speaker encoder model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.dataset in recognized_datasets, 'is not supported, file a issue to propose a new one'
|
||||
|
||||
# Create directories
|
||||
assert args.datasets_root.exists()
|
||||
if not hasattr(args, "out_dir"):
|
||||
args.out_dir = args.datasets_root.joinpath("PPGVC", "ppg2mel")
|
||||
args.out_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
preprocess_dataset(**vars(args))
|
||||
@@ -1,6 +1,6 @@
|
||||
umap-learn
|
||||
visdom
|
||||
librosa>=0.8.0
|
||||
librosa==0.8.1
|
||||
matplotlib>=3.3.0
|
||||
numpy==1.19.3; platform_system == "Windows"
|
||||
numpy==1.19.4; platform_system != "Windows"
|
||||
@@ -14,4 +14,14 @@ PyQt5
|
||||
multiprocess
|
||||
numba
|
||||
webrtcvad; platform_system != "Windows"
|
||||
pypinyin
|
||||
pypinyin
|
||||
flask
|
||||
flask_wtf
|
||||
flask_cors==3.0.10
|
||||
gevent==21.8.0
|
||||
flask_restx
|
||||
tensorboard
|
||||
streamlit==1.8.0
|
||||
PyYAML==5.4.1
|
||||
torch_complex
|
||||
espnet
|
||||
142
run.py
Normal file
142
run.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import time
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
import glob
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from ppg_extractor import load_model
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
from utils.load_yaml import HpsYaml
|
||||
|
||||
from encoder.audio import preprocess_wav
|
||||
from encoder import inference as speacker_encoder
|
||||
from vocoder.hifigan import inference as vocoder
|
||||
from ppg2mel import MelDecoderMOLv2
|
||||
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
|
||||
|
||||
|
||||
def _build_ppg2mel_model(model_config, model_file, device):
|
||||
ppg2mel_model = MelDecoderMOLv2(
|
||||
**model_config["model"]
|
||||
).to(device)
|
||||
ckpt = torch.load(model_file, map_location=device)
|
||||
ppg2mel_model.load_state_dict(ckpt["model"])
|
||||
ppg2mel_model.eval()
|
||||
return ppg2mel_model
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert(args):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
output_dir = args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
step = os.path.basename(args.ppg2mel_model_file)[:-4].split("_")[-1]
|
||||
|
||||
# Build models
|
||||
print("Load PPG-model, PPG2Mel-model, Vocoder-model...")
|
||||
ppg_model = load_model(
|
||||
Path('./ppg_extractor/saved_models/24epoch.pt'),
|
||||
device,
|
||||
)
|
||||
ppg2mel_model = _build_ppg2mel_model(HpsYaml(args.ppg2mel_model_train_config), args.ppg2mel_model_file, device)
|
||||
# vocoder.load_model('./vocoder/saved_models/pretrained/g_hifigan.pt', "./vocoder/hifigan/config_16k_.json")
|
||||
vocoder.load_model('./vocoder/saved_models/24k/g_02830000.pt')
|
||||
# Data related
|
||||
ref_wav_path = args.ref_wav_path
|
||||
ref_wav = preprocess_wav(ref_wav_path)
|
||||
ref_fid = os.path.basename(ref_wav_path)[:-4]
|
||||
|
||||
# TODO: specify encoder
|
||||
speacker_encoder.load_model(Path("encoder/saved_models/pretrained_bak_5805000.pt"))
|
||||
ref_spk_dvec = speacker_encoder.embed_utterance(ref_wav)
|
||||
ref_spk_dvec = torch.from_numpy(ref_spk_dvec).unsqueeze(0).to(device)
|
||||
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
|
||||
|
||||
source_file_list = sorted(glob.glob(f"{args.wav_dir}/*.wav"))
|
||||
print(f"Number of source utterances: {len(source_file_list)}.")
|
||||
|
||||
total_rtf = 0.0
|
||||
cnt = 0
|
||||
for src_wav_path in tqdm(source_file_list):
|
||||
# Load the audio to a numpy array:
|
||||
src_wav, _ = librosa.load(src_wav_path, sr=16000)
|
||||
src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(device)
|
||||
src_wav_lengths = torch.LongTensor([len(src_wav)]).to(device)
|
||||
ppg = ppg_model(src_wav_tensor, src_wav_lengths)
|
||||
|
||||
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
|
||||
min_len = min(ppg.shape[1], len(lf0_uv))
|
||||
|
||||
ppg = ppg[:, :min_len]
|
||||
lf0_uv = lf0_uv[:min_len]
|
||||
|
||||
start = time.time()
|
||||
_, mel_pred, att_ws = ppg2mel_model.inference(
|
||||
ppg,
|
||||
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
|
||||
spembs=ref_spk_dvec,
|
||||
)
|
||||
src_fid = os.path.basename(src_wav_path)[:-4]
|
||||
wav_fname = f"{output_dir}/vc_{src_fid}_ref_{ref_fid}_step{step}.wav"
|
||||
mel_len = mel_pred.shape[0]
|
||||
rtf = (time.time() - start) / (0.01 * mel_len)
|
||||
total_rtf += rtf
|
||||
cnt += 1
|
||||
# continue
|
||||
mel_pred= mel_pred.transpose(0, 1)
|
||||
y, output_sample_rate = vocoder.infer_waveform(mel_pred.cpu())
|
||||
sf.write(wav_fname, y.squeeze(), output_sample_rate, "PCM_16")
|
||||
|
||||
print("RTF:")
|
||||
print(total_rtf / cnt)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(description="Conversion from wave input")
|
||||
parser.add_argument(
|
||||
"--wav_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Source wave directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ref_wav_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Reference wave file path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ppg2mel_model_train_config", "-c",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Training config file (yaml file)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ppg2mel_model_file", "-m",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="ppg2mel model checkpoint file path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir", "-o",
|
||||
type=str,
|
||||
default="vc_gens_vctk_oneshot",
|
||||
help="Output folder to save the converted wave."
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
convert(args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,22 +0,0 @@
|
||||
The audio files in this folder are provided for toolbox testing and
|
||||
benchmarking purposes. These are the same reference utterances
|
||||
used by the SV2TTS authors to generate the audio samples located at:
|
||||
https://google.github.io/tacotron/publications/speaker_adaptation/index.html
|
||||
|
||||
The `p240_00000.mp3` and `p260_00000.mp3` files are compressed
|
||||
versions of audios from the VCTK corpus available at:
|
||||
https://datashare.is.ed.ac.uk/handle/10283/3443
|
||||
VCTK.txt contains the copyright notices and licensing information.
|
||||
|
||||
The `1320_00000.mp3`, `3575_00000.mp3`, `6829_00000.mp3`
|
||||
and `8230_00000.mp3` files are compressed versions of audios
|
||||
from the LibriSpeech dataset available at: https://openslr.org/12
|
||||
For these files, the following notice applies:
|
||||
```
|
||||
LibriSpeech (c) 2014 by Vassil Panayotov
|
||||
|
||||
LibriSpeech ASR corpus is licensed under a
|
||||
Creative Commons Attribution 4.0 International License.
|
||||
|
||||
See <http://creativecommons.org/licenses/by/4.0/>.
|
||||
```
|
||||
BIN
samples/T0055G0013S0005.wav
Normal file
BIN
samples/T0055G0013S0005.wav
Normal file
Binary file not shown.
@@ -1,94 +0,0 @@
|
||||
---------------------------------------------------------------------
|
||||
CSTR VCTK Corpus
|
||||
English Multi-speaker Corpus for CSTR Voice Cloning Toolkit
|
||||
|
||||
(Version 0.92)
|
||||
RELEASE September 2019
|
||||
The Centre for Speech Technology Research
|
||||
University of Edinburgh
|
||||
Copyright (c) 2019
|
||||
|
||||
Junichi Yamagishi
|
||||
jyamagis@inf.ed.ac.uk
|
||||
---------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
|
||||
This CSTR VCTK Corpus includes speech data uttered by 110 English
|
||||
speakers with various accents. Each speaker reads out about 400
|
||||
sentences, which were selected from a newspaper, the rainbow passage
|
||||
and an elicitation paragraph used for the speech accent archive.
|
||||
|
||||
The newspaper texts were taken from Herald Glasgow, with permission
|
||||
from Herald & Times Group. Each speaker has a different set of the
|
||||
newspaper texts selected based a greedy algorithm that increases the
|
||||
contextual and phonetic coverage. The details of the text selection
|
||||
algorithms are described in the following paper:
|
||||
|
||||
C. Veaux, J. Yamagishi and S. King,
|
||||
"The voice bank corpus: Design, collection and data analysis of
|
||||
a large regional accent speech database,"
|
||||
https://doi.org/10.1109/ICSDA.2013.6709856
|
||||
|
||||
The rainbow passage and elicitation paragraph are the same for all
|
||||
speakers. The rainbow passage can be found at International Dialects
|
||||
of English Archive:
|
||||
(http://web.ku.edu/~idea/readings/rainbow.htm). The elicitation
|
||||
paragraph is identical to the one used for the speech accent archive
|
||||
(http://accent.gmu.edu). The details of the the speech accent archive
|
||||
can be found at
|
||||
http://www.ualberta.ca/~aacl2009/PDFs/WeinbergerKunath2009AACL.pdf
|
||||
|
||||
All speech data was recorded using an identical recording setup: an
|
||||
omni-directional microphone (DPA 4035) and a small diaphragm condenser
|
||||
microphone with very wide bandwidth (Sennheiser MKH 800), 96kHz
|
||||
sampling frequency at 24 bits and in a hemi-anechoic chamber of
|
||||
the University of Edinburgh. (However, two speakers, p280 and p315
|
||||
had technical issues of the audio recordings using MKH 800).
|
||||
All recordings were converted into 16 bits, were downsampled to
|
||||
48 kHz, and were manually end-pointed.
|
||||
|
||||
This corpus was originally aimed for HMM-based text-to-speech synthesis
|
||||
systems, especially for speaker-adaptive HMM-based speech synthesis
|
||||
that uses average voice models trained on multiple speakers and speaker
|
||||
adaptation technologies. This corpus is also suitable for DNN-based
|
||||
multi-speaker text-to-speech synthesis systems and waveform modeling.
|
||||
|
||||
COPYING
|
||||
|
||||
This corpus is licensed under the Creative Commons License: Attribution 4.0 International
|
||||
http://creativecommons.org/licenses/by/4.0/legalcode
|
||||
|
||||
VCTK VARIANTS
|
||||
There are several variants of the VCTK corpus:
|
||||
Speech enhancement
|
||||
- Noisy speech database for training speech enhancement algorithms and TTS models where we added various types of noises to VCTK artificially: http://dx.doi.org/10.7488/ds/2117
|
||||
- Reverberant speech database for training speech dereverberation algorithms and TTS models where we added various types of reverberantion to VCTK artificially http://dx.doi.org/10.7488/ds/1425
|
||||
- Noisy reverberant speech database for training speech enhancement algorithms and TTS models http://dx.doi.org/10.7488/ds/2139
|
||||
- Device Recorded VCTK where speech signals of the VCTK corpus were played back and re-recorded in office environments using relatively inexpensive consumer devices http://dx.doi.org/10.7488/ds/2316
|
||||
- The Microsoft Scalable Noisy Speech Dataset (MS-SNSD) https://github.com/microsoft/MS-SNSD
|
||||
|
||||
ASV and anti-spoofing
|
||||
- Spoofing and Anti-Spoofing (SAS) corpus, which is a collection of synthetic speech signals produced by nine techniques, two of which are speech synthesis, and seven are voice conversion. All of them were built using the VCTK corpus. http://dx.doi.org/10.7488/ds/252
|
||||
- Automatic Speaker Verification Spoofing and Countermeasures Challenge (ASVspoof 2015) Database. This database consists of synthetic speech signals produced by ten techniques and this has been used in the first Automatic Speaker Verification Spoofing and Countermeasures Challenge (ASVspoof 2015) http://dx.doi.org/10.7488/ds/298
|
||||
- ASVspoof 2019: The 3rd Automatic Speaker Verification Spoofing and Countermeasures Challenge database. This database has been used in the 3rd Automatic Speaker Verification Spoofing and Countermeasures Challenge (ASVspoof 2019) https://doi.org/10.7488/ds/2555
|
||||
|
||||
|
||||
ACKNOWLEDGEMENTS
|
||||
|
||||
The CSTR VCTK Corpus was constructed by:
|
||||
|
||||
Christophe Veaux (University of Edinburgh)
|
||||
Junichi Yamagishi (University of Edinburgh)
|
||||
Kirsten MacDonald
|
||||
|
||||
The research leading to these results was partly funded from EPSRC
|
||||
grants EP/I031022/1 (NST) and EP/J002526/1 (CAF), from the RSE-NSFC
|
||||
grant (61111130120), and from the JST CREST (uDialogue).
|
||||
|
||||
Please cite this corpus as follows:
|
||||
Christophe Veaux, Junichi Yamagishi, Kirsten MacDonald,
|
||||
"CSTR VCTK Corpus: English Multi-speaker Corpus for CSTR Voice Cloning Toolkit",
|
||||
The Centre for Speech Technology Research (CSTR),
|
||||
University of Edinburgh
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -167,7 +167,7 @@ def _mel_to_linear(mel_spectrogram, hparams):
|
||||
|
||||
def _build_mel_basis(hparams):
|
||||
assert hparams.fmax <= hparams.sample_rate // 2
|
||||
return librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels,
|
||||
return librosa.filters.mel(sr=hparams.sample_rate, n_fft=hparams.n_fft, n_mels=hparams.num_mels,
|
||||
fmin=hparams.fmin, fmax=hparams.fmax)
|
||||
|
||||
def _amp_to_db(x, hparams):
|
||||
|
||||
13
synthesizer/gst_hyperparameters.py
Normal file
13
synthesizer/gst_hyperparameters.py
Normal file
@@ -0,0 +1,13 @@
|
||||
class GSTHyperparameters():
|
||||
E = 512
|
||||
|
||||
# reference encoder
|
||||
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
||||
|
||||
# style token layer
|
||||
token_num = 10
|
||||
# token_emb_size = 256
|
||||
num_heads = 8
|
||||
|
||||
n_mels = 256 # Number of Mel banks to generate
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import ast
|
||||
import pprint
|
||||
import json
|
||||
|
||||
class HParams(object):
|
||||
def __init__(self, **kwargs): self.__dict__.update(kwargs)
|
||||
@@ -18,6 +19,19 @@ class HParams(object):
|
||||
self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
|
||||
return self
|
||||
|
||||
def loadJson(self, dict):
|
||||
print("\Loading the json with %s\n", dict)
|
||||
for k in dict.keys():
|
||||
if k not in ["tts_schedule", "tts_finetune_layers"]:
|
||||
self.__dict__[k] = dict[k]
|
||||
return self
|
||||
|
||||
def dumpJson(self, fp):
|
||||
print("\Saving the json with %s\n", fp)
|
||||
with fp.open("w", encoding="utf-8") as f:
|
||||
json.dump(self.__dict__, f)
|
||||
return self
|
||||
|
||||
hparams = HParams(
|
||||
### Signal Processing (used in both synthesizer and vocoder)
|
||||
sample_rate = 16000,
|
||||
@@ -49,19 +63,24 @@ hparams = HParams(
|
||||
# frame that has all values < -3.4
|
||||
|
||||
### Tacotron Training
|
||||
tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule
|
||||
(2, 5e-4, 40_000, 12), # (r, lr, step, batch_size)
|
||||
(2, 2e-4, 80_000, 12), #
|
||||
(2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames
|
||||
(2, 3e-5, 320_000, 12), # synthesized for each decoder iteration)
|
||||
(2, 1e-5, 640_000, 12)], # lr = learning rate
|
||||
tts_schedule = [(2, 1e-3, 10_000, 12), # Progressive training schedule
|
||||
(2, 5e-4, 15_000, 12), # (r, lr, step, batch_size)
|
||||
(2, 2e-4, 20_000, 12), # (r, lr, step, batch_size)
|
||||
(2, 1e-4, 30_000, 12), #
|
||||
(2, 5e-5, 40_000, 12), #
|
||||
(2, 1e-5, 60_000, 12), #
|
||||
(2, 5e-6, 160_000, 12), # r = reduction factor (# of mel frames
|
||||
(2, 3e-6, 320_000, 12), # synthesized for each decoder iteration)
|
||||
(2, 1e-6, 640_000, 12)], # lr = learning rate
|
||||
|
||||
tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
|
||||
tts_eval_interval = 500, # Number of steps between model evaluation (sample generation)
|
||||
# Set to -1 to generate after completing epoch, or 0 to disable
|
||||
|
||||
tts_eval_num_samples = 1, # Makes this number of samples
|
||||
|
||||
## For finetune usage, if set, only selected layers will be trained, available: encoder,encoder_proj,gst,decoder,postnet,post_proj
|
||||
tts_finetune_layers = [],
|
||||
|
||||
### Data Preprocessing
|
||||
max_mel_frames = 900,
|
||||
rescale = True,
|
||||
@@ -86,4 +105,6 @@ hparams = HParams(
|
||||
speaker_embedding_size = 256, # Dimension for the speaker embedding
|
||||
silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
|
||||
utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded
|
||||
use_gst = True, # Whether to use global style token
|
||||
use_ser_for_gst = True, # Whether to use speaker embedding referenced for global style token
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Union, List
|
||||
import numpy as np
|
||||
import librosa
|
||||
from utils import logmmse
|
||||
import json
|
||||
from pypinyin import lazy_pinyin, Style
|
||||
|
||||
class Synthesizer:
|
||||
@@ -44,6 +45,11 @@ class Synthesizer:
|
||||
return self._model is not None
|
||||
|
||||
def load(self):
|
||||
# Try to scan config file
|
||||
model_config_fpaths = list(self.model_fpath.parent.rglob("*.json"))
|
||||
if len(model_config_fpaths)>0 and model_config_fpaths[0].exists():
|
||||
with model_config_fpaths[0].open("r", encoding="utf-8") as f:
|
||||
hparams.loadJson(json.load(f))
|
||||
"""
|
||||
Instantiates and loads the model given the weights file that was passed in the constructor.
|
||||
"""
|
||||
@@ -62,7 +68,7 @@ class Synthesizer:
|
||||
stop_threshold=hparams.tts_stop_threshold,
|
||||
speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
|
||||
|
||||
self._model.load(self.model_fpath)
|
||||
self._model.load(self.model_fpath, self.device)
|
||||
self._model.eval()
|
||||
|
||||
if self.verbose:
|
||||
@@ -70,7 +76,7 @@ class Synthesizer:
|
||||
|
||||
def synthesize_spectrograms(self, texts: List[str],
|
||||
embeddings: Union[np.ndarray, List[np.ndarray]],
|
||||
return_alignments=False):
|
||||
return_alignments=False, style_idx=0, min_stop_token=5, steps=2000):
|
||||
"""
|
||||
Synthesizes mel spectrograms from texts and speaker embeddings.
|
||||
|
||||
@@ -125,7 +131,7 @@ class Synthesizer:
|
||||
speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
|
||||
|
||||
# Inference
|
||||
_, mels, alignments = self._model.generate(chars, speaker_embeddings)
|
||||
_, mels, alignments = self._model.generate(chars, speaker_embeddings, style_idx=style_idx, min_stop_token=min_stop_token, steps=steps)
|
||||
mels = mels.detach().cpu().numpy()
|
||||
for m in mels:
|
||||
# Trim silence from end of each spectrogram
|
||||
@@ -143,7 +149,7 @@ class Synthesizer:
|
||||
Loads and preprocesses an audio file under the same conditions the audio files were used to
|
||||
train the synthesizer.
|
||||
"""
|
||||
wav = librosa.load(str(fpath), hparams.sample_rate)[0]
|
||||
wav = librosa.load(path=str(fpath), sr=hparams.sample_rate)[0]
|
||||
if hparams.rescale:
|
||||
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
||||
# denoise
|
||||
|
||||
145
synthesizer/models/global_style_token.py
Normal file
145
synthesizer/models/global_style_token.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
import torch.nn.functional as tFunctional
|
||||
from synthesizer.gst_hyperparameters import GSTHyperparameters as hp
|
||||
from synthesizer.hparams import hparams
|
||||
|
||||
|
||||
class GlobalStyleToken(nn.Module):
|
||||
"""
|
||||
inputs: style mel spectrograms [batch_size, num_spec_frames, num_mel]
|
||||
speaker_embedding: speaker mel spectrograms [batch_size, num_spec_frames, num_mel]
|
||||
outputs: [batch_size, embedding_dim]
|
||||
"""
|
||||
def __init__(self, speaker_embedding_dim=None):
|
||||
|
||||
super().__init__()
|
||||
self.encoder = ReferenceEncoder()
|
||||
self.stl = STL(speaker_embedding_dim)
|
||||
|
||||
def forward(self, inputs, speaker_embedding=None):
|
||||
enc_out = self.encoder(inputs)
|
||||
# concat speaker_embedding according to https://github.com/mozilla/TTS/blob/master/TTS/tts/layers/gst_layers.py
|
||||
if hparams.use_ser_for_gst and speaker_embedding is not None:
|
||||
enc_out = torch.cat([enc_out, speaker_embedding], dim=-1)
|
||||
style_embed = self.stl(enc_out)
|
||||
|
||||
return style_embed
|
||||
|
||||
|
||||
class ReferenceEncoder(nn.Module):
|
||||
'''
|
||||
inputs --- [N, Ty/r, n_mels*r] mels
|
||||
outputs --- [N, ref_enc_gru_size]
|
||||
'''
|
||||
|
||||
def __init__(self):
|
||||
|
||||
super().__init__()
|
||||
K = len(hp.ref_enc_filters)
|
||||
filters = [1] + hp.ref_enc_filters
|
||||
convs = [nn.Conv2d(in_channels=filters[i],
|
||||
out_channels=filters[i + 1],
|
||||
kernel_size=(3, 3),
|
||||
stride=(2, 2),
|
||||
padding=(1, 1)) for i in range(K)]
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=hp.ref_enc_filters[i]) for i in range(K)])
|
||||
|
||||
out_channels = self.calculate_channels(hp.n_mels, 3, 2, 1, K)
|
||||
self.gru = nn.GRU(input_size=hp.ref_enc_filters[-1] * out_channels,
|
||||
hidden_size=hp.E // 2,
|
||||
batch_first=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
N = inputs.size(0)
|
||||
out = inputs.view(N, 1, -1, hp.n_mels) # [N, 1, Ty, n_mels]
|
||||
for conv, bn in zip(self.convs, self.bns):
|
||||
out = conv(out)
|
||||
out = bn(out)
|
||||
out = tFunctional.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
||||
|
||||
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
||||
T = out.size(1)
|
||||
N = out.size(0)
|
||||
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
||||
|
||||
self.gru.flatten_parameters()
|
||||
memory, out = self.gru(out) # out --- [1, N, E//2]
|
||||
|
||||
return out.squeeze(0)
|
||||
|
||||
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
|
||||
for i in range(n_convs):
|
||||
L = (L - kernel_size + 2 * pad) // stride + 1
|
||||
return L
|
||||
|
||||
|
||||
class STL(nn.Module):
|
||||
'''
|
||||
inputs --- [N, E//2]
|
||||
'''
|
||||
|
||||
def __init__(self, speaker_embedding_dim=None):
|
||||
|
||||
super().__init__()
|
||||
self.embed = nn.Parameter(torch.FloatTensor(hp.token_num, hp.E // hp.num_heads))
|
||||
d_q = hp.E // 2
|
||||
d_k = hp.E // hp.num_heads
|
||||
# self.attention = MultiHeadAttention(hp.num_heads, d_model, d_q, d_v)
|
||||
if hparams.use_ser_for_gst and speaker_embedding_dim is not None:
|
||||
d_q += speaker_embedding_dim
|
||||
self.attention = MultiHeadAttention(query_dim=d_q, key_dim=d_k, num_units=hp.E, num_heads=hp.num_heads)
|
||||
|
||||
init.normal_(self.embed, mean=0, std=0.5)
|
||||
|
||||
def forward(self, inputs):
|
||||
N = inputs.size(0)
|
||||
query = inputs.unsqueeze(1) # [N, 1, E//2]
|
||||
keys = torch.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads]
|
||||
style_embed = self.attention(query, keys)
|
||||
|
||||
return style_embed
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
'''
|
||||
input:
|
||||
query --- [N, T_q, query_dim]
|
||||
key --- [N, T_k, key_dim]
|
||||
output:
|
||||
out --- [N, T_q, num_units]
|
||||
'''
|
||||
|
||||
def __init__(self, query_dim, key_dim, num_units, num_heads):
|
||||
|
||||
super().__init__()
|
||||
self.num_units = num_units
|
||||
self.num_heads = num_heads
|
||||
self.key_dim = key_dim
|
||||
|
||||
self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
|
||||
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||
self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||
|
||||
def forward(self, query, key):
|
||||
querys = self.W_query(query) # [N, T_q, num_units]
|
||||
keys = self.W_key(key) # [N, T_k, num_units]
|
||||
values = self.W_value(key)
|
||||
|
||||
split_size = self.num_units // self.num_heads
|
||||
querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h]
|
||||
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||
|
||||
# score = softmax(QK^T / (d_k ** 0.5))
|
||||
scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
|
||||
scores = scores / (self.key_dim ** 0.5)
|
||||
scores = tFunctional.softmax(scores, dim=3)
|
||||
|
||||
# out = score * V
|
||||
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
|
||||
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
|
||||
|
||||
return out
|
||||
@@ -3,8 +3,9 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from synthesizer.models.global_style_token import GlobalStyleToken
|
||||
from synthesizer.gst_hyperparameters import GSTHyperparameters as gst_hp
|
||||
from synthesizer.hparams import hparams
|
||||
|
||||
|
||||
class HighwayNetwork(nn.Module):
|
||||
@@ -60,7 +61,7 @@ class Encoder(nn.Module):
|
||||
idx = 1
|
||||
|
||||
# Start by making a copy of each speaker embedding to match the input text length
|
||||
# The output of this has size (batch_size, num_chars * tts_embed_dims)
|
||||
# The output of this has size (batch_size, num_chars * speaker_embedding_size)
|
||||
speaker_embedding_size = speaker_embedding.size()[idx]
|
||||
e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
|
||||
|
||||
@@ -126,7 +127,7 @@ class CBHG(nn.Module):
|
||||
# Although we `_flatten_parameters()` on init, when using DataParallel
|
||||
# the model gets replicated, making it no longer guaranteed that the
|
||||
# weights are contiguous in GPU memory. Hence, we must call it again
|
||||
self._flatten_parameters()
|
||||
self.rnn.flatten_parameters()
|
||||
|
||||
# Save these for later
|
||||
residual = x
|
||||
@@ -213,7 +214,7 @@ class LSA(nn.Module):
|
||||
self.attention = None
|
||||
|
||||
def init_attention(self, encoder_seq_proj):
|
||||
device = next(self.parameters()).device # use same device as parameters
|
||||
device = encoder_seq_proj.device # use same device as parameters
|
||||
b, t, c = encoder_seq_proj.size()
|
||||
self.cumulative = torch.zeros(b, t, device=device)
|
||||
self.attention = torch.zeros(b, t, device=device)
|
||||
@@ -255,16 +256,17 @@ class Decoder(nn.Module):
|
||||
self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
|
||||
dropout=dropout)
|
||||
self.attn_net = LSA(decoder_dims)
|
||||
if hparams.use_gst:
|
||||
speaker_embedding_size += gst_hp.E
|
||||
self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
|
||||
self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
|
||||
self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
|
||||
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
|
||||
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
|
||||
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
|
||||
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
|
||||
|
||||
def zoneout(self, prev, current, p=0.1):
|
||||
device = next(self.parameters()).device # Use same device as parameters
|
||||
mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
|
||||
def zoneout(self, prev, current, device, p=0.1):
|
||||
mask = torch.zeros(prev.size(),device=device).bernoulli_(p)
|
||||
return prev * mask + current * (1 - mask)
|
||||
|
||||
def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
|
||||
@@ -272,7 +274,7 @@ class Decoder(nn.Module):
|
||||
|
||||
# Need this for reshaping mels
|
||||
batch_size = encoder_seq.size(0)
|
||||
|
||||
device = encoder_seq.device
|
||||
# Unpack the hidden and cell states
|
||||
attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
|
||||
rnn1_cell, rnn2_cell = cell_states
|
||||
@@ -298,7 +300,7 @@ class Decoder(nn.Module):
|
||||
# Compute first Residual RNN
|
||||
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
|
||||
if self.training:
|
||||
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
|
||||
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next,device=device)
|
||||
else:
|
||||
rnn1_hidden = rnn1_hidden_next
|
||||
x = x + rnn1_hidden
|
||||
@@ -306,7 +308,7 @@ class Decoder(nn.Module):
|
||||
# Compute second Residual RNN
|
||||
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
|
||||
if self.training:
|
||||
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
|
||||
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next, device=device)
|
||||
else:
|
||||
rnn2_hidden = rnn2_hidden_next
|
||||
x = x + rnn2_hidden
|
||||
@@ -337,7 +339,12 @@ class Tacotron(nn.Module):
|
||||
self.speaker_embedding_size = speaker_embedding_size
|
||||
self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
|
||||
encoder_K, num_highways, dropout)
|
||||
self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
|
||||
project_dims = encoder_dims + speaker_embedding_size
|
||||
if hparams.use_gst:
|
||||
project_dims += gst_hp.E
|
||||
self.encoder_proj = nn.Linear(project_dims, decoder_dims, bias=False)
|
||||
if hparams.use_gst:
|
||||
self.gst = GlobalStyleToken(speaker_embedding_size)
|
||||
self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
|
||||
dropout, speaker_embedding_size)
|
||||
self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
|
||||
@@ -357,12 +364,19 @@ class Tacotron(nn.Module):
|
||||
@r.setter
|
||||
def r(self, value):
|
||||
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
||||
|
||||
@staticmethod
|
||||
def _concat_speaker_embedding(outputs, speaker_embeddings):
|
||||
speaker_embeddings_ = speaker_embeddings.expand(
|
||||
outputs.size(0), outputs.size(1), -1)
|
||||
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
|
||||
return outputs
|
||||
|
||||
def forward(self, x, m, speaker_embedding):
|
||||
device = next(self.parameters()).device # use same device as parameters
|
||||
def forward(self, texts, mels, speaker_embedding):
|
||||
device = texts.device # use same device as parameters
|
||||
|
||||
self.step += 1
|
||||
batch_size, _, steps = m.size()
|
||||
batch_size, _, steps = mels.size()
|
||||
|
||||
# Initialise all hidden states and pack into tuple
|
||||
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
||||
@@ -379,11 +393,20 @@ class Tacotron(nn.Module):
|
||||
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
||||
|
||||
# Need an initial context vector
|
||||
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
|
||||
size = self.encoder_dims + self.speaker_embedding_size
|
||||
if hparams.use_gst:
|
||||
size += gst_hp.E
|
||||
context_vec = torch.zeros(batch_size, size, device=device)
|
||||
|
||||
# SV2TTS: Run the encoder with the speaker embedding
|
||||
# The projection avoids unnecessary matmuls in the decoder loop
|
||||
encoder_seq = self.encoder(x, speaker_embedding)
|
||||
encoder_seq = self.encoder(texts, speaker_embedding)
|
||||
# put after encoder
|
||||
if hparams.use_gst and self.gst is not None:
|
||||
style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced
|
||||
# style_embed = style_embed.expand_as(encoder_seq)
|
||||
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
||||
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
||||
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
||||
|
||||
# Need a couple of lists for outputs
|
||||
@@ -391,10 +414,10 @@ class Tacotron(nn.Module):
|
||||
|
||||
# Run the decoder loop
|
||||
for t in range(0, steps, self.r):
|
||||
prenet_in = m[:, :, t - 1] if t > 0 else go_frame
|
||||
prenet_in = mels[:, :, t - 1] if t > 0 else go_frame
|
||||
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
||||
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
||||
hidden_states, cell_states, context_vec, t, x)
|
||||
hidden_states, cell_states, context_vec, t, texts)
|
||||
mel_outputs.append(mel_frames)
|
||||
attn_scores.append(scores)
|
||||
stop_outputs.extend([stop_tokens] * self.r)
|
||||
@@ -414,9 +437,9 @@ class Tacotron(nn.Module):
|
||||
|
||||
return mel_outputs, linear, attn_scores, stop_outputs
|
||||
|
||||
def generate(self, x, speaker_embedding=None, steps=2000):
|
||||
def generate(self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5):
|
||||
self.eval()
|
||||
device = next(self.parameters()).device # use same device as parameters
|
||||
device = x.device # use same device as parameters
|
||||
|
||||
batch_size, _ = x.size()
|
||||
|
||||
@@ -435,11 +458,30 @@ class Tacotron(nn.Module):
|
||||
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
||||
|
||||
# Need an initial context vector
|
||||
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
|
||||
size = self.encoder_dims + self.speaker_embedding_size
|
||||
if hparams.use_gst:
|
||||
size += gst_hp.E
|
||||
context_vec = torch.zeros(batch_size, size, device=device)
|
||||
|
||||
# SV2TTS: Run the encoder with the speaker embedding
|
||||
# The projection avoids unnecessary matmuls in the decoder loop
|
||||
encoder_seq = self.encoder(x, speaker_embedding)
|
||||
|
||||
# put after encoder
|
||||
if hparams.use_gst and self.gst is not None:
|
||||
if style_idx >= 0 and style_idx < 10:
|
||||
query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
|
||||
if device.type == 'cuda':
|
||||
query = query.cuda()
|
||||
gst_embed = torch.tanh(self.gst.stl.embed)
|
||||
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
|
||||
style_embed = self.gst.stl.attention(query, key)
|
||||
else:
|
||||
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
|
||||
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
||||
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
||||
# style_embed = style_embed.expand_as(encoder_seq)
|
||||
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
||||
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
||||
|
||||
# Need a couple of lists for outputs
|
||||
@@ -455,7 +497,7 @@ class Tacotron(nn.Module):
|
||||
attn_scores.append(scores)
|
||||
stop_outputs.extend([stop_tokens] * self.r)
|
||||
# Stop the loop when all stop tokens in batch exceed threshold
|
||||
if (stop_tokens > 0.5).all() and t > 10: break
|
||||
if (stop_tokens * 10 > min_stop_token).all() and t > 10: break
|
||||
|
||||
# Concat the mel outputs into sequence
|
||||
mel_outputs = torch.cat(mel_outputs, dim=2)
|
||||
@@ -479,6 +521,15 @@ class Tacotron(nn.Module):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1: nn.init.xavier_uniform_(p)
|
||||
|
||||
def finetune_partial(self, whitelist_layers):
|
||||
self.zero_grad()
|
||||
for name, child in self.named_children():
|
||||
if name in whitelist_layers:
|
||||
print("Trainable Layer: %s" % name)
|
||||
print("Trainable Parameters: %.3f" % sum([np.prod(p.size()) for p in child.parameters()]))
|
||||
for param in child.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def get_step(self):
|
||||
return self.step.data.item()
|
||||
|
||||
@@ -490,11 +541,10 @@ class Tacotron(nn.Module):
|
||||
with open(path, "a") as f:
|
||||
print(msg, file=f)
|
||||
|
||||
def load(self, path, optimizer=None):
|
||||
def load(self, path, device, optimizer=None):
|
||||
# Use device of model params as location for loaded state
|
||||
device = next(self.parameters()).device
|
||||
checkpoint = torch.load(str(path), map_location=device)
|
||||
self.load_state_dict(checkpoint["model_state"])
|
||||
self.load_state_dict(checkpoint["model_state"], strict=False)
|
||||
|
||||
if "optimizer_state" in checkpoint and optimizer is not None:
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
||||
|
||||
@@ -6,8 +6,8 @@ from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from encoder import inference as encoder
|
||||
from synthesizer.preprocess_speaker import preprocess_speaker_general, preprocess_speaker_bznsyp
|
||||
from synthesizer.preprocess_transcript import preprocess_transcript_bznsyp
|
||||
from synthesizer.preprocess_speaker import preprocess_speaker_general
|
||||
from synthesizer.preprocess_transcript import preprocess_transcript_aishell3, preprocess_transcript_magicdata
|
||||
|
||||
data_info = {
|
||||
"aidatatang_200zh": {
|
||||
@@ -18,19 +18,20 @@ data_info = {
|
||||
"magicdata": {
|
||||
"subfolders": ["train"],
|
||||
"trans_filepath": "train/TRANS.txt",
|
||||
"speak_func": preprocess_speaker_general
|
||||
"speak_func": preprocess_speaker_general,
|
||||
"transcript_func": preprocess_transcript_magicdata,
|
||||
},
|
||||
"aishell3":{
|
||||
"subfolders": ["train/wav"],
|
||||
"trans_filepath": "train/content.txt",
|
||||
"speak_func": preprocess_speaker_general,
|
||||
"transcript_func": preprocess_transcript_aishell3,
|
||||
},
|
||||
"data_aishell":{
|
||||
"subfolders": ["wav/train"],
|
||||
"trans_filepath": "transcript/aishell_transcript_v0.8.txt",
|
||||
"speak_func": preprocess_speaker_general
|
||||
},
|
||||
"BZNSYP":{
|
||||
"subfolders": ["Wave"],
|
||||
"trans_filepath": "ProsodyLabeling/000001-010000.txt",
|
||||
"speak_func": preprocess_speaker_bznsyp,
|
||||
"transcript_func": preprocess_transcript_bznsyp,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
def preprocess_dataset(datasets_root: Path, out_dir: Path, n_processes: int,
|
||||
|
||||
@@ -61,9 +61,9 @@ def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
|
||||
return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, len(wav), mel_frames, text
|
||||
|
||||
|
||||
def _split_on_silences_aidatatang_200zh(wav_fpath, words, hparams):
|
||||
def _split_on_silences(wav_fpath, words, hparams):
|
||||
# Load the audio waveform
|
||||
wav, _ = librosa.load(wav_fpath, hparams.sample_rate)
|
||||
wav, _ = librosa.load(wav_fpath, sr= hparams.sample_rate)
|
||||
wav = librosa.effects.trim(wav, top_db= 40, frame_length=2048, hop_length=512)[0]
|
||||
if hparams.rescale:
|
||||
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
||||
@@ -81,24 +81,19 @@ def _split_on_silences_aidatatang_200zh(wav_fpath, words, hparams):
|
||||
return wav, res
|
||||
|
||||
def preprocess_speaker_general(speaker_dir, out_dir: Path, skip_existing: bool, hparams, dict_info, no_alignments: bool):
|
||||
wav_fpath_list = speaker_dir.glob("*.wav")
|
||||
return preprocess_speaker_internal(wav_fpath_list, out_dir, skip_existing, hparams, dict_info, no_alignments)
|
||||
|
||||
def preprocess_speaker_bznsyp(speaker_dir, out_dir: Path, skip_existing: bool, hparams, dict_info, no_alignments: bool):
|
||||
wav_fpath_list = [speaker_dir]
|
||||
return preprocess_speaker_internal(wav_fpath_list, out_dir, skip_existing, hparams, dict_info, no_alignments)
|
||||
|
||||
def preprocess_speaker_internal(wav_fpath_list, out_dir: Path, skip_existing: bool, hparams, dict_info, no_alignments: bool):
|
||||
# Iterate over each wav
|
||||
metadata = []
|
||||
for wav_fpath in wav_fpath_list:
|
||||
words = dict_info.get(wav_fpath.name.split(".")[0])
|
||||
words = dict_info.get(wav_fpath.name) if not words else words # try with wav
|
||||
if not words:
|
||||
print("no wordS")
|
||||
continue
|
||||
sub_basename = "%s_%02d" % (wav_fpath.name, 0)
|
||||
wav, text = _split_on_silences_aidatatang_200zh(wav_fpath, words, hparams)
|
||||
metadata.append(_process_utterance(wav, text, out_dir, sub_basename,
|
||||
skip_existing, hparams))
|
||||
return [m for m in metadata if m is not None]
|
||||
extensions = ["*.wav", "*.flac", "*.mp3"]
|
||||
for extension in extensions:
|
||||
wav_fpath_list = speaker_dir.glob(extension)
|
||||
# Iterate over each wav
|
||||
for wav_fpath in wav_fpath_list:
|
||||
words = dict_info.get(wav_fpath.name.split(".")[0])
|
||||
words = dict_info.get(wav_fpath.name) if not words else words # try with wav
|
||||
if not words:
|
||||
print("no wordS")
|
||||
continue
|
||||
sub_basename = "%s_%02d" % (wav_fpath.name, 0)
|
||||
wav, text = _split_on_silences(wav_fpath, words, hparams)
|
||||
metadata.append(_process_utterance(wav, text, out_dir, sub_basename,
|
||||
skip_existing, hparams))
|
||||
return [m for m in metadata if m is not None]
|
||||
|
||||
@@ -1,10 +1,18 @@
|
||||
def preprocess_transcript_bznsyp(dict_info, dict_transcript):
|
||||
transList = []
|
||||
for t in dict_transcript:
|
||||
transList.append(t)
|
||||
for i in range(0, len(transList), 2):
|
||||
if not transList[i]:
|
||||
def preprocess_transcript_aishell3(dict_info, dict_transcript):
|
||||
for v in dict_transcript:
|
||||
if not v:
|
||||
continue
|
||||
key = transList[i].split("\t")[0]
|
||||
transcript = transList[i+1].strip().replace("\n","").replace("\t"," ")
|
||||
dict_info[key] = transcript
|
||||
v = v.strip().replace("\n","").replace("\t"," ").split(" ")
|
||||
transList = []
|
||||
for i in range(2, len(v), 2):
|
||||
transList.append(v[i])
|
||||
dict_info[v[0]] = " ".join(transList)
|
||||
|
||||
|
||||
def preprocess_transcript_magicdata(dict_info, dict_transcript):
|
||||
for v in dict_transcript:
|
||||
if not v:
|
||||
continue
|
||||
v = v.strip().replace("\n","").replace("\t"," ").split(" ")
|
||||
dict_info[v[0]] = " ".join(v[2:])
|
||||
|
||||
@@ -45,7 +45,7 @@ def run_synthesis(in_dir, out_dir, model_dir, hparams):
|
||||
model_dir = Path(model_dir)
|
||||
model_fpath = model_dir.joinpath(model_dir.stem).with_suffix(".pt")
|
||||
print("\nLoading weights at %s" % model_fpath)
|
||||
model.load(model_fpath)
|
||||
model.load(model_fpath, device)
|
||||
print("Tacotron weights loaded from step %d" % model.step)
|
||||
|
||||
# Synthesize using same reduction factor as the model is currently trained
|
||||
|
||||
@@ -73,6 +73,7 @@ def collate_synthesizer(batch):
|
||||
|
||||
# Speaker embedding (SV2TTS)
|
||||
embeds = [x[2] for x in batch]
|
||||
embeds = np.stack(embeds)
|
||||
|
||||
# Index (for vocoder preprocessing)
|
||||
indices = [x[3] for x in batch]
|
||||
|
||||
@@ -2,15 +2,17 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from synthesizer import audio
|
||||
from synthesizer.models.tacotron import Tacotron
|
||||
from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
|
||||
from synthesizer.utils import ValueWindow, data_parallel_workaround
|
||||
from synthesizer.utils.plot import plot_spectrogram
|
||||
from synthesizer.utils.plot import plot_spectrogram, plot_spectrogram_and_trace
|
||||
from synthesizer.utils.symbols import symbols
|
||||
from synthesizer.utils.text import sequence_to_text
|
||||
from vocoder.display import *
|
||||
from datetime import datetime
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import sys
|
||||
@@ -23,7 +25,7 @@ def time_string():
|
||||
return datetime.now().strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||
backup_every: int, force_restart:bool, hparams):
|
||||
backup_every: int, log_every:int, force_restart:bool, hparams):
|
||||
|
||||
syn_dir = Path(syn_dir)
|
||||
models_dir = Path(models_dir)
|
||||
@@ -74,6 +76,13 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||
if num_chars != loaded_shape[0]:
|
||||
print("WARNING: you are using compatible mode due to wrong sympols length, please modify varible _characters in `utils\symbols.py`")
|
||||
num_chars != loaded_shape[0]
|
||||
# Try to scan config file
|
||||
model_config_fpaths = list(weights_fpath.parent.rglob("*.json"))
|
||||
if len(model_config_fpaths)>0 and model_config_fpaths[0].exists():
|
||||
with model_config_fpaths[0].open("r", encoding="utf-8") as f:
|
||||
hparams.loadJson(json.load(f))
|
||||
else: # save a config
|
||||
hparams.dumpJson(weights_fpath.parent.joinpath(run_id).with_suffix(".json"))
|
||||
|
||||
|
||||
model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
||||
@@ -92,7 +101,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||
speaker_embedding_size=hparams.speaker_embedding_size).to(device)
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = optim.Adam(model.parameters())
|
||||
optimizer = optim.Adam(model.parameters(), amsgrad=True)
|
||||
|
||||
# Load the weights
|
||||
if force_restart or not weights_fpath.exists():
|
||||
@@ -110,7 +119,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||
|
||||
else:
|
||||
print("\nLoading weights at %s" % weights_fpath)
|
||||
model.load(weights_fpath, optimizer)
|
||||
model.load(weights_fpath, device, optimizer)
|
||||
print("Tacotron weights loaded from step %d" % model.step)
|
||||
|
||||
# Initialize the dataset
|
||||
@@ -123,6 +132,9 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||
shuffle=True,
|
||||
pin_memory=True)
|
||||
|
||||
# tracing training step
|
||||
sw = SummaryWriter(log_dir=model_dir.joinpath("logs"))
|
||||
|
||||
for i, session in enumerate(hparams.tts_schedule):
|
||||
current_step = model.get_step()
|
||||
|
||||
@@ -142,7 +154,6 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||
continue
|
||||
|
||||
model.r = r
|
||||
|
||||
# Begin the training
|
||||
simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
|
||||
("Batch Size", batch_size),
|
||||
@@ -151,6 +162,8 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||
|
||||
for p in optimizer.param_groups:
|
||||
p["lr"] = lr
|
||||
if hparams.tts_finetune_layers is not None and len(hparams.tts_finetune_layers) > 0:
|
||||
model.finetune_partial(hparams.tts_finetune_layers)
|
||||
|
||||
data_loader = DataLoader(dataset,
|
||||
collate_fn=collate_synthesizer,
|
||||
@@ -208,18 +221,23 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||
step = model.get_step()
|
||||
k = step // 1000
|
||||
|
||||
|
||||
msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | {1./time_window.average:#.2} steps/s | Step: {k}k | "
|
||||
stream(msg)
|
||||
|
||||
if log_every != 0 and step % log_every == 0 :
|
||||
sw.add_scalar("training/loss", loss_window.average, step)
|
||||
|
||||
# Backup or save model as appropriate
|
||||
if backup_every != 0 and step % backup_every == 0 :
|
||||
backup_fpath = Path("{}/{}_{}k.pt".format(str(weights_fpath.parent), run_id, k))
|
||||
backup_fpath = Path("{}/{}_{}.pt".format(str(weights_fpath.parent), run_id, step))
|
||||
model.save(backup_fpath, optimizer)
|
||||
|
||||
if save_every != 0 and step % save_every == 0 :
|
||||
# Must save latest optimizer state to ensure that resuming training
|
||||
# doesn't produce artifacts
|
||||
model.save(weights_fpath, optimizer)
|
||||
|
||||
|
||||
# Evaluate model to generate samples
|
||||
epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done
|
||||
@@ -233,7 +251,8 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||
mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length]
|
||||
target_spectrogram = np_now(mels[sample_idx]).T[:mel_length]
|
||||
attention_len = mel_length // model.r
|
||||
|
||||
# eval_loss = F.mse_loss(mel_prediction, target_spectrogram)
|
||||
# sw.add_scalar("validing/loss", eval_loss.item(), step)
|
||||
eval_model(attention=np_now(attention[sample_idx][:, :attention_len]),
|
||||
mel_prediction=mel_prediction,
|
||||
target_spectrogram=target_spectrogram,
|
||||
@@ -244,7 +263,8 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||
wav_dir=wav_dir,
|
||||
sample_num=sample_idx + 1,
|
||||
loss=loss,
|
||||
hparams=hparams)
|
||||
hparams=hparams,
|
||||
sw=sw)
|
||||
|
||||
# Break out of loop to update training schedule
|
||||
if step >= max_step:
|
||||
@@ -254,10 +274,11 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||
print("")
|
||||
|
||||
def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
|
||||
plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams):
|
||||
plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams, sw):
|
||||
# Save some results for evaluation
|
||||
attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num)))
|
||||
save_attention(attention, attention_path)
|
||||
# save_attention(attention, attention_path)
|
||||
save_and_trace_attention(attention, attention_path, sw, step)
|
||||
|
||||
# save predicted mel spectrogram to disk (debug)
|
||||
mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num))
|
||||
@@ -271,7 +292,15 @@ def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
|
||||
# save real and predicted mel-spectrogram plot to disk (control purposes)
|
||||
spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num))
|
||||
title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss)
|
||||
plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str,
|
||||
target_spectrogram=target_spectrogram,
|
||||
max_len=target_spectrogram.size // hparams.num_mels)
|
||||
# plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str,
|
||||
# target_spectrogram=target_spectrogram,
|
||||
# max_len=target_spectrogram.size // hparams.num_mels)
|
||||
plot_spectrogram_and_trace(
|
||||
mel_prediction,
|
||||
str(spec_fpath),
|
||||
title=title_str,
|
||||
target_spectrogram=target_spectrogram,
|
||||
max_len=target_spectrogram.size // hparams.num_mels,
|
||||
sw=sw,
|
||||
step=step)
|
||||
print("Input at step {}: {}".format(step, sequence_to_text(input_seq)))
|
||||
|
||||
@@ -74,3 +74,42 @@ def plot_spectrogram(pred_spectrogram, path, title=None, split_title=False, targ
|
||||
plt.tight_layout()
|
||||
plt.savefig(path, format="png")
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_spectrogram_and_trace(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False, sw=None, step=0):
|
||||
if max_len is not None:
|
||||
target_spectrogram = target_spectrogram[:max_len]
|
||||
pred_spectrogram = pred_spectrogram[:max_len]
|
||||
|
||||
if split_title:
|
||||
title = split_title_line(title)
|
||||
|
||||
fig = plt.figure(figsize=(10, 8))
|
||||
# Set common labels
|
||||
fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16)
|
||||
|
||||
#target spectrogram subplot
|
||||
if target_spectrogram is not None:
|
||||
ax1 = fig.add_subplot(311)
|
||||
ax2 = fig.add_subplot(312)
|
||||
|
||||
if auto_aspect:
|
||||
im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none")
|
||||
else:
|
||||
im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none")
|
||||
ax1.set_title("Target Mel-Spectrogram")
|
||||
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
|
||||
ax2.set_title("Predicted Mel-Spectrogram")
|
||||
else:
|
||||
ax2 = fig.add_subplot(211)
|
||||
|
||||
if auto_aspect:
|
||||
im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none")
|
||||
else:
|
||||
im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none")
|
||||
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(path, format="png")
|
||||
sw.add_figure("spectrogram", fig, step)
|
||||
plt.close()
|
||||
@@ -12,6 +12,7 @@ recognized_datasets = [
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("This method is deprecaded and will not be longer supported, please use 'pre.py'")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocesses audio files from datasets, encodes them as mel spectrograms "
|
||||
"and writes them to the disk. Audio files are also saved, to be used by the "
|
||||
|
||||
@@ -5,6 +5,7 @@ import argparse
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("This method is deprecaded and will not be longer supported, please use 'pre.py'")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Creates embeddings for the synthesizer from the LibriSpeech utterances.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
|
||||
@@ -21,6 +21,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \
|
||||
"Number of steps between backups of the model. Set to 0 to never make backups of the "
|
||||
"model.")
|
||||
parser.add_argument("-l", "--log_every", type=int, default=200, help= \
|
||||
"Number of steps between summary the training info in tensorboard")
|
||||
parser.add_argument("-f", "--force_restart", action="store_true", help= \
|
||||
"Do not load any saved model and restart from scratch.")
|
||||
parser.add_argument("--hparams", default="",
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from toolbox.ui import UI
|
||||
from encoder import inference as encoder
|
||||
from synthesizer.inference import Synthesizer
|
||||
from vocoder import inference as vocoder
|
||||
from vocoder.wavernn import inference as rnn_vocoder
|
||||
from vocoder.hifigan import inference as gan_vocoder
|
||||
from pathlib import Path
|
||||
from time import perf_counter as timer
|
||||
from toolbox.utterance import Utterance
|
||||
@@ -9,9 +10,10 @@ import numpy as np
|
||||
import traceback
|
||||
import sys
|
||||
import torch
|
||||
import librosa
|
||||
import re
|
||||
from audioread.exceptions import NoBackendError
|
||||
|
||||
# 默认使用wavernn
|
||||
vocoder = rnn_vocoder
|
||||
|
||||
# Use this directory structure for your datasets, or modify it to fit your needs
|
||||
recognized_datasets = [
|
||||
@@ -45,21 +47,20 @@ recognized_datasets = [
|
||||
MAX_WAVES = 15
|
||||
|
||||
class Toolbox:
|
||||
def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, seed, no_mp3_support):
|
||||
if not no_mp3_support:
|
||||
try:
|
||||
librosa.load("samples/6829_00000.mp3")
|
||||
except NoBackendError:
|
||||
print("Librosa will be unable to open mp3 files if additional software is not installed.\n"
|
||||
"Please install ffmpeg or add the '--no_mp3_support' option to proceed without support for mp3 files.")
|
||||
exit(-1)
|
||||
def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, extractor_models_dir, convertor_models_dir, seed, no_mp3_support, vc_mode):
|
||||
self.no_mp3_support = no_mp3_support
|
||||
self.vc_mode = vc_mode
|
||||
sys.excepthook = self.excepthook
|
||||
self.datasets_root = datasets_root
|
||||
self.utterances = set()
|
||||
self.current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav
|
||||
|
||||
self.synthesizer = None # type: Synthesizer
|
||||
|
||||
# for ppg-based voice conversion
|
||||
self.extractor = None
|
||||
self.convertor = None # ppg2mel
|
||||
|
||||
self.current_wav = None
|
||||
self.waves_list = []
|
||||
self.waves_count = 0
|
||||
@@ -73,8 +74,9 @@ class Toolbox:
|
||||
self.trim_silences = False
|
||||
|
||||
# Initialize the events and the interface
|
||||
self.ui = UI()
|
||||
self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, seed)
|
||||
self.ui = UI(vc_mode)
|
||||
self.style_idx = 0
|
||||
self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, extractor_models_dir, convertor_models_dir, seed)
|
||||
self.setup_events()
|
||||
self.ui.start()
|
||||
|
||||
@@ -98,7 +100,11 @@ class Toolbox:
|
||||
self.ui.encoder_box.currentIndexChanged.connect(self.init_encoder)
|
||||
def func():
|
||||
self.synthesizer = None
|
||||
self.ui.synthesizer_box.currentIndexChanged.connect(func)
|
||||
if self.vc_mode:
|
||||
self.ui.extractor_box.currentIndexChanged.connect(self.init_extractor)
|
||||
else:
|
||||
self.ui.synthesizer_box.currentIndexChanged.connect(func)
|
||||
|
||||
self.ui.vocoder_box.currentIndexChanged.connect(self.init_vocoder)
|
||||
|
||||
# Utterance selection
|
||||
@@ -111,6 +117,11 @@ class Toolbox:
|
||||
self.ui.stop_button.clicked.connect(self.ui.stop)
|
||||
self.ui.record_button.clicked.connect(self.record)
|
||||
|
||||
# Source Utterance selection
|
||||
if self.vc_mode:
|
||||
func = lambda: self.load_soruce_button(self.ui.selected_utterance)
|
||||
self.ui.load_soruce_button.clicked.connect(func)
|
||||
|
||||
#Audio
|
||||
self.ui.setup_audio_devices(Synthesizer.sample_rate)
|
||||
|
||||
@@ -122,12 +133,17 @@ class Toolbox:
|
||||
self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
|
||||
|
||||
# Generation
|
||||
func = lambda: self.synthesize() or self.vocode()
|
||||
self.ui.generate_button.clicked.connect(func)
|
||||
self.ui.synthesize_button.clicked.connect(self.synthesize)
|
||||
self.ui.vocode_button.clicked.connect(self.vocode)
|
||||
self.ui.random_seed_checkbox.clicked.connect(self.update_seed_textbox)
|
||||
|
||||
if self.vc_mode:
|
||||
func = lambda: self.convert() or self.vocode()
|
||||
self.ui.convert_button.clicked.connect(func)
|
||||
else:
|
||||
func = lambda: self.synthesize() or self.vocode()
|
||||
self.ui.generate_button.clicked.connect(func)
|
||||
self.ui.synthesize_button.clicked.connect(self.synthesize)
|
||||
|
||||
# UMAP legend
|
||||
self.ui.clear_button.clicked.connect(self.clear_utterances)
|
||||
|
||||
@@ -140,9 +156,9 @@ class Toolbox:
|
||||
def replay_last_wav(self):
|
||||
self.ui.play(self.current_wav, Synthesizer.sample_rate)
|
||||
|
||||
def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, seed):
|
||||
def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, extractor_models_dir, convertor_models_dir, seed):
|
||||
self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True)
|
||||
self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir)
|
||||
self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, extractor_models_dir, convertor_models_dir, self.vc_mode)
|
||||
self.ui.populate_gen_options(seed, self.trim_silences)
|
||||
|
||||
def load_from_browser(self, fpath=None):
|
||||
@@ -173,7 +189,10 @@ class Toolbox:
|
||||
self.ui.log("Loaded %s" % name)
|
||||
|
||||
self.add_real_utterance(wav, name, speaker_name)
|
||||
|
||||
|
||||
def load_soruce_button(self, utterance: Utterance):
|
||||
self.selected_source_utterance = utterance
|
||||
|
||||
def record(self):
|
||||
wav = self.ui.record_one(encoder.sampling_rate, 5)
|
||||
if wav is None:
|
||||
@@ -198,7 +217,7 @@ class Toolbox:
|
||||
# Add the utterance
|
||||
utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, False)
|
||||
self.utterances.add(utterance)
|
||||
self.ui.register_utterance(utterance)
|
||||
self.ui.register_utterance(utterance, self.vc_mode)
|
||||
|
||||
# Plot it
|
||||
self.ui.draw_embed(embed, name, "current")
|
||||
@@ -236,7 +255,8 @@ class Toolbox:
|
||||
texts = processed_texts
|
||||
embed = self.ui.selected_utterance.embed
|
||||
embeds = [embed] * len(texts)
|
||||
specs = self.synthesizer.synthesize_spectrograms(texts, embeds)
|
||||
min_token = int(self.ui.token_slider.value())
|
||||
specs = self.synthesizer.synthesize_spectrograms(texts, embeds, style_idx=int(self.ui.style_slider.value()), min_stop_token=min_token, steps=int(self.ui.length_slider.value())*200)
|
||||
breaks = [spec.shape[1] for spec in specs]
|
||||
spec = np.concatenate(specs, axis=1)
|
||||
|
||||
@@ -270,7 +290,7 @@ class Toolbox:
|
||||
self.ui.set_loading(i, seq_len)
|
||||
if self.ui.current_vocoder_fpath is not None:
|
||||
self.ui.log("")
|
||||
wav = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
|
||||
wav, sample_rate = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
|
||||
else:
|
||||
self.ui.log("Waveform generation with Griffin-Lim... ")
|
||||
wav = Synthesizer.griffin_lim(spec)
|
||||
@@ -281,7 +301,7 @@ class Toolbox:
|
||||
b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
|
||||
b_starts = np.concatenate(([0], b_ends[:-1]))
|
||||
wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)]
|
||||
breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks)
|
||||
breaks = [np.zeros(int(0.15 * sample_rate))] * len(breaks)
|
||||
wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
|
||||
|
||||
# Trim excessive silences
|
||||
@@ -290,7 +310,7 @@ class Toolbox:
|
||||
|
||||
# Play it
|
||||
wav = wav / np.abs(wav).max() * 0.97
|
||||
self.ui.play(wav, Synthesizer.sample_rate)
|
||||
self.ui.play(wav, sample_rate)
|
||||
|
||||
# Name it (history displayed in combobox)
|
||||
# TODO better naming for the combobox items?
|
||||
@@ -332,6 +352,67 @@ class Toolbox:
|
||||
self.ui.draw_embed(embed, name, "generated")
|
||||
self.ui.draw_umap_projections(self.utterances)
|
||||
|
||||
def convert(self):
|
||||
self.ui.log("Extract PPG and Converting...")
|
||||
self.ui.set_loading(1)
|
||||
|
||||
# Init
|
||||
if self.convertor is None:
|
||||
self.init_convertor()
|
||||
if self.extractor is None:
|
||||
self.init_extractor()
|
||||
|
||||
src_wav = self.selected_source_utterance.wav
|
||||
|
||||
# Compute the ppg
|
||||
if not self.extractor is None:
|
||||
ppg = self.extractor.extract_from_wav(src_wav)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
ref_wav = self.ui.selected_utterance.wav
|
||||
# Import necessary dependency of Voice Conversion
|
||||
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
|
||||
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
|
||||
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
|
||||
min_len = min(ppg.shape[1], len(lf0_uv))
|
||||
ppg = ppg[:, :min_len]
|
||||
lf0_uv = lf0_uv[:min_len]
|
||||
_, mel_pred, att_ws = self.convertor.inference(
|
||||
ppg,
|
||||
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
|
||||
spembs=torch.from_numpy(self.ui.selected_utterance.embed).unsqueeze(0).to(device),
|
||||
)
|
||||
mel_pred= mel_pred.transpose(0, 1)
|
||||
breaks = [mel_pred.shape[1]]
|
||||
mel_pred= mel_pred.detach().cpu().numpy()
|
||||
self.ui.draw_spec(mel_pred, "generated")
|
||||
self.current_generated = (self.ui.selected_utterance.speaker_name, mel_pred, breaks, None)
|
||||
self.ui.set_loading(0)
|
||||
|
||||
def init_extractor(self):
|
||||
if self.ui.current_extractor_fpath is None:
|
||||
return
|
||||
model_fpath = self.ui.current_extractor_fpath
|
||||
self.ui.log("Loading the extractor %s... " % model_fpath)
|
||||
self.ui.set_loading(1)
|
||||
start = timer()
|
||||
import ppg_extractor as extractor
|
||||
self.extractor = extractor.load_model(model_fpath)
|
||||
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
||||
self.ui.set_loading(0)
|
||||
|
||||
def init_convertor(self):
|
||||
if self.ui.current_convertor_fpath is None:
|
||||
return
|
||||
model_fpath = self.ui.current_convertor_fpath
|
||||
self.ui.log("Loading the convertor %s... " % model_fpath)
|
||||
self.ui.set_loading(1)
|
||||
start = timer()
|
||||
import ppg2mel as convertor
|
||||
self.convertor = convertor.load_model( model_fpath)
|
||||
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
||||
self.ui.set_loading(0)
|
||||
|
||||
def init_encoder(self):
|
||||
model_fpath = self.ui.current_encoder_fpath
|
||||
|
||||
@@ -353,15 +434,31 @@ class Toolbox:
|
||||
self.ui.set_loading(0)
|
||||
|
||||
def init_vocoder(self):
|
||||
|
||||
global vocoder
|
||||
model_fpath = self.ui.current_vocoder_fpath
|
||||
# Case of Griffin-lim
|
||||
if model_fpath is None:
|
||||
return
|
||||
# Sekect vocoder based on model name
|
||||
model_config_fpath = None
|
||||
if model_fpath.name[0] == "g":
|
||||
vocoder = gan_vocoder
|
||||
self.ui.log("set hifigan as vocoder")
|
||||
# search a config file
|
||||
model_config_fpaths = list(model_fpath.parent.rglob("*.json"))
|
||||
if self.vc_mode and self.ui.current_extractor_fpath is None:
|
||||
return
|
||||
if len(model_config_fpaths) > 0:
|
||||
model_config_fpath = model_config_fpaths[0]
|
||||
else:
|
||||
vocoder = rnn_vocoder
|
||||
self.ui.log("set wavernn as vocoder")
|
||||
|
||||
self.ui.log("Loading the vocoder %s... " % model_fpath)
|
||||
self.ui.set_loading(1)
|
||||
start = timer()
|
||||
vocoder.load_model(model_fpath)
|
||||
vocoder.load_model(model_fpath, model_config_fpath)
|
||||
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
||||
self.ui.set_loading(0)
|
||||
|
||||
|
||||
BIN
toolbox/assets/mb.png
Normal file
BIN
toolbox/assets/mb.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.6 KiB |
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user