mirror of
https://github.com/babysor/Realtime-Voice-Clone-Chinese.git
synced 2026-02-03 18:43:41 +08:00
Compare commits
80 Commits
webtoolbox
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a191587417 | ||
|
|
d3ba597be9 | ||
|
|
6134c94b4d | ||
|
|
c04a1097bf | ||
|
|
9b4f8cc6c9 | ||
|
|
96993a5c61 | ||
|
|
70cc3988d3 | ||
|
|
c5998bfe71 | ||
|
|
c997dbdf66 | ||
|
|
47cc597ad0 | ||
|
|
8c895ed2c6 | ||
|
|
2e57bf3f11 | ||
|
|
11a5e2a141 | ||
|
|
7f0d983da7 | ||
|
|
0353bfc6e6 | ||
|
|
9ec114a7c1 | ||
|
|
ddf612e87c | ||
|
|
374cc89cfa | ||
|
|
6009da7072 | ||
|
|
1c61a601d1 | ||
|
|
02ee514aa3 | ||
|
|
6532d65153 | ||
|
|
3fe0690cc6 | ||
|
|
79f424d614 | ||
|
|
3c97d22938 | ||
|
|
fc26c38152 | ||
|
|
6c01b92703 | ||
|
|
c36f02634a | ||
|
|
b05e7441ff | ||
|
|
693de98f4d | ||
|
|
252a5e11b3 | ||
|
|
b617a87ee4 | ||
|
|
ad22997614 | ||
|
|
9e072c2619 | ||
|
|
b79e9d68e4 | ||
|
|
0536874dec | ||
|
|
4529479091 | ||
|
|
8ad9ba2b60 | ||
|
|
b56ec5ee1b | ||
|
|
0bc34a5bc9 | ||
|
|
875fe15069 | ||
|
|
4728863f9d | ||
|
|
a4daf42868 | ||
|
|
b50c7984ab | ||
|
|
26fe4a047d | ||
|
|
aff1b5313b | ||
|
|
7dca74e032 | ||
|
|
a37b26a89c | ||
|
|
902e1eb537 | ||
|
|
5c0e53a29a | ||
|
|
4edebdfeba | ||
|
|
6c8f3f4515 | ||
|
|
2bd323b7df | ||
|
|
3674d8b5c6 | ||
|
|
80aaf32164 | ||
|
|
c396792b22 | ||
|
|
7c58fe01d1 | ||
|
|
724194a4de | ||
|
|
31bc6656c3 | ||
|
|
aa35fb3139 | ||
|
|
727eafc51b | ||
|
|
d328ecba81 | ||
|
|
fad574118c | ||
|
|
b0c156a537 | ||
|
|
724809abf4 | ||
|
|
05cd1a54ea | ||
|
|
245099c740 | ||
|
|
6dd2af49fe | ||
|
|
8b43ec9a64 | ||
|
|
2a99f0ff05 | ||
|
|
a824b54122 | ||
|
|
81befb91b0 | ||
|
|
e2017d0314 | ||
|
|
547ac816df | ||
|
|
6b4ab39601 | ||
|
|
b46e7a7866 | ||
|
|
8a384a1191 | ||
|
|
11154783d8 | ||
|
|
d52db0444e | ||
|
|
790d11a58b |
1
.github/FUNDING.yml
vendored
Normal file
1
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
github: babysor
|
||||||
17
.github/ISSUE_TEMPLATE/issue.md
vendored
Normal file
17
.github/ISSUE_TEMPLATE/issue.md
vendored
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
---
|
||||||
|
name: Issue
|
||||||
|
about: Create a report to help us improve
|
||||||
|
title: ''
|
||||||
|
labels: ''
|
||||||
|
assignees: ''
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Summary[问题简述(一句话)]**
|
||||||
|
A clear and concise description of what the issue is.
|
||||||
|
|
||||||
|
**Env & To Reproduce[复现与环境]**
|
||||||
|
描述你用的环境、代码版本、模型
|
||||||
|
|
||||||
|
**Screenshots[截图(如有)]**
|
||||||
|
If applicable, add screenshots to help
|
||||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -13,9 +13,9 @@
|
|||||||
*.bbl
|
*.bbl
|
||||||
*.bcf
|
*.bcf
|
||||||
*.toc
|
*.toc
|
||||||
*.wav
|
|
||||||
*.sh
|
*.sh
|
||||||
synthesizer/saved_models/*
|
*/saved_models
|
||||||
vocoder/saved_models/*
|
!vocoder/saved_models/pretrained/**
|
||||||
cp_hifigan/*
|
!encoder/saved_models/pretrained.pt
|
||||||
!vocoder/saved_models/pretrained/*
|
wavs
|
||||||
|
log
|
||||||
44
.vscode/launch.json
vendored
44
.vscode/launch.json
vendored
@@ -17,7 +17,7 @@
|
|||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "vocoder_preprocess.py",
|
"program": "vocoder_preprocess.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": ["..\\..\\chs1"]
|
"args": ["..\\audiodata"]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Python: Vocoder Train",
|
"name": "Python: Vocoder Train",
|
||||||
@@ -25,15 +25,49 @@
|
|||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "vocoder_train.py",
|
"program": "vocoder_train.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": ["dev", "..\\..\\chs1"]
|
"args": ["dev", "..\\audiodata"]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Python: demo box",
|
"name": "Python: Demo Box",
|
||||||
"type": "python",
|
"type": "python",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "demo_toolbox.py",
|
"program": "demo_toolbox.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": ["-d", "..\\..\\chs"]
|
"args": ["-d","..\\audiodata"]
|
||||||
}
|
},
|
||||||
|
{
|
||||||
|
"name": "Python: Demo Box VC",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "demo_toolbox.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": ["-d","..\\audiodata","-vc"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Python: Synth Train",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "synthesizer_train.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": ["my_run", "..\\"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Python: PPG Convert",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "run.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": ["-c", ".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml",
|
||||||
|
"-m", ".\\ppg2mel\\saved_models\\best_loss_step_304000.pth", "--wav_dir", ".\\wavs\\input", "--ref_wav_path", ".\\wavs\\pkq.mp3", "-o", ".\\wavs\\output\\"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "GUI",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "mkgui\\base\\_cli.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": []
|
||||||
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
116
README-CN.md
116
README-CN.md
@@ -5,10 +5,10 @@
|
|||||||
|
|
||||||
### [English](README.md) | 中文
|
### [English](README.md) | 中文
|
||||||
|
|
||||||
### [DEMO VIDEO](https://www.bilibili.com/video/BV1sA411P7wM/)
|
### [DEMO VIDEO](https://www.bilibili.com/video/BV17Q4y1B7mY/) | [Wiki教程](https://github.com/babysor/MockingBird/wiki/Quick-Start-(Newbie)) | [训练教程](https://vaj2fgg8yn.feishu.cn/docs/doccn7kAbr3SJz0KM0SIDJ0Xnhd)
|
||||||
|
|
||||||
## 特性
|
## 特性
|
||||||
🌍 **中文** 支持普通话并使用多种中文数据集进行测试:aidatatang_200zh, magicdata, aishell3, biaobei,MozillaCommonVoice 等
|
🌍 **中文** 支持普通话并使用多种中文数据集进行测试:aidatatang_200zh, magicdata, aishell3, biaobei, MozillaCommonVoice, data_aishell 等
|
||||||
|
|
||||||
🤩 **PyTorch** 适用于 pytorch,已在 1.9.0 版本(最新于 2021 年 8 月)中测试,GPU Tesla T4 和 GTX 2060
|
🤩 **PyTorch** 适用于 pytorch,已在 1.9.0 版本(最新于 2021 年 8 月)中测试,GPU Tesla T4 和 GTX 2060
|
||||||
|
|
||||||
@@ -18,9 +18,19 @@
|
|||||||
|
|
||||||
🌍 **Webserver Ready** 可伺服你的训练结果,供远程调用
|
🌍 **Webserver Ready** 可伺服你的训练结果,供远程调用
|
||||||
|
|
||||||
|
### 进行中的工作
|
||||||
|
* GUI/客户端大升级与合并
|
||||||
|
[X] 初始化框架 `./mkgui` (基于streamlit + fastapi)和 [技术设计](https://vaj2fgg8yn.feishu.cn/docs/doccnvotLWylBub8VJIjKzoEaee)
|
||||||
|
[X] 增加 Voice Cloning and Conversion的演示页面
|
||||||
|
[X] 增加Voice Conversion的预处理preprocessing 和训练 training 页面
|
||||||
|
[ ] 增加其他的的预处理preprocessing 和训练 training 页面
|
||||||
|
* 模型后端基于ESPnet2升级
|
||||||
|
|
||||||
|
|
||||||
|
## 开始
|
||||||
### 1. 安装要求
|
### 1. 安装要求
|
||||||
> 按照原始存储库测试您是否已准备好所有环境。
|
> 按照原始存储库测试您是否已准备好所有环境。
|
||||||
**Python 3.7 或更高版本** 需要运行工具箱。
|
运行工具箱(demo_toolbox.py)需要 **Python 3.7 或更高版本** 。
|
||||||
|
|
||||||
* 安装 [PyTorch](https://pytorch.org/get-started/locally/)。
|
* 安装 [PyTorch](https://pytorch.org/get-started/locally/)。
|
||||||
> 如果在用 pip 方式安装的时候出现 `ERROR: Could not find a version that satisfies the requirement torch==1.9.0+cu102 (from versions: 0.1.2, 0.1.2.post1, 0.1.2.post2)` 这个错误可能是 python 版本过低,3.9 可以安装成功
|
> 如果在用 pip 方式安装的时候出现 `ERROR: Could not find a version that satisfies the requirement torch==1.9.0+cu102 (from versions: 0.1.2, 0.1.2.post1, 0.1.2.post2)` 这个错误可能是 python 版本过低,3.9 可以安装成功
|
||||||
@@ -31,11 +41,21 @@
|
|||||||
### 2. 准备预训练模型
|
### 2. 准备预训练模型
|
||||||
考虑训练您自己专属的模型或者下载社区他人训练好的模型:
|
考虑训练您自己专属的模型或者下载社区他人训练好的模型:
|
||||||
> 近期创建了[知乎专题](https://www.zhihu.com/column/c_1425605280340504576) 将不定期更新炼丹小技巧or心得,也欢迎提问
|
> 近期创建了[知乎专题](https://www.zhihu.com/column/c_1425605280340504576) 将不定期更新炼丹小技巧or心得,也欢迎提问
|
||||||
#### 2.1 使用数据集自己训练合成器模型(与2.2二选一)
|
#### 2.1 使用数据集自己训练encoder模型 (可选)
|
||||||
|
|
||||||
|
* 进行音频和梅尔频谱图预处理:
|
||||||
|
`python encoder_preprocess.py <datasets_root>`
|
||||||
|
使用`-d {dataset}` 指定数据集,支持 librispeech_other,voxceleb1,aidatatang_200zh,使用逗号分割处理多数据集。
|
||||||
|
* 训练encoder: `python encoder_train.py my_run <datasets_root>/SV2TTS/encoder`
|
||||||
|
> 训练encoder使用了visdom。你可以加上`-no_visdom`禁用visdom,但是有可视化会更好。在单独的命令行/进程中运行"visdom"来启动visdom服务器。
|
||||||
|
|
||||||
|
#### 2.2 使用数据集自己训练合成器模型(与2.3二选一)
|
||||||
* 下载 数据集并解压:确保您可以访问 *train* 文件夹中的所有音频文件(如.wav)
|
* 下载 数据集并解压:确保您可以访问 *train* 文件夹中的所有音频文件(如.wav)
|
||||||
* 进行音频和梅尔频谱图预处理:
|
* 进行音频和梅尔频谱图预处理:
|
||||||
`python pre.py <datasets_root>`
|
`python pre.py <datasets_root> -d {dataset} -n {number}`
|
||||||
可以传入参数 --dataset `{dataset}` 支持 aidatatang_200zh, magicdata, aishell3
|
可传入参数:
|
||||||
|
* `-d {dataset}` 指定数据集,支持 aidatatang_200zh, magicdata, aishell3, data_aishell, 不传默认为aidatatang_200zh
|
||||||
|
* `-n {number}` 指定并行数,CPU 11770k + 32GB实测10没有问题
|
||||||
> 假如你下载的 `aidatatang_200zh`文件放在D盘,`train`文件路径为 `D:\data\aidatatang_200zh\corpus\train` , 你的`datasets_root`就是 `D:\data\`
|
> 假如你下载的 `aidatatang_200zh`文件放在D盘,`train`文件路径为 `D:\data\aidatatang_200zh\corpus\train` , 你的`datasets_root`就是 `D:\data\`
|
||||||
|
|
||||||
* 训练合成器:
|
* 训练合成器:
|
||||||
@@ -43,16 +63,17 @@
|
|||||||
|
|
||||||
* 当您在训练文件夹 *synthesizer/saved_models/* 中看到注意线显示和损失满足您的需要时,请转到`启动程序`一步。
|
* 当您在训练文件夹 *synthesizer/saved_models/* 中看到注意线显示和损失满足您的需要时,请转到`启动程序`一步。
|
||||||
|
|
||||||
#### 2.2使用社区预先训练好的合成器(与2.1二选一)
|
#### 2.3使用社区预先训练好的合成器(与2.2二选一)
|
||||||
> 当实在没有设备或者不想慢慢调试,可以使用社区贡献的模型(欢迎持续分享):
|
> 当实在没有设备或者不想慢慢调试,可以使用社区贡献的模型(欢迎持续分享):
|
||||||
|
|
||||||
| 作者 | 下载链接 | 效果预览 | 信息 |
|
| 作者 | 下载链接 | 效果预览 | 信息 |
|
||||||
| --- | ----------- | ----- | ----- |
|
| --- | ----------- | ----- | ----- |
|
||||||
| 作者 | https://pan.baidu.com/s/1VHSKIbxXQejtxi2at9IrpA [百度盘链接](https://pan.baidu.com/s/1VHSKIbxXQejtxi2at9IrpA ) 提取码:i183 | | 200k steps 只用aidatatang_200zh
|
| 作者 | https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g [百度盘链接](https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g) 4j5d | | 75k steps 用3个开源数据集混合训练
|
||||||
|@FawenYo | https://drive.google.com/file/d/1H-YGOUHpmqKxJ9FRc6vAjPuqQki24UbC/view?usp=sharing [百度盘链接](https://pan.baidu.com/s/1vSYXO4wsLyjnF3Unl-Xoxg) 提取码:1024 | [input](https://github.com/babysor/MockingBird/wiki/audio/self_test.mp3) [output](https://github.com/babysor/MockingBird/wiki/audio/export.wav) | 200k steps 台湾口音
|
| 作者 | https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw [百度盘链接](https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw) 提取码:om7f | | 25k steps 用3个开源数据集混合训练, 切换到tag v0.0.1使用
|
||||||
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ 提取码:2021 | https://www.bilibili.com/video/BV1uh411B7AD/ | 150k steps 旧版需根据[issue](https://github.com/babysor/MockingBird/issues/37)修复
|
|@FawenYo | https://drive.google.com/file/d/1H-YGOUHpmqKxJ9FRc6vAjPuqQki24UbC/view?usp=sharing [百度盘链接](https://pan.baidu.com/s/1vSYXO4wsLyjnF3Unl-Xoxg) 提取码:1024 | [input](https://github.com/babysor/MockingBird/wiki/audio/self_test.mp3) [output](https://github.com/babysor/MockingBird/wiki/audio/export.wav) | 200k steps 台湾口音需切换到tag v0.0.1使用
|
||||||
|
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ 提取码:2021 | https://www.bilibili.com/video/BV1uh411B7AD/ | 150k steps 注意:根据[issue](https://github.com/babysor/MockingBird/issues/37)修复 并切换到tag v0.0.1使用
|
||||||
|
|
||||||
#### 2.3训练声码器 (可选)
|
#### 2.4训练声码器 (可选)
|
||||||
对效果影响不大,已经预置3款,如果希望自己训练可以参考以下命令。
|
对效果影响不大,已经预置3款,如果希望自己训练可以参考以下命令。
|
||||||
* 预处理数据:
|
* 预处理数据:
|
||||||
`python vocoder_preprocess.py <datasets_root> -m <synthesizer_model_path>`
|
`python vocoder_preprocess.py <datasets_root> -m <synthesizer_model_path>`
|
||||||
@@ -70,15 +91,10 @@
|
|||||||
### 3. 启动程序或工具箱
|
### 3. 启动程序或工具箱
|
||||||
您可以尝试使用以下命令:
|
您可以尝试使用以下命令:
|
||||||
|
|
||||||
### 3.1 启动Web程序:
|
### 3.1 启动Web程序(v2):
|
||||||
`python web.py`
|
`python web.py`
|
||||||
运行成功后在浏览器打开地址, 默认为 `http://localhost:8080`
|
运行成功后在浏览器打开地址, 默认为 `http://localhost:8080`
|
||||||
<img width="578" alt="bd64cd80385754afa599e3840504f45" src="https://user-images.githubusercontent.com/7423248/134275205-c95e6bd8-4f41-4eb5-9143-0390627baee1.png">
|
|
||||||
> 注:目前界面比较buggy,
|
|
||||||
> * 第一次点击`录制`要等待几秒浏览器正常启动录音,否则会有重音
|
|
||||||
> * 录制结束不要再点`录制`而是`停止`
|
|
||||||
> * 仅支持手动新录音(16khz), 不支持超过4MB的录音,最佳长度在5~15秒
|
> * 仅支持手动新录音(16khz), 不支持超过4MB的录音,最佳长度在5~15秒
|
||||||
> * 默认使用第一个找到的模型,有动手能力的可以看代码修改 `web\__init__.py`。
|
|
||||||
|
|
||||||
### 3.2 启动工具箱:
|
### 3.2 启动工具箱:
|
||||||
`python demo_toolbox.py -d <datasets_root>`
|
`python demo_toolbox.py -d <datasets_root>`
|
||||||
@@ -86,48 +102,56 @@
|
|||||||
|
|
||||||
<img width="1042" alt="d48ea37adf3660e657cfb047c10edbc" src="https://user-images.githubusercontent.com/7423248/134275227-c1ddf154-f118-4b77-8949-8c4c7daf25f0.png">
|
<img width="1042" alt="d48ea37adf3660e657cfb047c10edbc" src="https://user-images.githubusercontent.com/7423248/134275227-c1ddf154-f118-4b77-8949-8c4c7daf25f0.png">
|
||||||
|
|
||||||
## 文件结构(目标读者:开发者)
|
### 4. 番外:语音转换Voice Conversion(PPG based)
|
||||||
```
|
想像柯南拿着变声器然后发出毛利小五郎的声音吗?本项目现基于PPG-VC,引入额外两个模块(PPG extractor + PPG2Mel), 可以实现变声功能。(文档不全,尤其是训练部分,正在努力补充中)
|
||||||
├─archived_untest_files 废弃文件
|
#### 4.0 准备环境
|
||||||
├─encoder encoder模型
|
* 确保项目以上环境已经安装ok,运行`pip install -r requirements_vc.txt` 来安装剩余的必要包。
|
||||||
│ ├─data_objects
|
* 下载以下模型 链接:https://pan.baidu.com/s/1bl_x_DHJSAUyN2fma-Q_Wg
|
||||||
│ └─saved_models 预训练好的模型
|
提取码:gh41
|
||||||
├─samples 样例语音
|
* 24K采样率专用的vocoder(hifigan)到 *vocoder\saved_mode\xxx*
|
||||||
├─synthesizer synthesizer模型
|
* 预训练的ppg特征encoder(ppg_extractor)到 *ppg_extractor\saved_mode\xxx*
|
||||||
│ ├─models
|
* 预训练的PPG2Mel到 *ppg2mel\saved_mode\xxx*
|
||||||
│ ├─saved_models 预训练好的模型
|
|
||||||
│ └─utils 工具类库
|
#### 4.1 使用数据集自己训练PPG2Mel模型 (可选)
|
||||||
├─toolbox 图形化工具箱
|
|
||||||
├─utils 工具类库
|
* 下载aidatatang_200zh数据集并解压:确保您可以访问 *train* 文件夹中的所有音频文件(如.wav)
|
||||||
├─vocoder vocoder模型(目前包含hifi-gan、wavrnn)
|
* 进行音频和梅尔频谱图预处理:
|
||||||
│ ├─hifigan
|
`python pre4ppg.py <datasets_root> -d {dataset} -n {number}`
|
||||||
│ ├─saved_models 预训练好的模型
|
可传入参数:
|
||||||
│ └─wavernn
|
* `-d {dataset}` 指定数据集,支持 aidatatang_200zh, 不传默认为aidatatang_200zh
|
||||||
└─web
|
* `-n {number}` 指定并行数,CPU 11770k在8的情况下,需要运行12到18小时!待优化
|
||||||
├─api
|
> 假如你下载的 `aidatatang_200zh`文件放在D盘,`train`文件路径为 `D:\data\aidatatang_200zh\corpus\train` , 你的`datasets_root`就是 `D:\data\`
|
||||||
│ └─Web端接口
|
|
||||||
├─config
|
* 训练合成器, 注意在上一步先下载好`ppg2mel.yaml`, 修改里面的地址指向预训练好的文件夹:
|
||||||
│ └─ Web端配置文件
|
`python ppg2mel_train.py --config .\ppg2mel\saved_models\ppg2mel.yaml --oneshotvc `
|
||||||
├─static 前端静态脚本
|
* 如果想要继续上一次的训练,可以通过`--load .\ppg2mel\saved_models\<old_pt_file>` 参数指定一个预训练模型文件。
|
||||||
│ └─js
|
|
||||||
├─templates 前端模板
|
#### 4.2 启动工具箱VC模式
|
||||||
└─__init__.py Web端入口文件
|
您可以尝试使用以下命令:
|
||||||
```
|
`python demo_toolbox.py vc -d <datasets_root>`
|
||||||
|
> 请指定一个可用的数据集文件路径,如果有支持的数据集则会自动加载供调试,也同时会作为手动录制音频的存储目录。
|
||||||
|
<img width="971" alt="微信图片_20220305005351" src="https://user-images.githubusercontent.com/7423248/156805733-2b093dbc-d989-4e68-8609-db11f365886a.png">
|
||||||
|
|
||||||
## 引用及论文
|
## 引用及论文
|
||||||
> 该库一开始从仅支持英语的[Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) 分叉出来的,鸣谢作者。
|
> 该库一开始从仅支持英语的[Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) 分叉出来的,鸣谢作者。
|
||||||
|
|
||||||
| URL | Designation | 标题 | 实现源码 |
|
| URL | Designation | 标题 | 实现源码 |
|
||||||
| --- | ----------- | ----- | --------------------- |
|
| --- | ----------- | ----- | --------------------- |
|
||||||
|
| [1803.09017](https://arxiv.org/abs/1803.09017) | GlobalStyleToken (synthesizer)| Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis | 本代码库 |
|
||||||
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | 本代码库 |
|
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | 本代码库 |
|
||||||
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | **SV2TTS** | **Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis** | This repo |
|
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | SV2TTS | Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis | 本代码库 |
|
||||||
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|
||||||
|[1703.10135](https://arxiv.org/pdf/1703.10135.pdf) | Tacotron (synthesizer) | Tacotron: Towards End-to-End Speech Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN)
|
|[1703.10135](https://arxiv.org/pdf/1703.10135.pdf) | Tacotron (synthesizer) | Tacotron: Towards End-to-End Speech Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN)
|
||||||
|[1710.10467](https://arxiv.org/pdf/1710.10467.pdf) | GE2E (encoder)| Generalized End-To-End Loss for Speaker Verification | 本代码库 |
|
|[1710.10467](https://arxiv.org/pdf/1710.10467.pdf) | GE2E (encoder)| Generalized End-To-End Loss for Speaker Verification | 本代码库 |
|
||||||
|
|
||||||
## 常見問題(FQ&A)
|
## 常見問題(FQ&A)
|
||||||
#### 1.數據集哪裡下載?
|
#### 1.數據集哪裡下載?
|
||||||
[aidatatang_200zh](http://www.openslr.org/62/)、[magicdata](http://www.openslr.org/68/)、[aishell3](http://www.openslr.org/93/)
|
| 数据集 | OpenSLR地址 | 其他源 (Google Drive, Baidu网盘等) |
|
||||||
|
| --- | ----------- | ---------------|
|
||||||
|
| aidatatang_200zh | [OpenSLR](http://www.openslr.org/62/) | [Google Drive](https://drive.google.com/file/d/110A11KZoVe7vy6kXlLb6zVPLb_J91I_t/view?usp=sharing) |
|
||||||
|
| magicdata | [OpenSLR](http://www.openslr.org/68/) | [Google Drive (Dev set)](https://drive.google.com/file/d/1g5bWRUSNH68ycC6eNvtwh07nX3QhOOlo/view?usp=sharing) |
|
||||||
|
| aishell3 | [OpenSLR](https://www.openslr.org/93/) | [Google Drive](https://drive.google.com/file/d/1shYp_o4Z0X0cZSKQDtFirct2luFUwKzZ/view?usp=sharing) |
|
||||||
|
| data_aishell | [OpenSLR](https://www.openslr.org/33/) | |
|
||||||
> 解壓 aidatatang_200zh 後,還需將 `aidatatang_200zh\corpus\train`下的檔案全選解壓縮
|
> 解壓 aidatatang_200zh 後,還需將 `aidatatang_200zh\corpus\train`下的檔案全選解壓縮
|
||||||
|
|
||||||
#### 2.`<datasets_root>`是什麼意思?
|
#### 2.`<datasets_root>`是什麼意思?
|
||||||
|
|||||||
44
README.md
44
README.md
@@ -6,7 +6,7 @@
|
|||||||
> English | [中文](README-CN.md)
|
> English | [中文](README-CN.md)
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
🌍 **Chinese** supported mandarin and tested with multiple datasets: aidatatang_200zh, magicdata, aishell3, and etc.
|
🌍 **Chinese** supported mandarin and tested with multiple datasets: aidatatang_200zh, magicdata, aishell3, data_aishell, and etc.
|
||||||
|
|
||||||
🤩 **PyTorch** worked for pytorch, tested in version of 1.9.0(latest in August 2021), with GPU Tesla T4 and GTX 2060
|
🤩 **PyTorch** worked for pytorch, tested in version of 1.9.0(latest in August 2021), with GPU Tesla T4 and GTX 2060
|
||||||
|
|
||||||
@@ -16,7 +16,15 @@
|
|||||||
|
|
||||||
🌍 **Webserver Ready** to serve your result with remote calling
|
🌍 **Webserver Ready** to serve your result with remote calling
|
||||||
|
|
||||||
### [DEMO VIDEO](https://www.bilibili.com/video/BV1sA411P7wM/)
|
### [DEMO VIDEO](https://www.bilibili.com/video/BV17Q4y1B7mY/)
|
||||||
|
|
||||||
|
### Ongoing Works(Helps Needed)
|
||||||
|
* Major upgrade on GUI/Client and unifying web and toolbox
|
||||||
|
[X] Init framework `./mkgui` and [tech design](https://vaj2fgg8yn.feishu.cn/docs/doccnvotLWylBub8VJIjKzoEaee)
|
||||||
|
[X] Add demo part of Voice Cloning and Conversion
|
||||||
|
[X] Add preprocessing and training for Voice Conversion
|
||||||
|
[ ] Add preprocessing and training for Encoder/Synthesizer/Vocoder
|
||||||
|
* Major upgrade on model backend based on ESPnet2(not yet started)
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
@@ -32,27 +40,37 @@
|
|||||||
> Note that we are using the pretrained encoder/vocoder but synthesizer, since the original model is incompatible with the Chinese sympols. It means the demo_cli is not working at this moment.
|
> Note that we are using the pretrained encoder/vocoder but synthesizer, since the original model is incompatible with the Chinese sympols. It means the demo_cli is not working at this moment.
|
||||||
### 2. Prepare your models
|
### 2. Prepare your models
|
||||||
You can either train your models or use existing ones:
|
You can either train your models or use existing ones:
|
||||||
#### 2.1. Train synthesizer with your dataset
|
|
||||||
|
#### 2.1 Train encoder with your dataset (Optional)
|
||||||
|
|
||||||
|
* Preprocess with the audios and the mel spectrograms:
|
||||||
|
`python encoder_preprocess.py <datasets_root>` Allowing parameter `--dataset {dataset}` to support the datasets you want to preprocess. Only the train set of these datasets will be used. Possible names: librispeech_other, voxceleb1, voxceleb2. Use comma to sperate multiple datasets.
|
||||||
|
|
||||||
|
* Train the encoder: `python encoder_train.py my_run <datasets_root>/SV2TTS/encoder`
|
||||||
|
> For training, the encoder uses visdom. You can disable it with `--no_visdom`, but it's nice to have. Run "visdom" in a separate CLI/process to start your visdom server.
|
||||||
|
|
||||||
|
#### 2.2 Train synthesizer with your dataset
|
||||||
* Download dataset and unzip: make sure you can access all .wav in folder
|
* Download dataset and unzip: make sure you can access all .wav in folder
|
||||||
* Preprocess with the audios and the mel spectrograms:
|
* Preprocess with the audios and the mel spectrograms:
|
||||||
`python pre.py <datasets_root>`
|
`python pre.py <datasets_root>`
|
||||||
Allowing parameter `--dataset {dataset}` to support aidatatang_200zh, magicdata, aishell3, etc.
|
Allowing parameter `--dataset {dataset}` to support aidatatang_200zh, magicdata, aishell3, data_aishell, etc.If this parameter is not passed, the default dataset will be aidatatang_200zh.
|
||||||
|
|
||||||
* Train the synthesizer:
|
* Train the synthesizer:
|
||||||
`python synthesizer_train.py mandarin <datasets_root>/SV2TTS/synthesizer`
|
`python synthesizer_train.py mandarin <datasets_root>/SV2TTS/synthesizer`
|
||||||
|
|
||||||
* Go to next step when you see attention line show and loss meet your need in training folder *synthesizer/saved_models/*.
|
* Go to next step when you see attention line show and loss meet your need in training folder *synthesizer/saved_models/*.
|
||||||
|
|
||||||
#### 2.2 Use pretrained model of synthesizer
|
#### 2.3 Use pretrained model of synthesizer
|
||||||
> Thanks to the community, some models will be shared:
|
> Thanks to the community, some models will be shared:
|
||||||
|
|
||||||
| author | Download link | Preview Video | Info |
|
| author | Download link | Preview Video | Info |
|
||||||
| --- | ----------- | ----- |----- |
|
| --- | ----------- | ----- |----- |
|
||||||
| @myself | https://pan.baidu.com/s/1VHSKIbxXQejtxi2at9IrpA [Baidu](https://pan.baidu.com/s/1VHSKIbxXQejtxi2at9IrpA ) code:i183 | | 200k steps only trained by aidatatang_200zh
|
| @author | https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g [Baidu](https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g) 4j5d | | 75k steps trained by multiple datasets
|
||||||
|@FawenYo | https://drive.google.com/file/d/1H-YGOUHpmqKxJ9FRc6vAjPuqQki24UbC/view?usp=sharing [Baidu Pan](https://pan.baidu.com/s/1vSYXO4wsLyjnF3Unl-Xoxg) Code:1024 | [input](https://github.com/babysor/MockingBird/wiki/audio/self_test.mp3) [output](https://github.com/babysor/MockingBird/wiki/audio/export.wav) | 200k steps with local accent of Taiwan
|
| @author | https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw [Baidu](https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw) code:om7f | | 25k steps trained by multiple datasets, only works under version 0.0.1
|
||||||
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ code:2021 | https://www.bilibili.com/video/BV1uh411B7AD/
|
|@FawenYo | https://drive.google.com/file/d/1H-YGOUHpmqKxJ9FRc6vAjPuqQki24UbC/view?usp=sharing https://u.teknik.io/AYxWf.pt | [input](https://github.com/babysor/MockingBird/wiki/audio/self_test.mp3) [output](https://github.com/babysor/MockingBird/wiki/audio/export.wav) | 200k steps with local accent of Taiwan, only works under version 0.0.1
|
||||||
|
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ code:2021 | https://www.bilibili.com/video/BV1uh411B7AD/ | only works under version 0.0.1
|
||||||
|
|
||||||
#### 2.3 Train vocoder (Optional)
|
#### 2.4 Train vocoder (Optional)
|
||||||
> note: vocoder has little difference in effect, so you may not need to train a new one.
|
> note: vocoder has little difference in effect, so you may not need to train a new one.
|
||||||
* Preprocess the data:
|
* Preprocess the data:
|
||||||
`python vocoder_preprocess.py <datasets_root> -m <synthesizer_model_path>`
|
`python vocoder_preprocess.py <datasets_root> -m <synthesizer_model_path>`
|
||||||
@@ -77,6 +95,7 @@ You can then try the toolbox:
|
|||||||
|
|
||||||
| URL | Designation | Title | Implementation source |
|
| URL | Designation | Title | Implementation source |
|
||||||
| --- | ----------- | ----- | --------------------- |
|
| --- | ----------- | ----- | --------------------- |
|
||||||
|
| [1803.09017](https://arxiv.org/abs/1803.09017) | GlobalStyleToken (synthesizer)| Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis | This repo |
|
||||||
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | This repo |
|
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | This repo |
|
||||||
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | **SV2TTS** | **Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis** | This repo |
|
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | **SV2TTS** | **Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis** | This repo |
|
||||||
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|
||||||
@@ -85,7 +104,12 @@ You can then try the toolbox:
|
|||||||
|
|
||||||
## F Q&A
|
## F Q&A
|
||||||
#### 1.Where can I download the dataset?
|
#### 1.Where can I download the dataset?
|
||||||
[aidatatang_200zh](http://www.openslr.org/62/)、[magicdata](http://www.openslr.org/68/)、[aishell3](http://www.openslr.org/93/)
|
| Dataset | Original Source | Alternative Sources |
|
||||||
|
| --- | ----------- | ---------------|
|
||||||
|
| aidatatang_200zh | [OpenSLR](http://www.openslr.org/62/) | [Google Drive](https://drive.google.com/file/d/110A11KZoVe7vy6kXlLb6zVPLb_J91I_t/view?usp=sharing) |
|
||||||
|
| magicdata | [OpenSLR](http://www.openslr.org/68/) | [Google Drive (Dev set)](https://drive.google.com/file/d/1g5bWRUSNH68ycC6eNvtwh07nX3QhOOlo/view?usp=sharing) |
|
||||||
|
| aishell3 | [OpenSLR](https://www.openslr.org/93/) | [Google Drive](https://drive.google.com/file/d/1shYp_o4Z0X0cZSKQDtFirct2luFUwKzZ/view?usp=sharing) |
|
||||||
|
| data_aishell | [OpenSLR](https://www.openslr.org/33/) | |
|
||||||
> After unzip aidatatang_200zh, you need to unzip all the files under `aidatatang_200zh\corpus\train`
|
> After unzip aidatatang_200zh, you need to unzip all the files under `aidatatang_200zh\corpus\train`
|
||||||
|
|
||||||
#### 2.What is`<datasets_root>`?
|
#### 2.What is`<datasets_root>`?
|
||||||
|
|||||||
@@ -15,12 +15,18 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("-d", "--datasets_root", type=Path, help= \
|
parser.add_argument("-d", "--datasets_root", type=Path, help= \
|
||||||
"Path to the directory containing your datasets. See toolbox/__init__.py for a list of "
|
"Path to the directory containing your datasets. See toolbox/__init__.py for a list of "
|
||||||
"supported datasets.", default=None)
|
"supported datasets.", default=None)
|
||||||
|
parser.add_argument("-vc", "--vc_mode", action="store_true",
|
||||||
|
help="Voice Conversion Mode(PPG based)")
|
||||||
parser.add_argument("-e", "--enc_models_dir", type=Path, default="encoder/saved_models",
|
parser.add_argument("-e", "--enc_models_dir", type=Path, default="encoder/saved_models",
|
||||||
help="Directory containing saved encoder models")
|
help="Directory containing saved encoder models")
|
||||||
parser.add_argument("-s", "--syn_models_dir", type=Path, default="synthesizer/saved_models",
|
parser.add_argument("-s", "--syn_models_dir", type=Path, default="synthesizer/saved_models",
|
||||||
help="Directory containing saved synthesizer models")
|
help="Directory containing saved synthesizer models")
|
||||||
parser.add_argument("-v", "--voc_models_dir", type=Path, default="vocoder/saved_models",
|
parser.add_argument("-v", "--voc_models_dir", type=Path, default="vocoder/saved_models",
|
||||||
help="Directory containing saved vocoder models")
|
help="Directory containing saved vocoder models")
|
||||||
|
parser.add_argument("-ex", "--extractor_models_dir", type=Path, default="ppg_extractor/saved_models",
|
||||||
|
help="Directory containing saved extrator models")
|
||||||
|
parser.add_argument("-cv", "--convertor_models_dir", type=Path, default="ppg2mel/saved_models",
|
||||||
|
help="Directory containing saved convert models")
|
||||||
parser.add_argument("--cpu", action="store_true", help=\
|
parser.add_argument("--cpu", action="store_true", help=\
|
||||||
"If True, processing is done on CPU, even when a GPU is available.")
|
"If True, processing is done on CPU, even when a GPU is available.")
|
||||||
parser.add_argument("--seed", type=int, default=None, help=\
|
parser.add_argument("--seed", type=int, default=None, help=\
|
||||||
|
|||||||
@@ -34,8 +34,16 @@ def load_model(weights_fpath: Path, device=None):
|
|||||||
_model.load_state_dict(checkpoint["model_state"])
|
_model.load_state_dict(checkpoint["model_state"])
|
||||||
_model.eval()
|
_model.eval()
|
||||||
print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
|
print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
|
||||||
|
return _model
|
||||||
|
|
||||||
|
def set_model(model, device=None):
|
||||||
|
global _model, _device
|
||||||
|
_model = model
|
||||||
|
if device is None:
|
||||||
|
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
_device = device
|
||||||
|
_model.to(device)
|
||||||
|
|
||||||
def is_loaded():
|
def is_loaded():
|
||||||
return _model is not None
|
return _model is not None
|
||||||
|
|
||||||
@@ -57,7 +65,7 @@ def embed_frames_batch(frames_batch):
|
|||||||
|
|
||||||
|
|
||||||
def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
|
def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
|
||||||
min_pad_coverage=0.75, overlap=0.5):
|
min_pad_coverage=0.75, overlap=0.5, rate=None):
|
||||||
"""
|
"""
|
||||||
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
|
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
|
||||||
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
|
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
|
||||||
@@ -85,9 +93,18 @@ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_fram
|
|||||||
assert 0 <= overlap < 1
|
assert 0 <= overlap < 1
|
||||||
assert 0 < min_pad_coverage <= 1
|
assert 0 < min_pad_coverage <= 1
|
||||||
|
|
||||||
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
if rate != None:
|
||||||
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
||||||
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
|
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
||||||
|
frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
|
||||||
|
else:
|
||||||
|
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
||||||
|
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
||||||
|
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
|
||||||
|
|
||||||
|
assert 0 < frame_step, "The rate is too high"
|
||||||
|
assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
|
||||||
|
(sampling_rate / (samples_per_frame * partials_n_frames))
|
||||||
|
|
||||||
# Compute the slices
|
# Compute the slices
|
||||||
wav_slices, mel_slices = [], []
|
wav_slices, mel_slices = [], []
|
||||||
|
|||||||
@@ -117,6 +117,15 @@ def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir,
|
|||||||
logger.finalize()
|
logger.finalize()
|
||||||
print("Done preprocessing %s.\n" % dataset_name)
|
print("Done preprocessing %s.\n" % dataset_name)
|
||||||
|
|
||||||
|
def preprocess_aidatatang_200zh(datasets_root: Path, out_dir: Path, skip_existing=False):
|
||||||
|
dataset_name = "aidatatang_200zh"
|
||||||
|
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
||||||
|
if not dataset_root:
|
||||||
|
return
|
||||||
|
# Preprocess all speakers
|
||||||
|
speaker_dirs = list(dataset_root.joinpath("corpus", "train").glob("*"))
|
||||||
|
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
|
||||||
|
skip_existing, logger)
|
||||||
|
|
||||||
def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
|
def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
|
||||||
for dataset_name in librispeech_datasets["train"]["other"]:
|
for dataset_name in librispeech_datasets["train"]["other"]:
|
||||||
|
|||||||
Binary file not shown.
@@ -1,4 +1,4 @@
|
|||||||
from encoder.preprocess import preprocess_librispeech, preprocess_voxceleb1, preprocess_voxceleb2
|
from encoder.preprocess import preprocess_librispeech, preprocess_voxceleb1, preprocess_voxceleb2, preprocess_aidatatang_200zh
|
||||||
from utils.argutils import print_args
|
from utils.argutils import print_args
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import argparse
|
import argparse
|
||||||
@@ -10,17 +10,7 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Preprocesses audio files from datasets, encodes them as mel spectrograms and "
|
description="Preprocesses audio files from datasets, encodes them as mel spectrograms and "
|
||||||
"writes them to the disk. This will allow you to train the encoder. The "
|
"writes them to the disk. This will allow you to train the encoder. The "
|
||||||
"datasets required are at least one of VoxCeleb1, VoxCeleb2 and LibriSpeech. "
|
"datasets required are at least one of LibriSpeech, VoxCeleb1, VoxCeleb2, aidatatang_200zh. ",
|
||||||
"Ideally, you should have all three. You should extract them as they are "
|
|
||||||
"after having downloaded them and put them in a same directory, e.g.:\n"
|
|
||||||
"-[datasets_root]\n"
|
|
||||||
" -LibriSpeech\n"
|
|
||||||
" -train-other-500\n"
|
|
||||||
" -VoxCeleb1\n"
|
|
||||||
" -wav\n"
|
|
||||||
" -vox1_meta.csv\n"
|
|
||||||
" -VoxCeleb2\n"
|
|
||||||
" -dev",
|
|
||||||
formatter_class=MyFormatter
|
formatter_class=MyFormatter
|
||||||
)
|
)
|
||||||
parser.add_argument("datasets_root", type=Path, help=\
|
parser.add_argument("datasets_root", type=Path, help=\
|
||||||
@@ -29,7 +19,7 @@ if __name__ == "__main__":
|
|||||||
"Path to the output directory that will contain the mel spectrograms. If left out, "
|
"Path to the output directory that will contain the mel spectrograms. If left out, "
|
||||||
"defaults to <datasets_root>/SV2TTS/encoder/")
|
"defaults to <datasets_root>/SV2TTS/encoder/")
|
||||||
parser.add_argument("-d", "--datasets", type=str,
|
parser.add_argument("-d", "--datasets", type=str,
|
||||||
default="librispeech_other,voxceleb1,voxceleb2", help=\
|
default="librispeech_other,voxceleb1,aidatatang_200zh", help=\
|
||||||
"Comma-separated list of the name of the datasets you want to preprocess. Only the train "
|
"Comma-separated list of the name of the datasets you want to preprocess. Only the train "
|
||||||
"set of these datasets will be used. Possible names: librispeech_other, voxceleb1, "
|
"set of these datasets will be used. Possible names: librispeech_other, voxceleb1, "
|
||||||
"voxceleb2.")
|
"voxceleb2.")
|
||||||
@@ -63,6 +53,7 @@ if __name__ == "__main__":
|
|||||||
"librispeech_other": preprocess_librispeech,
|
"librispeech_other": preprocess_librispeech,
|
||||||
"voxceleb1": preprocess_voxceleb1,
|
"voxceleb1": preprocess_voxceleb1,
|
||||||
"voxceleb2": preprocess_voxceleb2,
|
"voxceleb2": preprocess_voxceleb2,
|
||||||
|
"aidatatang_200zh": preprocess_aidatatang_200zh,
|
||||||
}
|
}
|
||||||
args = vars(args)
|
args = vars(args)
|
||||||
for dataset in args.pop("datasets"):
|
for dataset in args.pop("datasets"):
|
||||||
0
mkgui/__init__.py
Normal file
0
mkgui/__init__.py
Normal file
143
mkgui/app.py
Normal file
143
mkgui/app.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
from asyncio.windows_events import NULL
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from enum import Enum
|
||||||
|
from encoder import inference as encoder
|
||||||
|
import librosa
|
||||||
|
from scipy.io.wavfile import write
|
||||||
|
import re
|
||||||
|
import numpy as np
|
||||||
|
from mkgui.base.components.types import FileContent
|
||||||
|
from vocoder.hifigan import inference as gan_vocoder
|
||||||
|
from synthesizer.inference import Synthesizer
|
||||||
|
from typing import Any
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
AUDIO_SAMPLES_DIR = 'samples\\'
|
||||||
|
SYN_MODELS_DIRT = "synthesizer\\saved_models"
|
||||||
|
ENC_MODELS_DIRT = "encoder\\saved_models"
|
||||||
|
VOC_MODELS_DIRT = "vocoder\\saved_models"
|
||||||
|
TEMP_SOURCE_AUDIO = "wavs/temp_source.wav"
|
||||||
|
TEMP_RESULT_AUDIO = "wavs/temp_result.wav"
|
||||||
|
|
||||||
|
# Load local sample audio as options TODO: load dataset
|
||||||
|
if os.path.isdir(AUDIO_SAMPLES_DIR):
|
||||||
|
audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
|
||||||
|
# Pre-Load models
|
||||||
|
if os.path.isdir(SYN_MODELS_DIRT):
|
||||||
|
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
|
||||||
|
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
||||||
|
else:
|
||||||
|
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.")
|
||||||
|
|
||||||
|
if os.path.isdir(ENC_MODELS_DIRT):
|
||||||
|
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||||
|
print("Loaded encoders models: " + str(len(encoders)))
|
||||||
|
else:
|
||||||
|
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||||
|
|
||||||
|
if os.path.isdir(VOC_MODELS_DIRT):
|
||||||
|
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
|
||||||
|
print("Loaded vocoders models: " + str(len(synthesizers)))
|
||||||
|
else:
|
||||||
|
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
|
||||||
|
|
||||||
|
|
||||||
|
class Input(BaseModel):
|
||||||
|
message: str = Field(
|
||||||
|
..., example="欢迎使用工具箱, 现已支持中文输入!", alias="文本内容"
|
||||||
|
)
|
||||||
|
local_audio_file: audio_input_selection = Field(
|
||||||
|
..., alias="输入语音(本地wav)",
|
||||||
|
description="选择本地语音文件."
|
||||||
|
)
|
||||||
|
upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
|
||||||
|
description="拖拽或点击上传.", mime_type="audio/wav")
|
||||||
|
encoder: encoders = Field(
|
||||||
|
..., alias="编码模型",
|
||||||
|
description="选择语音编码模型文件."
|
||||||
|
)
|
||||||
|
synthesizer: synthesizers = Field(
|
||||||
|
..., alias="合成模型",
|
||||||
|
description="选择语音合成模型文件."
|
||||||
|
)
|
||||||
|
vocoder: vocoders = Field(
|
||||||
|
..., alias="语音解码模型",
|
||||||
|
description="选择语音解码模型文件(目前只支持HifiGan类型)."
|
||||||
|
)
|
||||||
|
|
||||||
|
class AudioEntity(BaseModel):
|
||||||
|
content: bytes
|
||||||
|
mel: Any
|
||||||
|
|
||||||
|
class Output(BaseModel):
|
||||||
|
__root__: tuple[AudioEntity, AudioEntity]
|
||||||
|
|
||||||
|
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||||
|
"""Custom output UI.
|
||||||
|
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||||
|
"""
|
||||||
|
src, result = self.__root__
|
||||||
|
|
||||||
|
streamlit_app.subheader("Synthesized Audio")
|
||||||
|
streamlit_app.audio(result.content, format="audio/wav")
|
||||||
|
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.imshow(src.mel, aspect="equal", interpolation="none")
|
||||||
|
ax.set_title("mel spectrogram(Source Audio)")
|
||||||
|
streamlit_app.pyplot(fig)
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.imshow(result.mel, aspect="equal", interpolation="none")
|
||||||
|
ax.set_title("mel spectrogram(Result Audio)")
|
||||||
|
streamlit_app.pyplot(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def synthesize(input: Input) -> Output:
|
||||||
|
"""synthesize(合成)"""
|
||||||
|
# load models
|
||||||
|
encoder.load_model(Path(input.encoder.value))
|
||||||
|
current_synt = Synthesizer(Path(input.synthesizer.value))
|
||||||
|
gan_vocoder.load_model(Path(input.vocoder.value))
|
||||||
|
|
||||||
|
# load file
|
||||||
|
if input.upload_audio_file != None:
|
||||||
|
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
||||||
|
f.write(input.upload_audio_file.as_bytes())
|
||||||
|
f.seek(0)
|
||||||
|
wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
|
||||||
|
else:
|
||||||
|
wav, sample_rate = librosa.load(input.local_audio_file.value)
|
||||||
|
write(TEMP_SOURCE_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
||||||
|
|
||||||
|
source_spec = Synthesizer.make_spectrogram(wav)
|
||||||
|
|
||||||
|
# preprocess
|
||||||
|
encoder_wav = encoder.preprocess_wav(wav, sample_rate)
|
||||||
|
embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
||||||
|
|
||||||
|
# Load input text
|
||||||
|
texts = filter(None, input.message.split("\n"))
|
||||||
|
punctuation = '!,。、,' # punctuate and split/clean text
|
||||||
|
processed_texts = []
|
||||||
|
for text in texts:
|
||||||
|
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
|
||||||
|
if processed_text:
|
||||||
|
processed_texts.append(processed_text.strip())
|
||||||
|
texts = processed_texts
|
||||||
|
|
||||||
|
# synthesize and vocode
|
||||||
|
embeds = [embed] * len(texts)
|
||||||
|
specs = current_synt.synthesize_spectrograms(texts, embeds)
|
||||||
|
spec = np.concatenate(specs, axis=1)
|
||||||
|
sample_rate = Synthesizer.sample_rate
|
||||||
|
wav, sample_rate = gan_vocoder.infer_waveform(spec)
|
||||||
|
|
||||||
|
# write and output
|
||||||
|
write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
||||||
|
with open(TEMP_SOURCE_AUDIO, "rb") as f:
|
||||||
|
source_file = f.read()
|
||||||
|
with open(TEMP_RESULT_AUDIO, "rb") as f:
|
||||||
|
result_file = f.read()
|
||||||
|
return Output(__root__=(AudioEntity(content=source_file, mel=source_spec), AudioEntity(content=result_file, mel=spec)))
|
||||||
167
mkgui/app_vc.py
Normal file
167
mkgui/app_vc.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
from asyncio.windows_events import NULL
|
||||||
|
from synthesizer.inference import Synthesizer
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from encoder import inference as speacker_encoder
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from enum import Enum
|
||||||
|
import ppg_extractor as Extractor
|
||||||
|
import ppg2mel as Convertor
|
||||||
|
import librosa
|
||||||
|
from scipy.io.wavfile import write
|
||||||
|
import re
|
||||||
|
import numpy as np
|
||||||
|
from mkgui.base.components.types import FileContent
|
||||||
|
from vocoder.hifigan import inference as gan_vocoder
|
||||||
|
from typing import Any
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
AUDIO_SAMPLES_DIR = 'samples\\'
|
||||||
|
EXT_MODELS_DIRT = "ppg_extractor\\saved_models"
|
||||||
|
CONV_MODELS_DIRT = "ppg2mel\\saved_models"
|
||||||
|
VOC_MODELS_DIRT = "vocoder\\saved_models"
|
||||||
|
TEMP_SOURCE_AUDIO = "wavs/temp_source.wav"
|
||||||
|
TEMP_TARGET_AUDIO = "wavs/temp_target.wav"
|
||||||
|
TEMP_RESULT_AUDIO = "wavs/temp_result.wav"
|
||||||
|
|
||||||
|
# Load local sample audio as options TODO: load dataset
|
||||||
|
if os.path.isdir(AUDIO_SAMPLES_DIR):
|
||||||
|
audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
|
||||||
|
# Pre-Load models
|
||||||
|
if os.path.isdir(EXT_MODELS_DIRT):
|
||||||
|
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
||||||
|
print("Loaded extractor models: " + str(len(extractors)))
|
||||||
|
else:
|
||||||
|
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
||||||
|
|
||||||
|
if os.path.isdir(CONV_MODELS_DIRT):
|
||||||
|
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
|
||||||
|
print("Loaded convertor models: " + str(len(convertors)))
|
||||||
|
else:
|
||||||
|
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
|
||||||
|
|
||||||
|
if os.path.isdir(VOC_MODELS_DIRT):
|
||||||
|
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
|
||||||
|
print("Loaded vocoders models: " + str(len(vocoders)))
|
||||||
|
else:
|
||||||
|
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
|
||||||
|
|
||||||
|
class Input(BaseModel):
|
||||||
|
local_audio_file: audio_input_selection = Field(
|
||||||
|
..., alias="输入语音(本地wav)",
|
||||||
|
description="选择本地语音文件."
|
||||||
|
)
|
||||||
|
upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
|
||||||
|
description="拖拽或点击上传.", mime_type="audio/wav")
|
||||||
|
local_audio_file_target: audio_input_selection = Field(
|
||||||
|
..., alias="目标语音(本地wav)",
|
||||||
|
description="选择本地语音文件."
|
||||||
|
)
|
||||||
|
upload_audio_file_target: FileContent = Field(default=None, alias="或上传目标语音",
|
||||||
|
description="拖拽或点击上传.", mime_type="audio/wav")
|
||||||
|
extractor: extractors = Field(
|
||||||
|
..., alias="编码模型",
|
||||||
|
description="选择语音编码模型文件."
|
||||||
|
)
|
||||||
|
convertor: convertors = Field(
|
||||||
|
..., alias="转换模型",
|
||||||
|
description="选择语音转换模型文件."
|
||||||
|
)
|
||||||
|
vocoder: vocoders = Field(
|
||||||
|
..., alias="语音编码模型",
|
||||||
|
description="选择语音解码模型文件(目前只支持HifiGan类型)."
|
||||||
|
)
|
||||||
|
|
||||||
|
class AudioEntity(BaseModel):
|
||||||
|
content: bytes
|
||||||
|
mel: Any
|
||||||
|
|
||||||
|
class Output(BaseModel):
|
||||||
|
__root__: tuple[AudioEntity, AudioEntity, AudioEntity]
|
||||||
|
|
||||||
|
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||||
|
"""Custom output UI.
|
||||||
|
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||||
|
"""
|
||||||
|
src, target, result = self.__root__
|
||||||
|
|
||||||
|
streamlit_app.subheader("Synthesized Audio")
|
||||||
|
streamlit_app.audio(result.content, format="audio/wav")
|
||||||
|
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.imshow(src.mel, aspect="equal", interpolation="none")
|
||||||
|
ax.set_title("mel spectrogram(Source Audio)")
|
||||||
|
streamlit_app.pyplot(fig)
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.imshow(target.mel, aspect="equal", interpolation="none")
|
||||||
|
ax.set_title("mel spectrogram(Target Audio)")
|
||||||
|
streamlit_app.pyplot(fig)
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.imshow(result.mel, aspect="equal", interpolation="none")
|
||||||
|
ax.set_title("mel spectrogram(Result Audio)")
|
||||||
|
streamlit_app.pyplot(fig)
|
||||||
|
|
||||||
|
def convert(input: Input) -> Output:
|
||||||
|
"""convert(转换)"""
|
||||||
|
# load models
|
||||||
|
extractor = Extractor.load_model(Path(input.extractor.value))
|
||||||
|
convertor = Convertor.load_model(Path(input.convertor.value))
|
||||||
|
# current_synt = Synthesizer(Path(input.synthesizer.value))
|
||||||
|
gan_vocoder.load_model(Path(input.vocoder.value))
|
||||||
|
|
||||||
|
# load file
|
||||||
|
if input.upload_audio_file != None:
|
||||||
|
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
||||||
|
f.write(input.upload_audio_file.as_bytes())
|
||||||
|
f.seek(0)
|
||||||
|
src_wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
|
||||||
|
else:
|
||||||
|
src_wav, sample_rate = librosa.load(input.local_audio_file.value)
|
||||||
|
write(TEMP_SOURCE_AUDIO, sample_rate, src_wav) #Make sure we get the correct wav
|
||||||
|
|
||||||
|
if input.upload_audio_file_target != None:
|
||||||
|
with open(TEMP_TARGET_AUDIO, "w+b") as f:
|
||||||
|
f.write(input.upload_audio_file_target.as_bytes())
|
||||||
|
f.seek(0)
|
||||||
|
ref_wav, _ = librosa.load(TEMP_TARGET_AUDIO)
|
||||||
|
else:
|
||||||
|
ref_wav, _ = librosa.load(input.local_audio_file_target.value)
|
||||||
|
write(TEMP_TARGET_AUDIO, sample_rate, ref_wav) #Make sure we get the correct wav
|
||||||
|
|
||||||
|
ppg = extractor.extract_from_wav(src_wav)
|
||||||
|
# Import necessary dependency of Voice Conversion
|
||||||
|
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
|
||||||
|
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
|
||||||
|
speacker_encoder.load_model(Path("encoder/saved_models/pretrained_bak_5805000.pt"))
|
||||||
|
embed = speacker_encoder.embed_utterance(ref_wav)
|
||||||
|
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
|
||||||
|
min_len = min(ppg.shape[1], len(lf0_uv))
|
||||||
|
ppg = ppg[:, :min_len]
|
||||||
|
lf0_uv = lf0_uv[:min_len]
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
_, mel_pred, att_ws = convertor.inference(
|
||||||
|
ppg,
|
||||||
|
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
|
||||||
|
spembs=torch.from_numpy(embed).unsqueeze(0).to(device),
|
||||||
|
)
|
||||||
|
mel_pred= mel_pred.transpose(0, 1)
|
||||||
|
breaks = [mel_pred.shape[1]]
|
||||||
|
mel_pred= mel_pred.detach().cpu().numpy()
|
||||||
|
|
||||||
|
# synthesize and vocode
|
||||||
|
wav, sample_rate = gan_vocoder.infer_waveform(mel_pred)
|
||||||
|
|
||||||
|
# write and output
|
||||||
|
write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
||||||
|
with open(TEMP_SOURCE_AUDIO, "rb") as f:
|
||||||
|
source_file = f.read()
|
||||||
|
with open(TEMP_TARGET_AUDIO, "rb") as f:
|
||||||
|
target_file = f.read()
|
||||||
|
with open(TEMP_RESULT_AUDIO, "rb") as f:
|
||||||
|
result_file = f.read()
|
||||||
|
|
||||||
|
|
||||||
|
return Output(__root__=(AudioEntity(content=source_file, mel=Synthesizer.make_spectrogram(src_wav)), AudioEntity(content=target_file, mel=Synthesizer.make_spectrogram(ref_wav)), AudioEntity(content=result_file, mel=Synthesizer.make_spectrogram(wav))))
|
||||||
2
mkgui/base/__init__.py
Normal file
2
mkgui/base/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
|
||||||
|
from .core import Opyrator
|
||||||
1
mkgui/base/api/__init__.py
Normal file
1
mkgui/base/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .fastapi_app import create_api
|
||||||
102
mkgui/base/api/fastapi_utils.py
Normal file
102
mkgui/base/api/fastapi_utils.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
"""Collection of utilities for FastAPI apps."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import Any, Type
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Form
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
def as_form(cls: Type[BaseModel]) -> Any:
|
||||||
|
"""Adds an as_form class method to decorated models.
|
||||||
|
|
||||||
|
The as_form class method can be used with FastAPI endpoints
|
||||||
|
"""
|
||||||
|
new_params = [
|
||||||
|
inspect.Parameter(
|
||||||
|
field.alias,
|
||||||
|
inspect.Parameter.POSITIONAL_ONLY,
|
||||||
|
default=(Form(field.default) if not field.required else Form(...)),
|
||||||
|
)
|
||||||
|
for field in cls.__fields__.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _as_form(**data): # type: ignore
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
sig = inspect.signature(_as_form)
|
||||||
|
sig = sig.replace(parameters=new_params)
|
||||||
|
_as_form.__signature__ = sig # type: ignore
|
||||||
|
setattr(cls, "as_form", _as_form)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
def patch_fastapi(app: FastAPI) -> None:
|
||||||
|
"""Patch function to allow relative url resolution.
|
||||||
|
|
||||||
|
This patch is required to make fastapi fully functional with a relative url path.
|
||||||
|
This code snippet can be copy-pasted to any Fastapi application.
|
||||||
|
"""
|
||||||
|
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import HTMLResponse
|
||||||
|
|
||||||
|
async def redoc_ui_html(req: Request) -> HTMLResponse:
|
||||||
|
assert app.openapi_url is not None
|
||||||
|
redoc_ui = get_redoc_html(
|
||||||
|
openapi_url="./" + app.openapi_url.lstrip("/"),
|
||||||
|
title=app.title + " - Redoc UI",
|
||||||
|
)
|
||||||
|
|
||||||
|
return HTMLResponse(redoc_ui.body.decode("utf-8"))
|
||||||
|
|
||||||
|
async def swagger_ui_html(req: Request) -> HTMLResponse:
|
||||||
|
assert app.openapi_url is not None
|
||||||
|
swagger_ui = get_swagger_ui_html(
|
||||||
|
openapi_url="./" + app.openapi_url.lstrip("/"),
|
||||||
|
title=app.title + " - Swagger UI",
|
||||||
|
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
# insert request interceptor to have all request run on relativ path
|
||||||
|
request_interceptor = (
|
||||||
|
"requestInterceptor: (e) => {"
|
||||||
|
"\n\t\t\tvar url = window.location.origin + window.location.pathname"
|
||||||
|
'\n\t\t\turl = url.substring( 0, url.lastIndexOf( "/" ) + 1);'
|
||||||
|
"\n\t\t\turl = e.url.replace(/http(s)?:\/\/[^/]*\//i, url);" # noqa: W605
|
||||||
|
"\n\t\t\te.contextUrl = url"
|
||||||
|
"\n\t\t\te.url = url"
|
||||||
|
"\n\t\t\treturn e;}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return HTMLResponse(
|
||||||
|
swagger_ui.body.decode("utf-8").replace(
|
||||||
|
"dom_id: '#swagger-ui',",
|
||||||
|
"dom_id: '#swagger-ui',\n\t\t" + request_interceptor + ",",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# remove old docs route and add our patched route
|
||||||
|
routes_new = []
|
||||||
|
for app_route in app.routes:
|
||||||
|
if app_route.path == "/docs": # type: ignore
|
||||||
|
continue
|
||||||
|
|
||||||
|
if app_route.path == "/redoc": # type: ignore
|
||||||
|
continue
|
||||||
|
|
||||||
|
routes_new.append(app_route)
|
||||||
|
|
||||||
|
app.router.routes = routes_new
|
||||||
|
|
||||||
|
assert app.docs_url is not None
|
||||||
|
app.add_route(app.docs_url, swagger_ui_html, include_in_schema=False)
|
||||||
|
assert app.redoc_url is not None
|
||||||
|
app.add_route(app.redoc_url, redoc_ui_html, include_in_schema=False)
|
||||||
|
|
||||||
|
# Make graphql realtive
|
||||||
|
from starlette import graphql
|
||||||
|
|
||||||
|
graphql.GRAPHIQL = graphql.GRAPHIQL.replace(
|
||||||
|
"({{REQUEST_PATH}}", '("." + {{REQUEST_PATH}}'
|
||||||
|
)
|
||||||
0
mkgui/base/components/__init__.py
Normal file
0
mkgui/base/components/__init__.py
Normal file
43
mkgui/base/components/outputs.py
Normal file
43
mkgui/base/components/outputs.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ScoredLabel(BaseModel):
|
||||||
|
label: str
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationOutput(BaseModel):
|
||||||
|
__root__: List[ScoredLabel]
|
||||||
|
|
||||||
|
def __iter__(self): # type: ignore
|
||||||
|
return iter(self.__root__)
|
||||||
|
|
||||||
|
def __getitem__(self, item): # type: ignore
|
||||||
|
return self.__root__[item]
|
||||||
|
|
||||||
|
def render_output_ui(self, streamlit) -> None: # type: ignore
|
||||||
|
import plotly.express as px
|
||||||
|
|
||||||
|
sorted_predictions = sorted(
|
||||||
|
[prediction.dict() for prediction in self.__root__],
|
||||||
|
key=lambda k: k["score"],
|
||||||
|
)
|
||||||
|
|
||||||
|
num_labels = len(sorted_predictions)
|
||||||
|
if len(sorted_predictions) > 10:
|
||||||
|
num_labels = streamlit.slider(
|
||||||
|
"Maximum labels to show: ",
|
||||||
|
min_value=1,
|
||||||
|
max_value=len(sorted_predictions),
|
||||||
|
value=len(sorted_predictions),
|
||||||
|
)
|
||||||
|
fig = px.bar(
|
||||||
|
sorted_predictions[len(sorted_predictions) - num_labels :],
|
||||||
|
x="score",
|
||||||
|
y="label",
|
||||||
|
orientation="h",
|
||||||
|
)
|
||||||
|
streamlit.plotly_chart(fig, use_container_width=True)
|
||||||
|
# fig.show()
|
||||||
46
mkgui/base/components/types.py
Normal file
46
mkgui/base/components/types.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import base64
|
||||||
|
from typing import Any, Dict, overload
|
||||||
|
|
||||||
|
|
||||||
|
class FileContent(str):
|
||||||
|
def as_bytes(self) -> bytes:
|
||||||
|
return base64.b64decode(self, validate=True)
|
||||||
|
|
||||||
|
def as_str(self) -> str:
|
||||||
|
return self.as_bytes().decode()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||||
|
field_schema.update(format="byte")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_validators__(cls) -> Any: # type: ignore
|
||||||
|
yield cls.validate
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, value: Any) -> "FileContent":
|
||||||
|
if isinstance(value, FileContent):
|
||||||
|
return value
|
||||||
|
elif isinstance(value, str):
|
||||||
|
return FileContent(value)
|
||||||
|
elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||||
|
return FileContent(base64.b64encode(value).decode())
|
||||||
|
else:
|
||||||
|
raise Exception("Wrong type")
|
||||||
|
|
||||||
|
# # 暂时无法使用,因为浏览器中没有考虑选择文件夹
|
||||||
|
# class DirectoryContent(FileContent):
|
||||||
|
# @classmethod
|
||||||
|
# def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||||
|
# field_schema.update(format="path")
|
||||||
|
|
||||||
|
# @classmethod
|
||||||
|
# def validate(cls, value: Any) -> "DirectoryContent":
|
||||||
|
# if isinstance(value, DirectoryContent):
|
||||||
|
# return value
|
||||||
|
# elif isinstance(value, str):
|
||||||
|
# return DirectoryContent(value)
|
||||||
|
# elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||||
|
# return DirectoryContent(base64.b64encode(value).decode())
|
||||||
|
# else:
|
||||||
|
# raise Exception("Wrong type")
|
||||||
203
mkgui/base/core.py
Normal file
203
mkgui/base/core.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import re
|
||||||
|
from typing import Any, Callable, Type, Union, get_type_hints
|
||||||
|
|
||||||
|
from pydantic import BaseModel, parse_raw_as
|
||||||
|
from pydantic.tools import parse_obj_as
|
||||||
|
|
||||||
|
|
||||||
|
def name_to_title(name: str) -> str:
|
||||||
|
"""Converts a camelCase or snake_case name to title case."""
|
||||||
|
# If camelCase -> convert to snake case
|
||||||
|
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
||||||
|
name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
|
||||||
|
# Convert to title case
|
||||||
|
return name.replace("_", " ").strip().title()
|
||||||
|
|
||||||
|
|
||||||
|
def is_compatible_type(type: Type) -> bool:
|
||||||
|
"""Returns `True` if the type is opyrator-compatible."""
|
||||||
|
try:
|
||||||
|
if issubclass(type, BaseModel):
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
# valid list type
|
||||||
|
if type.__origin__ is list and issubclass(type.__args__[0], BaseModel):
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_input_type(func: Callable) -> Type:
|
||||||
|
"""Returns the input type of a given function (callable).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function for which to get the input type.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the function does not have a valid input type annotation.
|
||||||
|
"""
|
||||||
|
type_hints = get_type_hints(func)
|
||||||
|
|
||||||
|
if "input" not in type_hints:
|
||||||
|
raise ValueError(
|
||||||
|
"The callable MUST have a parameter with the name `input` with typing annotation. "
|
||||||
|
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
|
||||||
|
)
|
||||||
|
|
||||||
|
input_type = type_hints["input"]
|
||||||
|
|
||||||
|
if not is_compatible_type(input_type):
|
||||||
|
raise ValueError(
|
||||||
|
"The `input` parameter MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: return warning if more than one input parameters
|
||||||
|
|
||||||
|
return input_type
|
||||||
|
|
||||||
|
|
||||||
|
def get_output_type(func: Callable) -> Type:
|
||||||
|
"""Returns the output type of a given function (callable).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function for which to get the output type.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the function does not have a valid output type annotation.
|
||||||
|
"""
|
||||||
|
type_hints = get_type_hints(func)
|
||||||
|
if "return" not in type_hints:
|
||||||
|
raise ValueError(
|
||||||
|
"The return type of the callable MUST be annotated with type hints."
|
||||||
|
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
|
||||||
|
)
|
||||||
|
|
||||||
|
output_type = type_hints["return"]
|
||||||
|
|
||||||
|
if not is_compatible_type(output_type):
|
||||||
|
raise ValueError(
|
||||||
|
"The return value MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
|
||||||
|
)
|
||||||
|
|
||||||
|
return output_type
|
||||||
|
|
||||||
|
|
||||||
|
def get_callable(import_string: str) -> Callable:
|
||||||
|
"""Import a callable from an string."""
|
||||||
|
callable_seperator = ":"
|
||||||
|
if callable_seperator not in import_string:
|
||||||
|
# Use dot as seperator
|
||||||
|
callable_seperator = "."
|
||||||
|
|
||||||
|
if callable_seperator not in import_string:
|
||||||
|
raise ValueError("The callable path MUST specify the function. ")
|
||||||
|
|
||||||
|
mod_name, callable_name = import_string.rsplit(callable_seperator, 1)
|
||||||
|
mod = importlib.import_module(mod_name)
|
||||||
|
return getattr(mod, callable_name)
|
||||||
|
|
||||||
|
|
||||||
|
class Opyrator:
|
||||||
|
def __init__(self, func: Union[Callable, str]) -> None:
|
||||||
|
if isinstance(func, str):
|
||||||
|
# Try to load the function from a string notion
|
||||||
|
self.function = get_callable(func)
|
||||||
|
else:
|
||||||
|
self.function = func
|
||||||
|
|
||||||
|
self._action = "Execute"
|
||||||
|
self._input_type = None
|
||||||
|
self._output_type = None
|
||||||
|
|
||||||
|
if not callable(self.function):
|
||||||
|
raise ValueError("The provided function parameters is not a callable.")
|
||||||
|
|
||||||
|
if inspect.isclass(self.function):
|
||||||
|
raise ValueError(
|
||||||
|
"The provided callable is an uninitialized Class. This is not allowed."
|
||||||
|
)
|
||||||
|
|
||||||
|
if inspect.isfunction(self.function):
|
||||||
|
# The provided callable is a function
|
||||||
|
self._input_type = get_input_type(self.function)
|
||||||
|
self._output_type = get_output_type(self.function)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get name
|
||||||
|
self._name = name_to_title(self.function.__name__)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get description from function
|
||||||
|
doc_string = inspect.getdoc(self.function)
|
||||||
|
if doc_string:
|
||||||
|
self._action = doc_string
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
elif hasattr(self.function, "__call__"):
|
||||||
|
# The provided callable is a function
|
||||||
|
self._input_type = get_input_type(self.function.__call__) # type: ignore
|
||||||
|
self._output_type = get_output_type(self.function.__call__) # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get name
|
||||||
|
self._name = name_to_title(type(self.function).__name__)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get action from
|
||||||
|
doc_string = inspect.getdoc(self.function.__call__) # type: ignore
|
||||||
|
if doc_string:
|
||||||
|
self._action = doc_string
|
||||||
|
|
||||||
|
if (
|
||||||
|
not self._action
|
||||||
|
or self._action == "Call"
|
||||||
|
):
|
||||||
|
# Get docstring from class instead of __call__ function
|
||||||
|
doc_string = inspect.getdoc(self.function)
|
||||||
|
if doc_string:
|
||||||
|
self._action = doc_string
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown callable type.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action(self) -> str:
|
||||||
|
return self._action
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_type(self) -> Any:
|
||||||
|
return self._input_type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_type(self) -> Any:
|
||||||
|
return self._output_type
|
||||||
|
|
||||||
|
def __call__(self, input: Any, **kwargs: Any) -> Any:
|
||||||
|
|
||||||
|
input_obj = input
|
||||||
|
|
||||||
|
if isinstance(input, str):
|
||||||
|
# Allow json input
|
||||||
|
input_obj = parse_raw_as(self.input_type, input)
|
||||||
|
|
||||||
|
if isinstance(input, dict):
|
||||||
|
# Allow dict input
|
||||||
|
input_obj = parse_obj_as(self.input_type, input)
|
||||||
|
|
||||||
|
return self.function(input_obj, **kwargs)
|
||||||
1
mkgui/base/ui/__init__.py
Normal file
1
mkgui/base/ui/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .streamlit_ui import render_streamlit_ui
|
||||||
129
mkgui/base/ui/schema_utils.py
Normal file
129
mkgui/base/ui/schema_utils.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_reference(reference: str, references: Dict) -> Dict:
|
||||||
|
return references[reference.split("/")[-1]]
|
||||||
|
|
||||||
|
|
||||||
|
def get_single_reference_item(property: Dict, references: Dict) -> Dict:
|
||||||
|
# Ref can either be directly in the properties or the first element of allOf
|
||||||
|
reference = property.get("$ref")
|
||||||
|
if reference is None:
|
||||||
|
reference = property["allOf"][0]["$ref"]
|
||||||
|
return resolve_reference(reference, references)
|
||||||
|
|
||||||
|
|
||||||
|
def is_single_string_property(property: Dict) -> bool:
|
||||||
|
return property.get("type") == "string"
|
||||||
|
|
||||||
|
|
||||||
|
def is_single_datetime_property(property: Dict) -> bool:
|
||||||
|
if property.get("type") != "string":
|
||||||
|
return False
|
||||||
|
return property.get("format") in ["date-time", "time", "date"]
|
||||||
|
|
||||||
|
|
||||||
|
def is_single_boolean_property(property: Dict) -> bool:
|
||||||
|
return property.get("type") == "boolean"
|
||||||
|
|
||||||
|
|
||||||
|
def is_single_number_property(property: Dict) -> bool:
|
||||||
|
return property.get("type") in ["integer", "number"]
|
||||||
|
|
||||||
|
|
||||||
|
def is_single_file_property(property: Dict) -> bool:
|
||||||
|
if property.get("type") != "string":
|
||||||
|
return False
|
||||||
|
# TODO: binary?
|
||||||
|
return property.get("format") == "byte"
|
||||||
|
|
||||||
|
|
||||||
|
def is_single_directory_property(property: Dict) -> bool:
|
||||||
|
if property.get("type") != "string":
|
||||||
|
return False
|
||||||
|
return property.get("format") == "path"
|
||||||
|
|
||||||
|
def is_multi_enum_property(property: Dict, references: Dict) -> bool:
|
||||||
|
if property.get("type") != "array":
|
||||||
|
return False
|
||||||
|
|
||||||
|
if property.get("uniqueItems") is not True:
|
||||||
|
# Only relevant if it is a set or other datastructures with unique items
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ = resolve_reference(property["items"]["$ref"], references)["enum"]
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_single_enum_property(property: Dict, references: Dict) -> bool:
|
||||||
|
try:
|
||||||
|
_ = get_single_reference_item(property, references)["enum"]
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_single_dict_property(property: Dict) -> bool:
|
||||||
|
if property.get("type") != "object":
|
||||||
|
return False
|
||||||
|
return "additionalProperties" in property
|
||||||
|
|
||||||
|
|
||||||
|
def is_single_reference(property: Dict) -> bool:
|
||||||
|
if property.get("type") is not None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return bool(property.get("$ref"))
|
||||||
|
|
||||||
|
|
||||||
|
def is_multi_file_property(property: Dict) -> bool:
|
||||||
|
if property.get("type") != "array":
|
||||||
|
return False
|
||||||
|
|
||||||
|
if property.get("items") is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# TODO: binary
|
||||||
|
return property["items"]["format"] == "byte"
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_single_object(property: Dict, references: Dict) -> bool:
|
||||||
|
try:
|
||||||
|
object_reference = get_single_reference_item(property, references)
|
||||||
|
if object_reference["type"] != "object":
|
||||||
|
return False
|
||||||
|
return "properties" in object_reference
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_property_list(property: Dict) -> bool:
|
||||||
|
if property.get("type") != "array":
|
||||||
|
return False
|
||||||
|
|
||||||
|
if property.get("items") is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
return property["items"]["type"] in ["string", "number", "integer"]
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_object_list_property(property: Dict, references: Dict) -> bool:
|
||||||
|
if property.get("type") != "array":
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
object_reference = resolve_reference(property["items"]["$ref"], references)
|
||||||
|
if object_reference["type"] != "object":
|
||||||
|
return False
|
||||||
|
return "properties" in object_reference
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
885
mkgui/base/ui/streamlit_ui.py
Normal file
885
mkgui/base/ui/streamlit_ui.py
Normal file
@@ -0,0 +1,885 @@
|
|||||||
|
import datetime
|
||||||
|
import inspect
|
||||||
|
import mimetypes
|
||||||
|
import sys
|
||||||
|
from os import getcwd, unlink
|
||||||
|
from platform import system
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
from typing import Any, Callable, Dict, List, Type
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import streamlit as st
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import BaseModel, ValidationError, parse_obj_as
|
||||||
|
|
||||||
|
from mkgui.base import Opyrator
|
||||||
|
from mkgui.base.core import name_to_title
|
||||||
|
from mkgui.base.ui import schema_utils
|
||||||
|
from mkgui.base.ui.streamlit_utils import CUSTOM_STREAMLIT_CSS
|
||||||
|
|
||||||
|
STREAMLIT_RUNNER_SNIPPET = """
|
||||||
|
from mkgui.base.ui import render_streamlit_ui
|
||||||
|
from mkgui.base import Opyrator
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
# TODO: Make it configurable
|
||||||
|
# Page config can only be setup once
|
||||||
|
st.set_page_config(
|
||||||
|
page_title="MockingBird",
|
||||||
|
page_icon="🧊",
|
||||||
|
layout="wide")
|
||||||
|
|
||||||
|
render_streamlit_ui()
|
||||||
|
"""
|
||||||
|
|
||||||
|
# with st.spinner("Loading MockingBird GUI. Please wait..."):
|
||||||
|
# opyrator = Opyrator("{opyrator_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def launch_ui(port: int = 8501) -> None:
|
||||||
|
with NamedTemporaryFile(
|
||||||
|
suffix=".py", mode="w", encoding="utf-8", delete=False
|
||||||
|
) as f:
|
||||||
|
f.write(STREAMLIT_RUNNER_SNIPPET)
|
||||||
|
f.seek(0)
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
python_path = f'PYTHONPATH="$PYTHONPATH:{getcwd()}"'
|
||||||
|
if system() == "Windows":
|
||||||
|
python_path = f"set PYTHONPATH=%PYTHONPATH%;{getcwd()} &&"
|
||||||
|
subprocess.run(
|
||||||
|
f"""set STREAMLIT_GLOBAL_SHOW_WARNING_ON_DIRECT_EXECUTION=false""",
|
||||||
|
shell=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
subprocess.run(
|
||||||
|
f"""{python_path} "{sys.executable}" -m streamlit run --server.port={port} --server.headless=True --runner.magicEnabled=False --server.maxUploadSize=50 --browser.gatherUsageStats=False {f.name}""",
|
||||||
|
shell=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
f.close()
|
||||||
|
unlink(f.name)
|
||||||
|
|
||||||
|
|
||||||
|
def function_has_named_arg(func: Callable, parameter: str) -> bool:
|
||||||
|
try:
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
for param in sig.parameters.values():
|
||||||
|
if param.name == "input":
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def has_output_ui_renderer(data_item: BaseModel) -> bool:
|
||||||
|
return hasattr(data_item, "render_output_ui")
|
||||||
|
|
||||||
|
|
||||||
|
def has_input_ui_renderer(input_class: Type[BaseModel]) -> bool:
|
||||||
|
return hasattr(input_class, "render_input_ui")
|
||||||
|
|
||||||
|
|
||||||
|
def is_compatible_audio(mime_type: str) -> bool:
|
||||||
|
return mime_type in ["audio/mpeg", "audio/ogg", "audio/wav"]
|
||||||
|
|
||||||
|
|
||||||
|
def is_compatible_image(mime_type: str) -> bool:
|
||||||
|
return mime_type in ["image/png", "image/jpeg"]
|
||||||
|
|
||||||
|
|
||||||
|
def is_compatible_video(mime_type: str) -> bool:
|
||||||
|
return mime_type in ["video/mp4"]
|
||||||
|
|
||||||
|
|
||||||
|
class InputUI:
|
||||||
|
def __init__(self, session_state, input_class: Type[BaseModel]):
|
||||||
|
self._session_state = session_state
|
||||||
|
self._input_class = input_class
|
||||||
|
|
||||||
|
self._schema_properties = input_class.schema(by_alias=True).get(
|
||||||
|
"properties", {}
|
||||||
|
)
|
||||||
|
self._schema_references = input_class.schema(by_alias=True).get(
|
||||||
|
"definitions", {}
|
||||||
|
)
|
||||||
|
|
||||||
|
def render_ui(self, streamlit_app_root) -> None:
|
||||||
|
if has_input_ui_renderer(self._input_class):
|
||||||
|
# The input model has a rendering function
|
||||||
|
# The rendering also returns the current state of input data
|
||||||
|
self._session_state.input_data = self._input_class.render_input_ui( # type: ignore
|
||||||
|
st, self._session_state.input_data
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# print(self._schema_properties)
|
||||||
|
for property_key in self._schema_properties.keys():
|
||||||
|
property = self._schema_properties[property_key]
|
||||||
|
|
||||||
|
if not property.get("title"):
|
||||||
|
# Set property key as fallback title
|
||||||
|
property["title"] = name_to_title(property_key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if "input_data" in self._session_state:
|
||||||
|
self._store_value(
|
||||||
|
property_key,
|
||||||
|
self._render_property(streamlit_app_root, property_key, property),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print("Exception!", e)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_default_streamlit_input_kwargs(self, key: str, property: Dict) -> Dict:
|
||||||
|
streamlit_kwargs = {
|
||||||
|
"label": property.get("title"),
|
||||||
|
"key": key,
|
||||||
|
}
|
||||||
|
|
||||||
|
if property.get("description"):
|
||||||
|
streamlit_kwargs["help"] = property.get("description")
|
||||||
|
return streamlit_kwargs
|
||||||
|
|
||||||
|
def _store_value(self, key: str, value: Any) -> None:
|
||||||
|
data_element = self._session_state.input_data
|
||||||
|
key_elements = key.split(".")
|
||||||
|
for i, key_element in enumerate(key_elements):
|
||||||
|
if i == len(key_elements) - 1:
|
||||||
|
# add value to this element
|
||||||
|
data_element[key_element] = value
|
||||||
|
return
|
||||||
|
if key_element not in data_element:
|
||||||
|
data_element[key_element] = {}
|
||||||
|
data_element = data_element[key_element]
|
||||||
|
|
||||||
|
def _get_value(self, key: str) -> Any:
|
||||||
|
data_element = self._session_state.input_data
|
||||||
|
key_elements = key.split(".")
|
||||||
|
for i, key_element in enumerate(key_elements):
|
||||||
|
if i == len(key_elements) - 1:
|
||||||
|
# add value to this element
|
||||||
|
if key_element not in data_element:
|
||||||
|
return None
|
||||||
|
return data_element[key_element]
|
||||||
|
if key_element not in data_element:
|
||||||
|
data_element[key_element] = {}
|
||||||
|
data_element = data_element[key_element]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _render_single_datetime_input(
|
||||||
|
self, streamlit_app: st, key: str, property: Dict
|
||||||
|
) -> Any:
|
||||||
|
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||||
|
|
||||||
|
if property.get("format") == "time":
|
||||||
|
if property.get("default"):
|
||||||
|
try:
|
||||||
|
streamlit_kwargs["value"] = datetime.time.fromisoformat( # type: ignore
|
||||||
|
property.get("default")
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return streamlit_app.time_input(**streamlit_kwargs)
|
||||||
|
elif property.get("format") == "date":
|
||||||
|
if property.get("default"):
|
||||||
|
try:
|
||||||
|
streamlit_kwargs["value"] = datetime.date.fromisoformat( # type: ignore
|
||||||
|
property.get("default")
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return streamlit_app.date_input(**streamlit_kwargs)
|
||||||
|
elif property.get("format") == "date-time":
|
||||||
|
if property.get("default"):
|
||||||
|
try:
|
||||||
|
streamlit_kwargs["value"] = datetime.datetime.fromisoformat( # type: ignore
|
||||||
|
property.get("default")
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
with streamlit_app.container():
|
||||||
|
streamlit_app.subheader(streamlit_kwargs.get("label"))
|
||||||
|
if streamlit_kwargs.get("description"):
|
||||||
|
streamlit_app.text(streamlit_kwargs.get("description"))
|
||||||
|
selected_date = None
|
||||||
|
selected_time = None
|
||||||
|
date_col, time_col = streamlit_app.columns(2)
|
||||||
|
with date_col:
|
||||||
|
date_kwargs = {"label": "Date", "key": key + "-date-input"}
|
||||||
|
if streamlit_kwargs.get("value"):
|
||||||
|
try:
|
||||||
|
date_kwargs["value"] = streamlit_kwargs.get( # type: ignore
|
||||||
|
"value"
|
||||||
|
).date()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
selected_date = streamlit_app.date_input(**date_kwargs)
|
||||||
|
|
||||||
|
with time_col:
|
||||||
|
time_kwargs = {"label": "Time", "key": key + "-time-input"}
|
||||||
|
if streamlit_kwargs.get("value"):
|
||||||
|
try:
|
||||||
|
time_kwargs["value"] = streamlit_kwargs.get( # type: ignore
|
||||||
|
"value"
|
||||||
|
).time()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
selected_time = streamlit_app.time_input(**time_kwargs)
|
||||||
|
return datetime.datetime.combine(selected_date, selected_time)
|
||||||
|
else:
|
||||||
|
streamlit_app.warning(
|
||||||
|
"Date format is not supported: " + str(property.get("format"))
|
||||||
|
)
|
||||||
|
|
||||||
|
def _render_single_file_input(
|
||||||
|
self, streamlit_app: st, key: str, property: Dict
|
||||||
|
) -> Any:
|
||||||
|
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||||
|
file_extension = None
|
||||||
|
if "mime_type" in property:
|
||||||
|
file_extension = mimetypes.guess_extension(property["mime_type"])
|
||||||
|
|
||||||
|
uploaded_file = streamlit_app.file_uploader(
|
||||||
|
**streamlit_kwargs, accept_multiple_files=False, type=file_extension
|
||||||
|
)
|
||||||
|
if uploaded_file is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
bytes = uploaded_file.getvalue()
|
||||||
|
if property.get("mime_type"):
|
||||||
|
if is_compatible_audio(property["mime_type"]):
|
||||||
|
# Show audio
|
||||||
|
streamlit_app.audio(bytes, format=property.get("mime_type"))
|
||||||
|
if is_compatible_image(property["mime_type"]):
|
||||||
|
# Show image
|
||||||
|
streamlit_app.image(bytes)
|
||||||
|
if is_compatible_video(property["mime_type"]):
|
||||||
|
# Show video
|
||||||
|
streamlit_app.video(bytes, format=property.get("mime_type"))
|
||||||
|
return bytes
|
||||||
|
|
||||||
|
def _render_single_string_input(
|
||||||
|
self, streamlit_app: st, key: str, property: Dict
|
||||||
|
) -> Any:
|
||||||
|
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||||
|
|
||||||
|
if property.get("default"):
|
||||||
|
streamlit_kwargs["value"] = property.get("default")
|
||||||
|
elif property.get("example"):
|
||||||
|
# TODO: also use example for other property types
|
||||||
|
# Use example as value if it is provided
|
||||||
|
streamlit_kwargs["value"] = property.get("example")
|
||||||
|
|
||||||
|
if property.get("maxLength") is not None:
|
||||||
|
streamlit_kwargs["max_chars"] = property.get("maxLength")
|
||||||
|
|
||||||
|
if (
|
||||||
|
property.get("format")
|
||||||
|
or (
|
||||||
|
property.get("maxLength") is not None
|
||||||
|
and int(property.get("maxLength")) < 140 # type: ignore
|
||||||
|
)
|
||||||
|
or property.get("writeOnly")
|
||||||
|
):
|
||||||
|
# If any format is set, use single text input
|
||||||
|
# If max chars is set to less than 140, use single text input
|
||||||
|
# If write only -> password field
|
||||||
|
if property.get("writeOnly"):
|
||||||
|
streamlit_kwargs["type"] = "password"
|
||||||
|
return streamlit_app.text_input(**streamlit_kwargs)
|
||||||
|
else:
|
||||||
|
# Otherwise use multiline text area
|
||||||
|
return streamlit_app.text_area(**streamlit_kwargs)
|
||||||
|
|
||||||
|
def _render_multi_enum_input(
|
||||||
|
self, streamlit_app: st, key: str, property: Dict
|
||||||
|
) -> Any:
|
||||||
|
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||||
|
reference_item = schema_utils.resolve_reference(
|
||||||
|
property["items"]["$ref"], self._schema_references
|
||||||
|
)
|
||||||
|
# TODO: how to select defaults
|
||||||
|
return streamlit_app.multiselect(
|
||||||
|
**streamlit_kwargs, options=reference_item["enum"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _render_single_enum_input(
|
||||||
|
self, streamlit_app: st, key: str, property: Dict
|
||||||
|
) -> Any:
|
||||||
|
|
||||||
|
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||||
|
reference_item = schema_utils.get_single_reference_item(
|
||||||
|
property, self._schema_references
|
||||||
|
)
|
||||||
|
|
||||||
|
if property.get("default") is not None:
|
||||||
|
try:
|
||||||
|
streamlit_kwargs["index"] = reference_item["enum"].index(
|
||||||
|
property.get("default")
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Use default selection
|
||||||
|
pass
|
||||||
|
|
||||||
|
return streamlit_app.selectbox(
|
||||||
|
**streamlit_kwargs, options=reference_item["enum"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _render_single_dict_input(
|
||||||
|
self, streamlit_app: st, key: str, property: Dict
|
||||||
|
) -> Any:
|
||||||
|
|
||||||
|
# Add title and subheader
|
||||||
|
streamlit_app.subheader(property.get("title"))
|
||||||
|
if property.get("description"):
|
||||||
|
streamlit_app.markdown(property.get("description"))
|
||||||
|
|
||||||
|
streamlit_app.markdown("---")
|
||||||
|
|
||||||
|
current_dict = self._get_value(key)
|
||||||
|
if not current_dict:
|
||||||
|
current_dict = {}
|
||||||
|
|
||||||
|
key_col, value_col = streamlit_app.columns(2)
|
||||||
|
|
||||||
|
with key_col:
|
||||||
|
updated_key = streamlit_app.text_input(
|
||||||
|
"Key", value="", key=key + "-new-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
with value_col:
|
||||||
|
# TODO: also add boolean?
|
||||||
|
value_kwargs = {"label": "Value", "key": key + "-new-value"}
|
||||||
|
if property["additionalProperties"].get("type") == "integer":
|
||||||
|
value_kwargs["value"] = 0 # type: ignore
|
||||||
|
updated_value = streamlit_app.number_input(**value_kwargs)
|
||||||
|
elif property["additionalProperties"].get("type") == "number":
|
||||||
|
value_kwargs["value"] = 0.0 # type: ignore
|
||||||
|
value_kwargs["format"] = "%f"
|
||||||
|
updated_value = streamlit_app.number_input(**value_kwargs)
|
||||||
|
else:
|
||||||
|
value_kwargs["value"] = ""
|
||||||
|
updated_value = streamlit_app.text_input(**value_kwargs)
|
||||||
|
|
||||||
|
streamlit_app.markdown("---")
|
||||||
|
|
||||||
|
with streamlit_app.container():
|
||||||
|
clear_col, add_col = streamlit_app.columns([1, 2])
|
||||||
|
|
||||||
|
with clear_col:
|
||||||
|
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
||||||
|
current_dict = {}
|
||||||
|
|
||||||
|
with add_col:
|
||||||
|
if (
|
||||||
|
streamlit_app.button("Add Item", key=key + "-add-item")
|
||||||
|
and updated_key
|
||||||
|
):
|
||||||
|
current_dict[updated_key] = updated_value
|
||||||
|
|
||||||
|
streamlit_app.write(current_dict)
|
||||||
|
|
||||||
|
return current_dict
|
||||||
|
|
||||||
|
def _render_single_reference(
|
||||||
|
self, streamlit_app: st, key: str, property: Dict
|
||||||
|
) -> Any:
|
||||||
|
reference_item = schema_utils.get_single_reference_item(
|
||||||
|
property, self._schema_references
|
||||||
|
)
|
||||||
|
return self._render_property(streamlit_app, key, reference_item)
|
||||||
|
|
||||||
|
def _render_multi_file_input(
|
||||||
|
self, streamlit_app: st, key: str, property: Dict
|
||||||
|
) -> Any:
|
||||||
|
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||||
|
|
||||||
|
file_extension = None
|
||||||
|
if "mime_type" in property:
|
||||||
|
file_extension = mimetypes.guess_extension(property["mime_type"])
|
||||||
|
|
||||||
|
uploaded_files = streamlit_app.file_uploader(
|
||||||
|
**streamlit_kwargs, accept_multiple_files=True, type=file_extension
|
||||||
|
)
|
||||||
|
uploaded_files_bytes = []
|
||||||
|
if uploaded_files:
|
||||||
|
for uploaded_file in uploaded_files:
|
||||||
|
uploaded_files_bytes.append(uploaded_file.read())
|
||||||
|
return uploaded_files_bytes
|
||||||
|
|
||||||
|
def _render_single_boolean_input(
|
||||||
|
self, streamlit_app: st, key: str, property: Dict
|
||||||
|
) -> Any:
|
||||||
|
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||||
|
|
||||||
|
if property.get("default"):
|
||||||
|
streamlit_kwargs["value"] = property.get("default")
|
||||||
|
return streamlit_app.checkbox(**streamlit_kwargs)
|
||||||
|
|
||||||
|
def _render_single_number_input(
|
||||||
|
self, streamlit_app: st, key: str, property: Dict
|
||||||
|
) -> Any:
|
||||||
|
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||||
|
|
||||||
|
number_transform = int
|
||||||
|
if property.get("type") == "number":
|
||||||
|
number_transform = float # type: ignore
|
||||||
|
streamlit_kwargs["format"] = "%f"
|
||||||
|
|
||||||
|
if "multipleOf" in property:
|
||||||
|
# Set stepcount based on multiple of parameter
|
||||||
|
streamlit_kwargs["step"] = number_transform(property["multipleOf"])
|
||||||
|
elif number_transform == int:
|
||||||
|
# Set step size to 1 as default
|
||||||
|
streamlit_kwargs["step"] = 1
|
||||||
|
elif number_transform == float:
|
||||||
|
# Set step size to 0.01 as default
|
||||||
|
# TODO: adapt to default value
|
||||||
|
streamlit_kwargs["step"] = 0.01
|
||||||
|
|
||||||
|
if "minimum" in property:
|
||||||
|
streamlit_kwargs["min_value"] = number_transform(property["minimum"])
|
||||||
|
if "exclusiveMinimum" in property:
|
||||||
|
streamlit_kwargs["min_value"] = number_transform(
|
||||||
|
property["exclusiveMinimum"] + streamlit_kwargs["step"]
|
||||||
|
)
|
||||||
|
if "maximum" in property:
|
||||||
|
streamlit_kwargs["max_value"] = number_transform(property["maximum"])
|
||||||
|
|
||||||
|
if "exclusiveMaximum" in property:
|
||||||
|
streamlit_kwargs["max_value"] = number_transform(
|
||||||
|
property["exclusiveMaximum"] - streamlit_kwargs["step"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if property.get("default") is not None:
|
||||||
|
streamlit_kwargs["value"] = number_transform(property.get("default")) # type: ignore
|
||||||
|
else:
|
||||||
|
if "min_value" in streamlit_kwargs:
|
||||||
|
streamlit_kwargs["value"] = streamlit_kwargs["min_value"]
|
||||||
|
elif number_transform == int:
|
||||||
|
streamlit_kwargs["value"] = 0
|
||||||
|
else:
|
||||||
|
# Set default value to step
|
||||||
|
streamlit_kwargs["value"] = number_transform(streamlit_kwargs["step"])
|
||||||
|
|
||||||
|
if "min_value" in streamlit_kwargs and "max_value" in streamlit_kwargs:
|
||||||
|
# TODO: Only if less than X steps
|
||||||
|
return streamlit_app.slider(**streamlit_kwargs)
|
||||||
|
else:
|
||||||
|
return streamlit_app.number_input(**streamlit_kwargs)
|
||||||
|
|
||||||
|
def _render_object_input(self, streamlit_app: st, key: str, property: Dict) -> Any:
|
||||||
|
properties = property["properties"]
|
||||||
|
object_inputs = {}
|
||||||
|
for property_key in properties:
|
||||||
|
property = properties[property_key]
|
||||||
|
if not property.get("title"):
|
||||||
|
# Set property key as fallback title
|
||||||
|
property["title"] = name_to_title(property_key)
|
||||||
|
# construct full key based on key parts -> required later to get the value
|
||||||
|
full_key = key + "." + property_key
|
||||||
|
object_inputs[property_key] = self._render_property(
|
||||||
|
streamlit_app, full_key, property
|
||||||
|
)
|
||||||
|
return object_inputs
|
||||||
|
|
||||||
|
def _render_single_object_input(
|
||||||
|
self, streamlit_app: st, key: str, property: Dict
|
||||||
|
) -> Any:
|
||||||
|
# Add title and subheader
|
||||||
|
title = property.get("title")
|
||||||
|
streamlit_app.subheader(title)
|
||||||
|
if property.get("description"):
|
||||||
|
streamlit_app.markdown(property.get("description"))
|
||||||
|
|
||||||
|
object_reference = schema_utils.get_single_reference_item(
|
||||||
|
property, self._schema_references
|
||||||
|
)
|
||||||
|
return self._render_object_input(streamlit_app, key, object_reference)
|
||||||
|
|
||||||
|
def _render_property_list_input(
|
||||||
|
self, streamlit_app: st, key: str, property: Dict
|
||||||
|
) -> Any:
|
||||||
|
|
||||||
|
# Add title and subheader
|
||||||
|
streamlit_app.subheader(property.get("title"))
|
||||||
|
if property.get("description"):
|
||||||
|
streamlit_app.markdown(property.get("description"))
|
||||||
|
|
||||||
|
streamlit_app.markdown("---")
|
||||||
|
|
||||||
|
current_list = self._get_value(key)
|
||||||
|
if not current_list:
|
||||||
|
current_list = []
|
||||||
|
|
||||||
|
value_kwargs = {"label": "Value", "key": key + "-new-value"}
|
||||||
|
if property["items"]["type"] == "integer":
|
||||||
|
value_kwargs["value"] = 0 # type: ignore
|
||||||
|
new_value = streamlit_app.number_input(**value_kwargs)
|
||||||
|
elif property["items"]["type"] == "number":
|
||||||
|
value_kwargs["value"] = 0.0 # type: ignore
|
||||||
|
value_kwargs["format"] = "%f"
|
||||||
|
new_value = streamlit_app.number_input(**value_kwargs)
|
||||||
|
else:
|
||||||
|
value_kwargs["value"] = ""
|
||||||
|
new_value = streamlit_app.text_input(**value_kwargs)
|
||||||
|
|
||||||
|
streamlit_app.markdown("---")
|
||||||
|
|
||||||
|
with streamlit_app.container():
|
||||||
|
clear_col, add_col = streamlit_app.columns([1, 2])
|
||||||
|
|
||||||
|
with clear_col:
|
||||||
|
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
||||||
|
current_list = []
|
||||||
|
|
||||||
|
with add_col:
|
||||||
|
if (
|
||||||
|
streamlit_app.button("Add Item", key=key + "-add-item")
|
||||||
|
and new_value is not None
|
||||||
|
):
|
||||||
|
current_list.append(new_value)
|
||||||
|
|
||||||
|
streamlit_app.write(current_list)
|
||||||
|
|
||||||
|
return current_list
|
||||||
|
|
||||||
|
def _render_object_list_input(
|
||||||
|
self, streamlit_app: st, key: str, property: Dict
|
||||||
|
) -> Any:
|
||||||
|
|
||||||
|
# TODO: support max_items, and min_items properties
|
||||||
|
|
||||||
|
# Add title and subheader
|
||||||
|
streamlit_app.subheader(property.get("title"))
|
||||||
|
if property.get("description"):
|
||||||
|
streamlit_app.markdown(property.get("description"))
|
||||||
|
|
||||||
|
streamlit_app.markdown("---")
|
||||||
|
|
||||||
|
current_list = self._get_value(key)
|
||||||
|
if not current_list:
|
||||||
|
current_list = []
|
||||||
|
|
||||||
|
object_reference = schema_utils.resolve_reference(
|
||||||
|
property["items"]["$ref"], self._schema_references
|
||||||
|
)
|
||||||
|
input_data = self._render_object_input(streamlit_app, key, object_reference)
|
||||||
|
|
||||||
|
streamlit_app.markdown("---")
|
||||||
|
|
||||||
|
with streamlit_app.container():
|
||||||
|
clear_col, add_col = streamlit_app.columns([1, 2])
|
||||||
|
|
||||||
|
with clear_col:
|
||||||
|
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
||||||
|
current_list = []
|
||||||
|
|
||||||
|
with add_col:
|
||||||
|
if (
|
||||||
|
streamlit_app.button("Add Item", key=key + "-add-item")
|
||||||
|
and input_data
|
||||||
|
):
|
||||||
|
current_list.append(input_data)
|
||||||
|
|
||||||
|
streamlit_app.write(current_list)
|
||||||
|
return current_list
|
||||||
|
|
||||||
|
def _render_property(self, streamlit_app: st, key: str, property: Dict) -> Any:
|
||||||
|
if schema_utils.is_single_enum_property(property, self._schema_references):
|
||||||
|
return self._render_single_enum_input(streamlit_app, key, property)
|
||||||
|
|
||||||
|
if schema_utils.is_multi_enum_property(property, self._schema_references):
|
||||||
|
return self._render_multi_enum_input(streamlit_app, key, property)
|
||||||
|
|
||||||
|
if schema_utils.is_single_file_property(property):
|
||||||
|
return self._render_single_file_input(streamlit_app, key, property)
|
||||||
|
|
||||||
|
if schema_utils.is_multi_file_property(property):
|
||||||
|
return self._render_multi_file_input(streamlit_app, key, property)
|
||||||
|
|
||||||
|
if schema_utils.is_single_datetime_property(property):
|
||||||
|
return self._render_single_datetime_input(streamlit_app, key, property)
|
||||||
|
|
||||||
|
if schema_utils.is_single_boolean_property(property):
|
||||||
|
return self._render_single_boolean_input(streamlit_app, key, property)
|
||||||
|
|
||||||
|
if schema_utils.is_single_dict_property(property):
|
||||||
|
return self._render_single_dict_input(streamlit_app, key, property)
|
||||||
|
|
||||||
|
if schema_utils.is_single_number_property(property):
|
||||||
|
return self._render_single_number_input(streamlit_app, key, property)
|
||||||
|
|
||||||
|
if schema_utils.is_single_string_property(property):
|
||||||
|
return self._render_single_string_input(streamlit_app, key, property)
|
||||||
|
|
||||||
|
if schema_utils.is_single_object(property, self._schema_references):
|
||||||
|
return self._render_single_object_input(streamlit_app, key, property)
|
||||||
|
|
||||||
|
if schema_utils.is_object_list_property(property, self._schema_references):
|
||||||
|
return self._render_object_list_input(streamlit_app, key, property)
|
||||||
|
|
||||||
|
if schema_utils.is_property_list(property):
|
||||||
|
return self._render_property_list_input(streamlit_app, key, property)
|
||||||
|
|
||||||
|
if schema_utils.is_single_reference(property):
|
||||||
|
return self._render_single_reference(streamlit_app, key, property)
|
||||||
|
|
||||||
|
streamlit_app.warning(
|
||||||
|
"The type of the following property is currently not supported: "
|
||||||
|
+ str(property.get("title"))
|
||||||
|
)
|
||||||
|
raise Exception("Unsupported property")
|
||||||
|
|
||||||
|
|
||||||
|
class OutputUI:
|
||||||
|
def __init__(self, output_data: Any, input_data: Any):
|
||||||
|
self._output_data = output_data
|
||||||
|
self._input_data = input_data
|
||||||
|
|
||||||
|
def render_ui(self, streamlit_app) -> None:
|
||||||
|
try:
|
||||||
|
if isinstance(self._output_data, BaseModel):
|
||||||
|
self._render_single_output(streamlit_app, self._output_data)
|
||||||
|
return
|
||||||
|
if type(self._output_data) == list:
|
||||||
|
self._render_list_output(streamlit_app, self._output_data)
|
||||||
|
return
|
||||||
|
except Exception as ex:
|
||||||
|
streamlit_app.exception(ex)
|
||||||
|
# Fallback to
|
||||||
|
streamlit_app.json(jsonable_encoder(self._output_data))
|
||||||
|
|
||||||
|
def _render_single_text_property(
|
||||||
|
self, streamlit: st, property_schema: Dict, value: Any
|
||||||
|
) -> None:
|
||||||
|
# Add title and subheader
|
||||||
|
streamlit.subheader(property_schema.get("title"))
|
||||||
|
if property_schema.get("description"):
|
||||||
|
streamlit.markdown(property_schema.get("description"))
|
||||||
|
if value is None or value == "":
|
||||||
|
streamlit.info("No value returned!")
|
||||||
|
else:
|
||||||
|
streamlit.code(str(value), language="plain")
|
||||||
|
|
||||||
|
def _render_single_file_property(
|
||||||
|
self, streamlit: st, property_schema: Dict, value: Any
|
||||||
|
) -> None:
|
||||||
|
# Add title and subheader
|
||||||
|
streamlit.subheader(property_schema.get("title"))
|
||||||
|
if property_schema.get("description"):
|
||||||
|
streamlit.markdown(property_schema.get("description"))
|
||||||
|
if value is None or value == "":
|
||||||
|
streamlit.info("No value returned!")
|
||||||
|
else:
|
||||||
|
# TODO: Detect if it is a FileContent instance
|
||||||
|
# TODO: detect if it is base64
|
||||||
|
file_extension = ""
|
||||||
|
if "mime_type" in property_schema:
|
||||||
|
mime_type = property_schema["mime_type"]
|
||||||
|
file_extension = mimetypes.guess_extension(mime_type) or ""
|
||||||
|
|
||||||
|
if is_compatible_audio(mime_type):
|
||||||
|
streamlit.audio(value.as_bytes(), format=mime_type)
|
||||||
|
return
|
||||||
|
|
||||||
|
if is_compatible_image(mime_type):
|
||||||
|
streamlit.image(value.as_bytes())
|
||||||
|
return
|
||||||
|
|
||||||
|
if is_compatible_video(mime_type):
|
||||||
|
streamlit.video(value.as_bytes(), format=mime_type)
|
||||||
|
return
|
||||||
|
|
||||||
|
filename = (
|
||||||
|
(property_schema["title"] + file_extension)
|
||||||
|
.lower()
|
||||||
|
.strip()
|
||||||
|
.replace(" ", "-")
|
||||||
|
)
|
||||||
|
streamlit.markdown(
|
||||||
|
f'<a href="data:application/octet-stream;base64,{value}" download="{filename}"><input type="button" value="Download File"></a>',
|
||||||
|
unsafe_allow_html=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _render_single_complex_property(
|
||||||
|
self, streamlit: st, property_schema: Dict, value: Any
|
||||||
|
) -> None:
|
||||||
|
# Add title and subheader
|
||||||
|
streamlit.subheader(property_schema.get("title"))
|
||||||
|
if property_schema.get("description"):
|
||||||
|
streamlit.markdown(property_schema.get("description"))
|
||||||
|
|
||||||
|
streamlit.json(jsonable_encoder(value))
|
||||||
|
|
||||||
|
def _render_single_output(self, streamlit: st, output_data: BaseModel) -> None:
|
||||||
|
try:
|
||||||
|
if has_output_ui_renderer(output_data):
|
||||||
|
if function_has_named_arg(output_data.render_output_ui, "input"): # type: ignore
|
||||||
|
# render method also requests the input data
|
||||||
|
output_data.render_output_ui(streamlit, input=self._input_data) # type: ignore
|
||||||
|
else:
|
||||||
|
output_data.render_output_ui(streamlit) # type: ignore
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
# Use default auto-generation methods if the custom rendering throws an exception
|
||||||
|
logger.exception(
|
||||||
|
"Failed to execute custom render_output_ui function. Using auto-generation instead"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_schema = output_data.schema(by_alias=False)
|
||||||
|
model_properties = model_schema.get("properties")
|
||||||
|
definitions = model_schema.get("definitions")
|
||||||
|
|
||||||
|
if model_properties:
|
||||||
|
for property_key in output_data.__dict__:
|
||||||
|
property_schema = model_properties.get(property_key)
|
||||||
|
if not property_schema.get("title"):
|
||||||
|
# Set property key as fallback title
|
||||||
|
property_schema["title"] = property_key
|
||||||
|
|
||||||
|
output_property_value = output_data.__dict__[property_key]
|
||||||
|
|
||||||
|
if has_output_ui_renderer(output_property_value):
|
||||||
|
output_property_value.render_output_ui(streamlit) # type: ignore
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(output_property_value, BaseModel):
|
||||||
|
# Render output recursivly
|
||||||
|
streamlit.subheader(property_schema.get("title"))
|
||||||
|
if property_schema.get("description"):
|
||||||
|
streamlit.markdown(property_schema.get("description"))
|
||||||
|
self._render_single_output(streamlit, output_property_value)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if property_schema:
|
||||||
|
if schema_utils.is_single_file_property(property_schema):
|
||||||
|
self._render_single_file_property(
|
||||||
|
streamlit, property_schema, output_property_value
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (
|
||||||
|
schema_utils.is_single_string_property(property_schema)
|
||||||
|
or schema_utils.is_single_number_property(property_schema)
|
||||||
|
or schema_utils.is_single_datetime_property(property_schema)
|
||||||
|
or schema_utils.is_single_boolean_property(property_schema)
|
||||||
|
):
|
||||||
|
self._render_single_text_property(
|
||||||
|
streamlit, property_schema, output_property_value
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if definitions and schema_utils.is_single_enum_property(
|
||||||
|
property_schema, definitions
|
||||||
|
):
|
||||||
|
self._render_single_text_property(
|
||||||
|
streamlit, property_schema, output_property_value.value
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# TODO: render dict as table
|
||||||
|
|
||||||
|
self._render_single_complex_property(
|
||||||
|
streamlit, property_schema, output_property_value
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
def _render_list_output(self, streamlit: st, output_data: List) -> None:
|
||||||
|
try:
|
||||||
|
data_items: List = []
|
||||||
|
for data_item in output_data:
|
||||||
|
if has_output_ui_renderer(data_item):
|
||||||
|
# Render using the render function
|
||||||
|
data_item.render_output_ui(streamlit) # type: ignore
|
||||||
|
continue
|
||||||
|
data_items.append(data_item.dict())
|
||||||
|
# Try to show as dataframe
|
||||||
|
streamlit.table(pd.DataFrame(data_items))
|
||||||
|
except Exception:
|
||||||
|
# Fallback to
|
||||||
|
streamlit.json(jsonable_encoder(output_data))
|
||||||
|
|
||||||
|
|
||||||
|
def getOpyrator(mode: str) -> Opyrator:
|
||||||
|
if mode == None or mode.startswith('VC'):
|
||||||
|
from mkgui.app_vc import convert
|
||||||
|
return Opyrator(convert)
|
||||||
|
if mode == None or mode.startswith('预处理'):
|
||||||
|
from mkgui.preprocess import preprocess
|
||||||
|
return Opyrator(preprocess)
|
||||||
|
if mode == None or mode.startswith('模型训练'):
|
||||||
|
from mkgui.train import train
|
||||||
|
return Opyrator(train)
|
||||||
|
from mkgui.app import synthesize
|
||||||
|
return Opyrator(synthesize)
|
||||||
|
|
||||||
|
|
||||||
|
def render_streamlit_ui() -> None:
|
||||||
|
# init
|
||||||
|
session_state = st.session_state
|
||||||
|
session_state.input_data = {}
|
||||||
|
# Add custom css settings
|
||||||
|
st.markdown(f"<style>{CUSTOM_STREAMLIT_CSS}</style>", unsafe_allow_html=True)
|
||||||
|
|
||||||
|
with st.spinner("Loading MockingBird GUI. Please wait..."):
|
||||||
|
session_state.mode = st.sidebar.selectbox(
|
||||||
|
'模式选择',
|
||||||
|
( "AI拟音", "VC拟音", "预处理", "模型训练")
|
||||||
|
)
|
||||||
|
if "mode" in session_state:
|
||||||
|
mode = session_state.mode
|
||||||
|
else:
|
||||||
|
mode = ""
|
||||||
|
opyrator = getOpyrator(mode)
|
||||||
|
title = opyrator.name + mode
|
||||||
|
|
||||||
|
col1, col2, _ = st.columns(3)
|
||||||
|
col2.title(title)
|
||||||
|
col2.markdown("欢迎使用MockingBird Web 2")
|
||||||
|
|
||||||
|
image = Image.open('.\\mkgui\\static\\mb.png')
|
||||||
|
col1.image(image)
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
left, right = st.columns([0.4, 0.6])
|
||||||
|
|
||||||
|
with left:
|
||||||
|
st.header("Control 控制")
|
||||||
|
InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(st)
|
||||||
|
execute_selected = st.button(opyrator.action)
|
||||||
|
if execute_selected:
|
||||||
|
with st.spinner("Executing operation. Please wait..."):
|
||||||
|
try:
|
||||||
|
input_data_obj = parse_obj_as(
|
||||||
|
opyrator.input_type, session_state.input_data
|
||||||
|
)
|
||||||
|
session_state.output_data = opyrator(input=input_data_obj)
|
||||||
|
session_state.latest_operation_input = input_data_obj # should this really be saved as additional session object?
|
||||||
|
except ValidationError as ex:
|
||||||
|
st.error(ex)
|
||||||
|
else:
|
||||||
|
# st.success("Operation executed successfully.")
|
||||||
|
pass
|
||||||
|
|
||||||
|
with right:
|
||||||
|
st.header("Result 结果")
|
||||||
|
if 'output_data' in session_state:
|
||||||
|
OutputUI(
|
||||||
|
session_state.output_data, session_state.latest_operation_input
|
||||||
|
).render_ui(st)
|
||||||
|
if st.button("Clear"):
|
||||||
|
# Clear all state
|
||||||
|
for key in st.session_state.keys():
|
||||||
|
del st.session_state[key]
|
||||||
|
session_state.input_data = {}
|
||||||
|
st.experimental_rerun()
|
||||||
|
else:
|
||||||
|
# placeholder
|
||||||
|
st.caption("请使用左侧控制板进行输入并运行获得结果")
|
||||||
|
|
||||||
|
|
||||||
13
mkgui/base/ui/streamlit_utils.py
Normal file
13
mkgui/base/ui/streamlit_utils.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
CUSTOM_STREAMLIT_CSS = """
|
||||||
|
div[data-testid="stBlock"] button {
|
||||||
|
width: 100% !important;
|
||||||
|
margin-bottom: 20px !important;
|
||||||
|
border-color: #bfbfbf !important;
|
||||||
|
}
|
||||||
|
section[data-testid="stSidebar"] div {
|
||||||
|
max-width: 10rem;
|
||||||
|
}
|
||||||
|
pre code {
|
||||||
|
white-space: pre-wrap;
|
||||||
|
}
|
||||||
|
"""
|
||||||
96
mkgui/preprocess.py
Normal file
96
mkgui/preprocess.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
EXT_MODELS_DIRT = "ppg_extractor\\saved_models"
|
||||||
|
ENC_MODELS_DIRT = "encoder\\saved_models"
|
||||||
|
|
||||||
|
|
||||||
|
if os.path.isdir(EXT_MODELS_DIRT):
|
||||||
|
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
||||||
|
print("Loaded extractor models: " + str(len(extractors)))
|
||||||
|
else:
|
||||||
|
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
||||||
|
|
||||||
|
if os.path.isdir(ENC_MODELS_DIRT):
|
||||||
|
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||||
|
print("Loaded encoders models: " + str(len(encoders)))
|
||||||
|
else:
|
||||||
|
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||||
|
|
||||||
|
class Model(str, Enum):
|
||||||
|
VC_PPG2MEL = "ppg2mel"
|
||||||
|
|
||||||
|
class Dataset(str, Enum):
|
||||||
|
AIDATATANG_200ZH = "aidatatang_200zh"
|
||||||
|
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
|
||||||
|
|
||||||
|
class Input(BaseModel):
|
||||||
|
# def render_input_ui(st, input) -> Dict:
|
||||||
|
# input["selected_dataset"] = st.selectbox(
|
||||||
|
# '选择数据集',
|
||||||
|
# ("aidatatang_200zh", "aidatatang_200zh_s")
|
||||||
|
# )
|
||||||
|
# return input
|
||||||
|
model: Model = Field(
|
||||||
|
Model.VC_PPG2MEL, title="目标模型",
|
||||||
|
)
|
||||||
|
dataset: Dataset = Field(
|
||||||
|
Dataset.AIDATATANG_200ZH, title="数据集选择",
|
||||||
|
)
|
||||||
|
datasets_root: str = Field(
|
||||||
|
..., alias="数据集根目录", description="输入数据集根目录(相对/绝对)",
|
||||||
|
format=True,
|
||||||
|
example="..\\trainning_data\\"
|
||||||
|
)
|
||||||
|
output_root: str = Field(
|
||||||
|
..., alias="输出根目录", description="输出结果根目录(相对/绝对)",
|
||||||
|
format=True,
|
||||||
|
example="..\\trainning_data\\"
|
||||||
|
)
|
||||||
|
n_processes: int = Field(
|
||||||
|
2, alias="处理线程数", description="根据CPU线程数来设置",
|
||||||
|
le=32, ge=1
|
||||||
|
)
|
||||||
|
extractor: extractors = Field(
|
||||||
|
..., alias="特征提取模型",
|
||||||
|
description="选择PPG特征提取模型文件."
|
||||||
|
)
|
||||||
|
encoder: encoders = Field(
|
||||||
|
..., alias="语音编码模型",
|
||||||
|
description="选择语音编码模型文件."
|
||||||
|
)
|
||||||
|
|
||||||
|
class AudioEntity(BaseModel):
|
||||||
|
content: bytes
|
||||||
|
mel: Any
|
||||||
|
|
||||||
|
class Output(BaseModel):
|
||||||
|
__root__: tuple[str, int]
|
||||||
|
|
||||||
|
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||||
|
"""Custom output UI.
|
||||||
|
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||||
|
"""
|
||||||
|
sr, count = self.__root__
|
||||||
|
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
|
||||||
|
|
||||||
|
def preprocess(input: Input) -> Output:
|
||||||
|
"""Preprocess(预处理)"""
|
||||||
|
finished = 0
|
||||||
|
if input.model == Model.VC_PPG2MEL:
|
||||||
|
from ppg2mel.preprocess import preprocess_dataset
|
||||||
|
finished = preprocess_dataset(
|
||||||
|
datasets_root=Path(input.datasets_root),
|
||||||
|
dataset=input.dataset,
|
||||||
|
out_dir=Path(input.output_root),
|
||||||
|
n_processes=input.n_processes,
|
||||||
|
ppg_encoder_model_fpath=Path(input.extractor.value),
|
||||||
|
speaker_encoder_model=Path(input.encoder.value)
|
||||||
|
)
|
||||||
|
# TODO: pass useful return code
|
||||||
|
return Output(__root__=(input.dataset, finished))
|
||||||
BIN
mkgui/static/mb.png
Normal file
BIN
mkgui/static/mb.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.6 KiB |
156
mkgui/train.py
Normal file
156
mkgui/train.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
import numpy as np
|
||||||
|
from utils.load_yaml import HpsYaml
|
||||||
|
from utils.util import AttrDict
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# TODO: seperator for *unix systems
|
||||||
|
# Constants
|
||||||
|
EXT_MODELS_DIRT = "ppg_extractor\\saved_models"
|
||||||
|
CONV_MODELS_DIRT = "ppg2mel\\saved_models"
|
||||||
|
ENC_MODELS_DIRT = "encoder\\saved_models"
|
||||||
|
|
||||||
|
|
||||||
|
if os.path.isdir(EXT_MODELS_DIRT):
|
||||||
|
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
||||||
|
print("Loaded extractor models: " + str(len(extractors)))
|
||||||
|
else:
|
||||||
|
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
||||||
|
|
||||||
|
if os.path.isdir(CONV_MODELS_DIRT):
|
||||||
|
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
|
||||||
|
print("Loaded convertor models: " + str(len(convertors)))
|
||||||
|
else:
|
||||||
|
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
|
||||||
|
|
||||||
|
if os.path.isdir(ENC_MODELS_DIRT):
|
||||||
|
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||||
|
print("Loaded encoders models: " + str(len(encoders)))
|
||||||
|
else:
|
||||||
|
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||||
|
|
||||||
|
class Model(str, Enum):
|
||||||
|
VC_PPG2MEL = "ppg2mel"
|
||||||
|
|
||||||
|
class Dataset(str, Enum):
|
||||||
|
AIDATATANG_200ZH = "aidatatang_200zh"
|
||||||
|
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
|
||||||
|
|
||||||
|
class Input(BaseModel):
|
||||||
|
# def render_input_ui(st, input) -> Dict:
|
||||||
|
# input["selected_dataset"] = st.selectbox(
|
||||||
|
# '选择数据集',
|
||||||
|
# ("aidatatang_200zh", "aidatatang_200zh_s")
|
||||||
|
# )
|
||||||
|
# return input
|
||||||
|
model: Model = Field(
|
||||||
|
Model.VC_PPG2MEL, title="模型类型",
|
||||||
|
)
|
||||||
|
# datasets_root: str = Field(
|
||||||
|
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
|
||||||
|
# format=True,
|
||||||
|
# example="..\\trainning_data\\"
|
||||||
|
# )
|
||||||
|
output_root: str = Field(
|
||||||
|
..., alias="输出目录(可选)", description="建议不填,保持默认",
|
||||||
|
format=True,
|
||||||
|
example=""
|
||||||
|
)
|
||||||
|
continue_mode: bool = Field(
|
||||||
|
True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练",
|
||||||
|
)
|
||||||
|
gpu: bool = Field(
|
||||||
|
True, alias="GPU训练", description="选择“是”,则使用GPU训练",
|
||||||
|
)
|
||||||
|
verbose: bool = Field(
|
||||||
|
True, alias="打印详情", description="选择“是”,输出更多详情",
|
||||||
|
)
|
||||||
|
# TODO: Move to hiden fields by default
|
||||||
|
convertor: convertors = Field(
|
||||||
|
..., alias="转换模型",
|
||||||
|
description="选择语音转换模型文件."
|
||||||
|
)
|
||||||
|
extractor: extractors = Field(
|
||||||
|
..., alias="特征提取模型",
|
||||||
|
description="选择PPG特征提取模型文件."
|
||||||
|
)
|
||||||
|
encoder: encoders = Field(
|
||||||
|
..., alias="语音编码模型",
|
||||||
|
description="选择语音编码模型文件."
|
||||||
|
)
|
||||||
|
njobs: int = Field(
|
||||||
|
8, alias="进程数", description="适用于ppg2mel",
|
||||||
|
)
|
||||||
|
seed: int = Field(
|
||||||
|
default=0, alias="初始随机数", description="适用于ppg2mel",
|
||||||
|
)
|
||||||
|
model_name: str = Field(
|
||||||
|
..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效",
|
||||||
|
example="test"
|
||||||
|
)
|
||||||
|
model_config: str = Field(
|
||||||
|
..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效",
|
||||||
|
example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
class AudioEntity(BaseModel):
|
||||||
|
content: bytes
|
||||||
|
mel: Any
|
||||||
|
|
||||||
|
class Output(BaseModel):
|
||||||
|
__root__: tuple[str, int]
|
||||||
|
|
||||||
|
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||||
|
"""Custom output UI.
|
||||||
|
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||||
|
"""
|
||||||
|
sr, count = self.__root__
|
||||||
|
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
|
||||||
|
|
||||||
|
def train(input: Input) -> Output:
|
||||||
|
"""Train(训练)"""
|
||||||
|
|
||||||
|
print(">>> OneShot VC training ...")
|
||||||
|
params = AttrDict()
|
||||||
|
params.update({
|
||||||
|
"gpu": input.gpu,
|
||||||
|
"cpu": not input.gpu,
|
||||||
|
"njobs": input.njobs,
|
||||||
|
"seed": input.seed,
|
||||||
|
"verbose": input.verbose,
|
||||||
|
"load": input.convertor.value,
|
||||||
|
"warm_start": False,
|
||||||
|
})
|
||||||
|
if input.continue_mode:
|
||||||
|
# trace old model and config
|
||||||
|
p = Path(input.convertor.value)
|
||||||
|
params.name = p.parent.name
|
||||||
|
# search a config file
|
||||||
|
model_config_fpaths = list(p.parent.rglob("*.yaml"))
|
||||||
|
if len(model_config_fpaths) == 0:
|
||||||
|
raise "No model yaml config found for convertor"
|
||||||
|
config = HpsYaml(model_config_fpaths[0])
|
||||||
|
params.ckpdir = p.parent.parent
|
||||||
|
params.config = model_config_fpaths[0]
|
||||||
|
params.logdir = os.path.join(p.parent, "log")
|
||||||
|
else:
|
||||||
|
# Make the config dict dot visitable
|
||||||
|
config = HpsYaml(input.config)
|
||||||
|
np.random.seed(input.seed)
|
||||||
|
torch.manual_seed(input.seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(input.seed)
|
||||||
|
mode = "train"
|
||||||
|
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||||
|
solver = Solver(config, params, mode)
|
||||||
|
solver.load_data()
|
||||||
|
solver.set_model()
|
||||||
|
solver.exec()
|
||||||
|
print(">>> Oneshot VC train finished!")
|
||||||
|
|
||||||
|
# TODO: pass useful return code
|
||||||
|
return Output(__root__=(input.dataset, 0))
|
||||||
209
ppg2mel/__init__.py
Normal file
209
ppg2mel/__init__.py
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright 2020 Songxiang Liu
|
||||||
|
# Apache 2.0
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .utils.abs_model import AbsMelDecoder
|
||||||
|
from .rnn_decoder_mol import Decoder
|
||||||
|
from .utils.cnn_postnet import Postnet
|
||||||
|
from .utils.vc_utils import get_mask_from_lengths
|
||||||
|
|
||||||
|
from utils.load_yaml import HpsYaml
|
||||||
|
|
||||||
|
class MelDecoderMOLv2(AbsMelDecoder):
|
||||||
|
"""Use an encoder to preprocess ppg."""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_speakers: int,
|
||||||
|
spk_embed_dim: int,
|
||||||
|
bottle_neck_feature_dim: int,
|
||||||
|
encoder_dim: int = 256,
|
||||||
|
encoder_downsample_rates: List = [2, 2],
|
||||||
|
attention_rnn_dim: int = 512,
|
||||||
|
decoder_rnn_dim: int = 512,
|
||||||
|
num_decoder_rnn_layer: int = 1,
|
||||||
|
concat_context_to_last: bool = True,
|
||||||
|
prenet_dims: List = [256, 128],
|
||||||
|
num_mixtures: int = 5,
|
||||||
|
frames_per_step: int = 2,
|
||||||
|
mask_padding: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.mask_padding = mask_padding
|
||||||
|
self.bottle_neck_feature_dim = bottle_neck_feature_dim
|
||||||
|
self.num_mels = 80
|
||||||
|
self.encoder_down_factor=np.cumprod(encoder_downsample_rates)[-1]
|
||||||
|
self.frames_per_step = frames_per_step
|
||||||
|
self.use_spk_dvec = True
|
||||||
|
|
||||||
|
input_dim = bottle_neck_feature_dim
|
||||||
|
|
||||||
|
# Downsampling convolution
|
||||||
|
self.bnf_prenet = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv1d(input_dim, encoder_dim, kernel_size=1, bias=False),
|
||||||
|
torch.nn.LeakyReLU(0.1),
|
||||||
|
|
||||||
|
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||||
|
torch.nn.Conv1d(
|
||||||
|
encoder_dim, encoder_dim,
|
||||||
|
kernel_size=2*encoder_downsample_rates[0],
|
||||||
|
stride=encoder_downsample_rates[0],
|
||||||
|
padding=encoder_downsample_rates[0]//2,
|
||||||
|
),
|
||||||
|
torch.nn.LeakyReLU(0.1),
|
||||||
|
|
||||||
|
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||||
|
torch.nn.Conv1d(
|
||||||
|
encoder_dim, encoder_dim,
|
||||||
|
kernel_size=2*encoder_downsample_rates[1],
|
||||||
|
stride=encoder_downsample_rates[1],
|
||||||
|
padding=encoder_downsample_rates[1]//2,
|
||||||
|
),
|
||||||
|
torch.nn.LeakyReLU(0.1),
|
||||||
|
|
||||||
|
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||||
|
)
|
||||||
|
decoder_enc_dim = encoder_dim
|
||||||
|
self.pitch_convs = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv1d(2, encoder_dim, kernel_size=1, bias=False),
|
||||||
|
torch.nn.LeakyReLU(0.1),
|
||||||
|
|
||||||
|
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||||
|
torch.nn.Conv1d(
|
||||||
|
encoder_dim, encoder_dim,
|
||||||
|
kernel_size=2*encoder_downsample_rates[0],
|
||||||
|
stride=encoder_downsample_rates[0],
|
||||||
|
padding=encoder_downsample_rates[0]//2,
|
||||||
|
),
|
||||||
|
torch.nn.LeakyReLU(0.1),
|
||||||
|
|
||||||
|
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||||
|
torch.nn.Conv1d(
|
||||||
|
encoder_dim, encoder_dim,
|
||||||
|
kernel_size=2*encoder_downsample_rates[1],
|
||||||
|
stride=encoder_downsample_rates[1],
|
||||||
|
padding=encoder_downsample_rates[1]//2,
|
||||||
|
),
|
||||||
|
torch.nn.LeakyReLU(0.1),
|
||||||
|
|
||||||
|
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.reduce_proj = torch.nn.Linear(encoder_dim + spk_embed_dim, encoder_dim)
|
||||||
|
|
||||||
|
# Decoder
|
||||||
|
self.decoder = Decoder(
|
||||||
|
enc_dim=decoder_enc_dim,
|
||||||
|
num_mels=self.num_mels,
|
||||||
|
frames_per_step=frames_per_step,
|
||||||
|
attention_rnn_dim=attention_rnn_dim,
|
||||||
|
decoder_rnn_dim=decoder_rnn_dim,
|
||||||
|
num_decoder_rnn_layer=num_decoder_rnn_layer,
|
||||||
|
prenet_dims=prenet_dims,
|
||||||
|
num_mixtures=num_mixtures,
|
||||||
|
use_stop_tokens=True,
|
||||||
|
concat_context_to_last=concat_context_to_last,
|
||||||
|
encoder_down_factor=self.encoder_down_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mel-Spec Postnet: some residual CNN layers
|
||||||
|
self.postnet = Postnet()
|
||||||
|
|
||||||
|
def parse_output(self, outputs, output_lengths=None):
|
||||||
|
if self.mask_padding and output_lengths is not None:
|
||||||
|
mask = ~get_mask_from_lengths(output_lengths, outputs[0].size(1))
|
||||||
|
mask = mask.unsqueeze(2).expand(mask.size(0), mask.size(1), self.num_mels)
|
||||||
|
outputs[0].data.masked_fill_(mask, 0.0)
|
||||||
|
outputs[1].data.masked_fill_(mask, 0.0)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
bottle_neck_features: torch.Tensor,
|
||||||
|
feature_lengths: torch.Tensor,
|
||||||
|
speech: torch.Tensor,
|
||||||
|
speech_lengths: torch.Tensor,
|
||||||
|
logf0_uv: torch.Tensor = None,
|
||||||
|
spembs: torch.Tensor = None,
|
||||||
|
output_att_ws: bool = False,
|
||||||
|
):
|
||||||
|
decoder_inputs = self.bnf_prenet(
|
||||||
|
bottle_neck_features.transpose(1, 2)
|
||||||
|
).transpose(1, 2)
|
||||||
|
logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
|
||||||
|
decoder_inputs = decoder_inputs + logf0_uv
|
||||||
|
|
||||||
|
assert spembs is not None
|
||||||
|
spk_embeds = F.normalize(
|
||||||
|
spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
|
||||||
|
decoder_inputs = torch.cat([decoder_inputs, spk_embeds], dim=-1)
|
||||||
|
decoder_inputs = self.reduce_proj(decoder_inputs)
|
||||||
|
|
||||||
|
# (B, num_mels, T_dec)
|
||||||
|
T_dec = torch.div(feature_lengths, int(self.encoder_down_factor), rounding_mode='floor')
|
||||||
|
mel_outputs, predicted_stop, alignments = self.decoder(
|
||||||
|
decoder_inputs, speech, T_dec)
|
||||||
|
## Post-processing
|
||||||
|
mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
|
||||||
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||||
|
if output_att_ws:
|
||||||
|
return self.parse_output(
|
||||||
|
[mel_outputs, mel_outputs_postnet, predicted_stop, alignments], speech_lengths)
|
||||||
|
else:
|
||||||
|
return self.parse_output(
|
||||||
|
[mel_outputs, mel_outputs_postnet, predicted_stop], speech_lengths)
|
||||||
|
|
||||||
|
# return mel_outputs, mel_outputs_postnet
|
||||||
|
|
||||||
|
def inference(
|
||||||
|
self,
|
||||||
|
bottle_neck_features: torch.Tensor,
|
||||||
|
logf0_uv: torch.Tensor = None,
|
||||||
|
spembs: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
decoder_inputs = self.bnf_prenet(bottle_neck_features.transpose(1, 2)).transpose(1, 2)
|
||||||
|
logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
|
||||||
|
decoder_inputs = decoder_inputs + logf0_uv
|
||||||
|
|
||||||
|
assert spembs is not None
|
||||||
|
spk_embeds = F.normalize(
|
||||||
|
spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
|
||||||
|
bottle_neck_features = torch.cat([decoder_inputs, spk_embeds], dim=-1)
|
||||||
|
bottle_neck_features = self.reduce_proj(bottle_neck_features)
|
||||||
|
|
||||||
|
## Decoder
|
||||||
|
if bottle_neck_features.size(0) > 1:
|
||||||
|
mel_outputs, alignments = self.decoder.inference_batched(bottle_neck_features)
|
||||||
|
else:
|
||||||
|
mel_outputs, alignments = self.decoder.inference(bottle_neck_features,)
|
||||||
|
## Post-processing
|
||||||
|
mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
|
||||||
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||||
|
# outputs = mel_outputs_postnet[0]
|
||||||
|
|
||||||
|
return mel_outputs[0], mel_outputs_postnet[0], alignments[0]
|
||||||
|
|
||||||
|
def load_model(model_file, device=None):
|
||||||
|
# search a config file
|
||||||
|
model_config_fpaths = list(model_file.parent.rglob("*.yaml"))
|
||||||
|
if len(model_config_fpaths) == 0:
|
||||||
|
raise "No model yaml config found for convertor"
|
||||||
|
if device is None:
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
model_config = HpsYaml(model_config_fpaths[0])
|
||||||
|
ppg2mel_model = MelDecoderMOLv2(
|
||||||
|
**model_config["model"]
|
||||||
|
).to(device)
|
||||||
|
ckpt = torch.load(model_file, map_location=device)
|
||||||
|
ppg2mel_model.load_state_dict(ckpt["model"])
|
||||||
|
ppg2mel_model.eval()
|
||||||
|
return ppg2mel_model
|
||||||
113
ppg2mel/preprocess.py
Normal file
113
ppg2mel/preprocess.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from pathlib import Path
|
||||||
|
import soundfile
|
||||||
|
import resampy
|
||||||
|
|
||||||
|
from ppg_extractor import load_model
|
||||||
|
import encoder.inference as Encoder
|
||||||
|
from encoder.audio import preprocess_wav
|
||||||
|
from encoder import audio
|
||||||
|
from utils.f0_utils import compute_f0
|
||||||
|
|
||||||
|
from torch.multiprocessing import Pool, cpu_count
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
SAMPLE_RATE=16000
|
||||||
|
|
||||||
|
def _compute_bnf(
|
||||||
|
wav: any,
|
||||||
|
output_fpath: str,
|
||||||
|
device: torch.device,
|
||||||
|
ppg_model_local: any,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute CTC-Attention Seq2seq ASR encoder bottle-neck features (BNF).
|
||||||
|
"""
|
||||||
|
ppg_model_local.to(device)
|
||||||
|
wav_tensor = torch.from_numpy(wav).float().to(device).unsqueeze(0)
|
||||||
|
wav_length = torch.LongTensor([wav.shape[0]]).to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
bnf = ppg_model_local(wav_tensor, wav_length)
|
||||||
|
bnf_npy = bnf.squeeze(0).cpu().numpy()
|
||||||
|
np.save(output_fpath, bnf_npy, allow_pickle=False)
|
||||||
|
return bnf_npy, len(bnf_npy)
|
||||||
|
|
||||||
|
def _compute_f0_from_wav(wav, output_fpath):
|
||||||
|
"""Compute merged f0 values."""
|
||||||
|
f0 = compute_f0(wav, SAMPLE_RATE)
|
||||||
|
np.save(output_fpath, f0, allow_pickle=False)
|
||||||
|
return f0, len(f0)
|
||||||
|
|
||||||
|
def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device):
|
||||||
|
Encoder.set_model(encoder_model_local)
|
||||||
|
# Compute where to split the utterance into partials and pad if necessary
|
||||||
|
wave_slices, mel_slices = Encoder.compute_partial_slices(len(wav), rate=1.3, min_pad_coverage=0.75)
|
||||||
|
max_wave_length = wave_slices[-1].stop
|
||||||
|
if max_wave_length >= len(wav):
|
||||||
|
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
||||||
|
|
||||||
|
# Split the utterance into partials
|
||||||
|
frames = audio.wav_to_mel_spectrogram(wav)
|
||||||
|
frames_batch = np.array([frames[s] for s in mel_slices])
|
||||||
|
partial_embeds = Encoder.embed_frames_batch(frames_batch)
|
||||||
|
|
||||||
|
# Compute the utterance embedding from the partial embeddings
|
||||||
|
raw_embed = np.mean(partial_embeds, axis=0)
|
||||||
|
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
||||||
|
|
||||||
|
np.save(output_fpath, embed, allow_pickle=False)
|
||||||
|
return embed, len(embed)
|
||||||
|
|
||||||
|
def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local):
|
||||||
|
# wav = preprocess_wav(wav_path)
|
||||||
|
# try:
|
||||||
|
wav, sr = soundfile.read(wav_path)
|
||||||
|
if len(wav) < sr:
|
||||||
|
return None, sr, len(wav)
|
||||||
|
if sr != SAMPLE_RATE:
|
||||||
|
wav = resampy.resample(wav, sr, SAMPLE_RATE)
|
||||||
|
sr = SAMPLE_RATE
|
||||||
|
utt_id = os.path.basename(wav_path).rstrip(".wav")
|
||||||
|
|
||||||
|
_, length_bnf = _compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local)
|
||||||
|
_, length_f0 = _compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav)
|
||||||
|
_, length_embed = _compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav)
|
||||||
|
|
||||||
|
def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model):
|
||||||
|
# Glob wav files
|
||||||
|
wav_file_list = sorted(Path(f"{datasets_root}/{dataset}").glob("**/*.wav"))
|
||||||
|
print(f"Globbed {len(wav_file_list)} wav files.")
|
||||||
|
|
||||||
|
out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True)
|
||||||
|
out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True)
|
||||||
|
out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True)
|
||||||
|
ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu")
|
||||||
|
encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu")
|
||||||
|
if n_processes is None:
|
||||||
|
n_processes = cpu_count()
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, encoder_model_local=encoder_model_local, device=device)
|
||||||
|
job = Pool(n_processes).imap(func, wav_file_list)
|
||||||
|
list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav"))
|
||||||
|
|
||||||
|
# finish processing and mark
|
||||||
|
t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8")
|
||||||
|
d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8")
|
||||||
|
e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8")
|
||||||
|
for file in sorted(out_dir.joinpath("f0").glob("*.npy")):
|
||||||
|
id = os.path.basename(file).split(".f0.npy")[0]
|
||||||
|
if id.endswith("01"):
|
||||||
|
d_fid_file.write(id + "\n")
|
||||||
|
elif id.endswith("09"):
|
||||||
|
e_fid_file.write(id + "\n")
|
||||||
|
else:
|
||||||
|
t_fid_file.write(id + "\n")
|
||||||
|
t_fid_file.close()
|
||||||
|
d_fid_file.close()
|
||||||
|
e_fid_file.close()
|
||||||
|
return len(wav_file_list)
|
||||||
374
ppg2mel/rnn_decoder_mol.py
Normal file
374
ppg2mel/rnn_decoder_mol.py
Normal file
@@ -0,0 +1,374 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
from .utils.mol_attention import MOLAttention
|
||||||
|
from .utils.basic_layers import Linear
|
||||||
|
from .utils.vc_utils import get_mask_from_lengths
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderPrenet(nn.Module):
|
||||||
|
def __init__(self, in_dim, sizes):
|
||||||
|
super().__init__()
|
||||||
|
in_sizes = [in_dim] + sizes[:-1]
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[Linear(in_size, out_size, bias=False)
|
||||||
|
for (in_size, out_size) in zip(in_sizes, sizes)])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for linear in self.layers:
|
||||||
|
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
"""Mixture of Logistic (MoL) attention-based RNN Decoder."""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
enc_dim,
|
||||||
|
num_mels,
|
||||||
|
frames_per_step,
|
||||||
|
attention_rnn_dim,
|
||||||
|
decoder_rnn_dim,
|
||||||
|
prenet_dims,
|
||||||
|
num_mixtures,
|
||||||
|
encoder_down_factor=1,
|
||||||
|
num_decoder_rnn_layer=1,
|
||||||
|
use_stop_tokens=False,
|
||||||
|
concat_context_to_last=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.enc_dim = enc_dim
|
||||||
|
self.encoder_down_factor = encoder_down_factor
|
||||||
|
self.num_mels = num_mels
|
||||||
|
self.frames_per_step = frames_per_step
|
||||||
|
self.attention_rnn_dim = attention_rnn_dim
|
||||||
|
self.decoder_rnn_dim = decoder_rnn_dim
|
||||||
|
self.prenet_dims = prenet_dims
|
||||||
|
self.use_stop_tokens = use_stop_tokens
|
||||||
|
self.num_decoder_rnn_layer = num_decoder_rnn_layer
|
||||||
|
self.concat_context_to_last = concat_context_to_last
|
||||||
|
|
||||||
|
# Mel prenet
|
||||||
|
self.prenet = DecoderPrenet(num_mels, prenet_dims)
|
||||||
|
self.prenet_pitch = DecoderPrenet(num_mels, prenet_dims)
|
||||||
|
|
||||||
|
# Attention RNN
|
||||||
|
self.attention_rnn = nn.LSTMCell(
|
||||||
|
prenet_dims[-1] + enc_dim,
|
||||||
|
attention_rnn_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attention
|
||||||
|
self.attention_layer = MOLAttention(
|
||||||
|
attention_rnn_dim,
|
||||||
|
r=frames_per_step/encoder_down_factor,
|
||||||
|
M=num_mixtures,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decoder RNN
|
||||||
|
self.decoder_rnn_layers = nn.ModuleList()
|
||||||
|
for i in range(num_decoder_rnn_layer):
|
||||||
|
if i == 0:
|
||||||
|
self.decoder_rnn_layers.append(
|
||||||
|
nn.LSTMCell(
|
||||||
|
enc_dim + attention_rnn_dim,
|
||||||
|
decoder_rnn_dim))
|
||||||
|
else:
|
||||||
|
self.decoder_rnn_layers.append(
|
||||||
|
nn.LSTMCell(
|
||||||
|
decoder_rnn_dim,
|
||||||
|
decoder_rnn_dim))
|
||||||
|
# self.decoder_rnn = nn.LSTMCell(
|
||||||
|
# 2 * enc_dim + attention_rnn_dim,
|
||||||
|
# decoder_rnn_dim
|
||||||
|
# )
|
||||||
|
if concat_context_to_last:
|
||||||
|
self.linear_projection = Linear(
|
||||||
|
enc_dim + decoder_rnn_dim,
|
||||||
|
num_mels * frames_per_step
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.linear_projection = Linear(
|
||||||
|
decoder_rnn_dim,
|
||||||
|
num_mels * frames_per_step
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Stop-token layer
|
||||||
|
if self.use_stop_tokens:
|
||||||
|
if concat_context_to_last:
|
||||||
|
self.stop_layer = Linear(
|
||||||
|
enc_dim + decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.stop_layer = Linear(
|
||||||
|
decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_go_frame(self, memory):
|
||||||
|
B = memory.size(0)
|
||||||
|
go_frame = torch.zeros((B, self.num_mels), dtype=torch.float,
|
||||||
|
device=memory.device)
|
||||||
|
return go_frame
|
||||||
|
|
||||||
|
def initialize_decoder_states(self, memory, mask):
|
||||||
|
device = next(self.parameters()).device
|
||||||
|
B = memory.size(0)
|
||||||
|
|
||||||
|
# attention rnn states
|
||||||
|
self.attention_hidden = torch.zeros(
|
||||||
|
(B, self.attention_rnn_dim), device=device)
|
||||||
|
self.attention_cell = torch.zeros(
|
||||||
|
(B, self.attention_rnn_dim), device=device)
|
||||||
|
|
||||||
|
# decoder rnn states
|
||||||
|
self.decoder_hiddens = []
|
||||||
|
self.decoder_cells = []
|
||||||
|
for i in range(self.num_decoder_rnn_layer):
|
||||||
|
self.decoder_hiddens.append(
|
||||||
|
torch.zeros((B, self.decoder_rnn_dim),
|
||||||
|
device=device)
|
||||||
|
)
|
||||||
|
self.decoder_cells.append(
|
||||||
|
torch.zeros((B, self.decoder_rnn_dim),
|
||||||
|
device=device)
|
||||||
|
)
|
||||||
|
# self.decoder_hidden = torch.zeros(
|
||||||
|
# (B, self.decoder_rnn_dim), device=device)
|
||||||
|
# self.decoder_cell = torch.zeros(
|
||||||
|
# (B, self.decoder_rnn_dim), device=device)
|
||||||
|
|
||||||
|
self.attention_context = torch.zeros(
|
||||||
|
(B, self.enc_dim), device=device)
|
||||||
|
|
||||||
|
self.memory = memory
|
||||||
|
# self.processed_memory = self.attention_layer.memory_layer(memory)
|
||||||
|
self.mask = mask
|
||||||
|
|
||||||
|
def parse_decoder_inputs(self, decoder_inputs):
|
||||||
|
"""Prepare decoder inputs, i.e. gt mel
|
||||||
|
Args:
|
||||||
|
decoder_inputs:(B, T_out, n_mel_channels) inputs used for teacher-forced training.
|
||||||
|
"""
|
||||||
|
decoder_inputs = decoder_inputs.reshape(
|
||||||
|
decoder_inputs.size(0),
|
||||||
|
int(decoder_inputs.size(1)/self.frames_per_step), -1)
|
||||||
|
# (B, T_out//r, r*num_mels) -> (T_out//r, B, r*num_mels)
|
||||||
|
decoder_inputs = decoder_inputs.transpose(0, 1)
|
||||||
|
# (T_out//r, B, num_mels)
|
||||||
|
decoder_inputs = decoder_inputs[:,:,-self.num_mels:]
|
||||||
|
return decoder_inputs
|
||||||
|
|
||||||
|
def parse_decoder_outputs(self, mel_outputs, alignments, stop_outputs):
|
||||||
|
""" Prepares decoder outputs for output
|
||||||
|
Args:
|
||||||
|
mel_outputs:
|
||||||
|
alignments:
|
||||||
|
"""
|
||||||
|
# (T_out//r, B, T_enc) -> (B, T_out//r, T_enc)
|
||||||
|
alignments = torch.stack(alignments).transpose(0, 1)
|
||||||
|
# (T_out//r, B) -> (B, T_out//r)
|
||||||
|
if stop_outputs is not None:
|
||||||
|
if alignments.size(0) == 1:
|
||||||
|
stop_outputs = torch.stack(stop_outputs).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
stop_outputs = torch.stack(stop_outputs).transpose(0, 1)
|
||||||
|
stop_outputs = stop_outputs.contiguous()
|
||||||
|
# (T_out//r, B, num_mels*r) -> (B, T_out//r, num_mels*r)
|
||||||
|
mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
|
||||||
|
# decouple frames per step
|
||||||
|
# (B, T_out, num_mels)
|
||||||
|
mel_outputs = mel_outputs.view(
|
||||||
|
mel_outputs.size(0), -1, self.num_mels)
|
||||||
|
return mel_outputs, alignments, stop_outputs
|
||||||
|
|
||||||
|
def attend(self, decoder_input):
|
||||||
|
cell_input = torch.cat((decoder_input, self.attention_context), -1)
|
||||||
|
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
||||||
|
cell_input, (self.attention_hidden, self.attention_cell))
|
||||||
|
self.attention_context, attention_weights = self.attention_layer(
|
||||||
|
self.attention_hidden, self.memory, None, self.mask)
|
||||||
|
|
||||||
|
decoder_rnn_input = torch.cat(
|
||||||
|
(self.attention_hidden, self.attention_context), -1)
|
||||||
|
|
||||||
|
return decoder_rnn_input, self.attention_context, attention_weights
|
||||||
|
|
||||||
|
def decode(self, decoder_input):
|
||||||
|
for i in range(self.num_decoder_rnn_layer):
|
||||||
|
if i == 0:
|
||||||
|
self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
|
||||||
|
decoder_input, (self.decoder_hiddens[i], self.decoder_cells[i]))
|
||||||
|
else:
|
||||||
|
self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
|
||||||
|
self.decoder_hiddens[i-1], (self.decoder_hiddens[i], self.decoder_cells[i]))
|
||||||
|
return self.decoder_hiddens[-1]
|
||||||
|
|
||||||
|
def forward(self, memory, mel_inputs, memory_lengths):
|
||||||
|
""" Decoder forward pass for training
|
||||||
|
Args:
|
||||||
|
memory: (B, T_enc, enc_dim) Encoder outputs
|
||||||
|
decoder_inputs: (B, T, num_mels) Decoder inputs for teacher forcing.
|
||||||
|
memory_lengths: (B, ) Encoder output lengths for attention masking.
|
||||||
|
Returns:
|
||||||
|
mel_outputs: (B, T, num_mels) mel outputs from the decoder
|
||||||
|
alignments: (B, T//r, T_enc) attention weights.
|
||||||
|
"""
|
||||||
|
# [1, B, num_mels]
|
||||||
|
go_frame = self.get_go_frame(memory).unsqueeze(0)
|
||||||
|
# [T//r, B, num_mels]
|
||||||
|
mel_inputs = self.parse_decoder_inputs(mel_inputs)
|
||||||
|
# [T//r + 1, B, num_mels]
|
||||||
|
mel_inputs = torch.cat((go_frame, mel_inputs), dim=0)
|
||||||
|
# [T//r + 1, B, prenet_dim]
|
||||||
|
decoder_inputs = self.prenet(mel_inputs)
|
||||||
|
# decoder_inputs_pitch = self.prenet_pitch(decoder_inputs__)
|
||||||
|
|
||||||
|
self.initialize_decoder_states(
|
||||||
|
memory, mask=~get_mask_from_lengths(memory_lengths),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attention_layer.init_states(memory)
|
||||||
|
# self.attention_layer_pitch.init_states(memory_pitch)
|
||||||
|
|
||||||
|
mel_outputs, alignments = [], []
|
||||||
|
if self.use_stop_tokens:
|
||||||
|
stop_outputs = []
|
||||||
|
else:
|
||||||
|
stop_outputs = None
|
||||||
|
while len(mel_outputs) < decoder_inputs.size(0) - 1:
|
||||||
|
decoder_input = decoder_inputs[len(mel_outputs)]
|
||||||
|
# decoder_input_pitch = decoder_inputs_pitch[len(mel_outputs)]
|
||||||
|
|
||||||
|
decoder_rnn_input, context, attention_weights = self.attend(decoder_input)
|
||||||
|
|
||||||
|
decoder_rnn_output = self.decode(decoder_rnn_input)
|
||||||
|
if self.concat_context_to_last:
|
||||||
|
decoder_rnn_output = torch.cat(
|
||||||
|
(decoder_rnn_output, context), dim=1)
|
||||||
|
|
||||||
|
mel_output = self.linear_projection(decoder_rnn_output)
|
||||||
|
if self.use_stop_tokens:
|
||||||
|
stop_output = self.stop_layer(decoder_rnn_output)
|
||||||
|
stop_outputs += [stop_output.squeeze()]
|
||||||
|
mel_outputs += [mel_output.squeeze(1)] #? perhaps don't need squeeze
|
||||||
|
alignments += [attention_weights]
|
||||||
|
# alignments_pitch += [attention_weights_pitch]
|
||||||
|
|
||||||
|
mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
|
||||||
|
mel_outputs, alignments, stop_outputs)
|
||||||
|
if stop_outputs is None:
|
||||||
|
return mel_outputs, alignments
|
||||||
|
else:
|
||||||
|
return mel_outputs, stop_outputs, alignments
|
||||||
|
|
||||||
|
def inference(self, memory, stop_threshold=0.5):
|
||||||
|
""" Decoder inference
|
||||||
|
Args:
|
||||||
|
memory: (1, T_enc, D_enc) Encoder outputs
|
||||||
|
Returns:
|
||||||
|
mel_outputs: mel outputs from the decoder
|
||||||
|
alignments: sequence of attention weights from the decoder
|
||||||
|
"""
|
||||||
|
# [1, num_mels]
|
||||||
|
decoder_input = self.get_go_frame(memory)
|
||||||
|
|
||||||
|
self.initialize_decoder_states(memory, mask=None)
|
||||||
|
|
||||||
|
self.attention_layer.init_states(memory)
|
||||||
|
|
||||||
|
mel_outputs, alignments = [], []
|
||||||
|
# NOTE(sx): heuristic
|
||||||
|
max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step
|
||||||
|
min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
|
||||||
|
while True:
|
||||||
|
decoder_input = self.prenet(decoder_input)
|
||||||
|
|
||||||
|
decoder_input_final, context, alignment = self.attend(decoder_input)
|
||||||
|
|
||||||
|
#mel_output, stop_output, alignment = self.decode(decoder_input)
|
||||||
|
decoder_rnn_output = self.decode(decoder_input_final)
|
||||||
|
if self.concat_context_to_last:
|
||||||
|
decoder_rnn_output = torch.cat(
|
||||||
|
(decoder_rnn_output, context), dim=1)
|
||||||
|
|
||||||
|
mel_output = self.linear_projection(decoder_rnn_output)
|
||||||
|
stop_output = self.stop_layer(decoder_rnn_output)
|
||||||
|
|
||||||
|
mel_outputs += [mel_output.squeeze(1)]
|
||||||
|
alignments += [alignment]
|
||||||
|
|
||||||
|
if torch.sigmoid(stop_output.data) > stop_threshold and len(mel_outputs) >= min_decoder_step:
|
||||||
|
break
|
||||||
|
if len(mel_outputs) >= max_decoder_step:
|
||||||
|
# print("Warning! Decoding steps reaches max decoder steps.")
|
||||||
|
break
|
||||||
|
|
||||||
|
decoder_input = mel_output[:,-self.num_mels:]
|
||||||
|
|
||||||
|
|
||||||
|
mel_outputs, alignments, _ = self.parse_decoder_outputs(
|
||||||
|
mel_outputs, alignments, None)
|
||||||
|
|
||||||
|
return mel_outputs, alignments
|
||||||
|
|
||||||
|
def inference_batched(self, memory, stop_threshold=0.5):
|
||||||
|
""" Decoder inference
|
||||||
|
Args:
|
||||||
|
memory: (B, T_enc, D_enc) Encoder outputs
|
||||||
|
Returns:
|
||||||
|
mel_outputs: mel outputs from the decoder
|
||||||
|
alignments: sequence of attention weights from the decoder
|
||||||
|
"""
|
||||||
|
# [1, num_mels]
|
||||||
|
decoder_input = self.get_go_frame(memory)
|
||||||
|
|
||||||
|
self.initialize_decoder_states(memory, mask=None)
|
||||||
|
|
||||||
|
self.attention_layer.init_states(memory)
|
||||||
|
|
||||||
|
mel_outputs, alignments = [], []
|
||||||
|
stop_outputs = []
|
||||||
|
# NOTE(sx): heuristic
|
||||||
|
max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step
|
||||||
|
min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
|
||||||
|
while True:
|
||||||
|
decoder_input = self.prenet(decoder_input)
|
||||||
|
|
||||||
|
decoder_input_final, context, alignment = self.attend(decoder_input)
|
||||||
|
|
||||||
|
#mel_output, stop_output, alignment = self.decode(decoder_input)
|
||||||
|
decoder_rnn_output = self.decode(decoder_input_final)
|
||||||
|
if self.concat_context_to_last:
|
||||||
|
decoder_rnn_output = torch.cat(
|
||||||
|
(decoder_rnn_output, context), dim=1)
|
||||||
|
|
||||||
|
mel_output = self.linear_projection(decoder_rnn_output)
|
||||||
|
# (B, 1)
|
||||||
|
stop_output = self.stop_layer(decoder_rnn_output)
|
||||||
|
stop_outputs += [stop_output.squeeze()]
|
||||||
|
# stop_outputs.append(stop_output)
|
||||||
|
|
||||||
|
mel_outputs += [mel_output.squeeze(1)]
|
||||||
|
alignments += [alignment]
|
||||||
|
# print(stop_output.shape)
|
||||||
|
if torch.all(torch.sigmoid(stop_output.squeeze().data) > stop_threshold) \
|
||||||
|
and len(mel_outputs) >= min_decoder_step:
|
||||||
|
break
|
||||||
|
if len(mel_outputs) >= max_decoder_step:
|
||||||
|
# print("Warning! Decoding steps reaches max decoder steps.")
|
||||||
|
break
|
||||||
|
|
||||||
|
decoder_input = mel_output[:,-self.num_mels:]
|
||||||
|
|
||||||
|
|
||||||
|
mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
|
||||||
|
mel_outputs, alignments, stop_outputs)
|
||||||
|
mel_outputs_stacked = []
|
||||||
|
for mel, stop_logit in zip(mel_outputs, stop_outputs):
|
||||||
|
idx = np.argwhere(torch.sigmoid(stop_logit.cpu()) > stop_threshold)[0][0].item()
|
||||||
|
mel_outputs_stacked.append(mel[:idx,:])
|
||||||
|
mel_outputs = torch.cat(mel_outputs_stacked, dim=0).unsqueeze(0)
|
||||||
|
return mel_outputs, alignments
|
||||||
62
ppg2mel/train.py
Normal file
62
ppg2mel/train.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from utils.load_yaml import HpsYaml
|
||||||
|
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||||
|
|
||||||
|
# For reproducibility, comment these may speed up training
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Arguments
|
||||||
|
parser = argparse.ArgumentParser(description=
|
||||||
|
'Training PPG2Mel VC model.')
|
||||||
|
parser.add_argument('--config', type=str,
|
||||||
|
help='Path to experiment config, e.g., config/vc.yaml')
|
||||||
|
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
||||||
|
parser.add_argument('--logdir', default='log/', type=str,
|
||||||
|
help='Logging path.', required=False)
|
||||||
|
parser.add_argument('--ckpdir', default='ckpt/', type=str,
|
||||||
|
help='Checkpoint path.', required=False)
|
||||||
|
parser.add_argument('--outdir', default='result/', type=str,
|
||||||
|
help='Decode output path.', required=False)
|
||||||
|
parser.add_argument('--load', default=None, type=str,
|
||||||
|
help='Load pre-trained model (for training only)', required=False)
|
||||||
|
parser.add_argument('--warm_start', action='store_true',
|
||||||
|
help='Load model weights only, ignore specified layers.')
|
||||||
|
parser.add_argument('--seed', default=0, type=int,
|
||||||
|
help='Random seed for reproducable results.', required=False)
|
||||||
|
parser.add_argument('--njobs', default=8, type=int,
|
||||||
|
help='Number of threads for dataloader/decoding.', required=False)
|
||||||
|
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
||||||
|
# parser.add_argument('--no-pin', action='store_true',
|
||||||
|
# help='Disable pin-memory for dataloader')
|
||||||
|
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
||||||
|
|
||||||
|
###
|
||||||
|
|
||||||
|
paras = parser.parse_args()
|
||||||
|
setattr(paras, 'gpu', not paras.cpu)
|
||||||
|
setattr(paras, 'pin_memory', not paras.no_pin)
|
||||||
|
setattr(paras, 'verbose', not paras.no_msg)
|
||||||
|
# Make the config dict dot visitable
|
||||||
|
config = HpsYaml(paras.config)
|
||||||
|
|
||||||
|
np.random.seed(paras.seed)
|
||||||
|
torch.manual_seed(paras.seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(paras.seed)
|
||||||
|
|
||||||
|
print(">>> OneShot VC training ...")
|
||||||
|
mode = "train"
|
||||||
|
solver = Solver(config, paras, mode)
|
||||||
|
solver.load_data()
|
||||||
|
solver.set_model()
|
||||||
|
solver.exec()
|
||||||
|
print(">>> Oneshot VC train finished!")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1
ppg2mel/train/__init__.py
Normal file
1
ppg2mel/train/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
#
|
||||||
50
ppg2mel/train/loss.py
Normal file
50
ppg2mel/train/loss.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
from typing import Dict
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from ..utils.nets_utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
class MaskedMSELoss(nn.Module):
|
||||||
|
def __init__(self, frames_per_step):
|
||||||
|
super().__init__()
|
||||||
|
self.frames_per_step = frames_per_step
|
||||||
|
self.mel_loss_criterion = nn.MSELoss(reduction='none')
|
||||||
|
# self.loss = nn.MSELoss()
|
||||||
|
self.stop_loss_criterion = nn.BCEWithLogitsLoss(reduction='none')
|
||||||
|
|
||||||
|
def get_mask(self, lengths, max_len=None):
|
||||||
|
# lengths: [B,]
|
||||||
|
if max_len is None:
|
||||||
|
max_len = torch.max(lengths)
|
||||||
|
batch_size = lengths.size(0)
|
||||||
|
seq_range = torch.arange(0, max_len).long()
|
||||||
|
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len).to(lengths.device)
|
||||||
|
seq_length_expand = lengths.unsqueeze(1).expand_as(seq_range_expand)
|
||||||
|
return (seq_range_expand < seq_length_expand).float()
|
||||||
|
|
||||||
|
def forward(self, mel_pred, mel_pred_postnet, mel_trg, lengths,
|
||||||
|
stop_target, stop_pred):
|
||||||
|
## process stop_target
|
||||||
|
B = stop_target.size(0)
|
||||||
|
stop_target = stop_target.reshape(B, -1, self.frames_per_step)[:, :, 0]
|
||||||
|
stop_lengths = torch.ceil(lengths.float() / self.frames_per_step).long()
|
||||||
|
stop_mask = self.get_mask(stop_lengths, int(mel_trg.size(1)/self.frames_per_step))
|
||||||
|
|
||||||
|
mel_trg.requires_grad = False
|
||||||
|
# (B, T, 1)
|
||||||
|
mel_mask = self.get_mask(lengths, mel_trg.size(1)).unsqueeze(-1)
|
||||||
|
# (B, T, D)
|
||||||
|
mel_mask = mel_mask.expand_as(mel_trg)
|
||||||
|
mel_loss_pre = (self.mel_loss_criterion(mel_pred, mel_trg) * mel_mask).sum() / mel_mask.sum()
|
||||||
|
mel_loss_post = (self.mel_loss_criterion(mel_pred_postnet, mel_trg) * mel_mask).sum() / mel_mask.sum()
|
||||||
|
|
||||||
|
mel_loss = mel_loss_pre + mel_loss_post
|
||||||
|
|
||||||
|
# stop token loss
|
||||||
|
stop_loss = torch.sum(self.stop_loss_criterion(stop_pred, stop_target) * stop_mask) / stop_mask.sum()
|
||||||
|
|
||||||
|
return mel_loss, stop_loss
|
||||||
45
ppg2mel/train/optim.py
Normal file
45
ppg2mel/train/optim.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class Optimizer():
|
||||||
|
def __init__(self, parameters, optimizer, lr, eps, lr_scheduler,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
# Setup torch optimizer
|
||||||
|
self.opt_type = optimizer
|
||||||
|
self.init_lr = lr
|
||||||
|
self.sch_type = lr_scheduler
|
||||||
|
opt = getattr(torch.optim, optimizer)
|
||||||
|
if lr_scheduler == 'warmup':
|
||||||
|
warmup_step = 4000.0
|
||||||
|
init_lr = lr
|
||||||
|
self.lr_scheduler = lambda step: init_lr * warmup_step ** 0.5 * \
|
||||||
|
np.minimum((step+1)*warmup_step**-1.5, (step+1)**-0.5)
|
||||||
|
self.opt = opt(parameters, lr=1.0)
|
||||||
|
else:
|
||||||
|
self.lr_scheduler = None
|
||||||
|
self.opt = opt(parameters, lr=lr, eps=eps) # ToDo: 1e-8 better?
|
||||||
|
|
||||||
|
def get_opt_state_dict(self):
|
||||||
|
return self.opt.state_dict()
|
||||||
|
|
||||||
|
def load_opt_state_dict(self, state_dict):
|
||||||
|
self.opt.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
def pre_step(self, step):
|
||||||
|
if self.lr_scheduler is not None:
|
||||||
|
cur_lr = self.lr_scheduler(step)
|
||||||
|
for param_group in self.opt.param_groups:
|
||||||
|
param_group['lr'] = cur_lr
|
||||||
|
else:
|
||||||
|
cur_lr = self.init_lr
|
||||||
|
self.opt.zero_grad()
|
||||||
|
return cur_lr
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
self.opt.step()
|
||||||
|
|
||||||
|
def create_msg(self):
|
||||||
|
return ['Optim.Info.| Algo. = {}\t| Lr = {}\t (schedule = {})'
|
||||||
|
.format(self.opt_type, self.init_lr, self.sch_type)]
|
||||||
10
ppg2mel/train/option.py
Normal file
10
ppg2mel/train/option.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# Default parameters which will be imported by solver
|
||||||
|
default_hparas = {
|
||||||
|
'GRAD_CLIP': 5.0, # Grad. clip threshold
|
||||||
|
'PROGRESS_STEP': 100, # Std. output refresh freq.
|
||||||
|
# Decode steps for objective validation (step = ratio*input_txt_len)
|
||||||
|
'DEV_STEP_RATIO': 1.2,
|
||||||
|
# Number of examples (alignment/text) to show in tensorboard
|
||||||
|
'DEV_N_EXAMPLE': 4,
|
||||||
|
'TB_FLUSH_FREQ': 180 # Update frequency of tensorboard (secs)
|
||||||
|
}
|
||||||
217
ppg2mel/train/solver.py
Normal file
217
ppg2mel/train/solver.py
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import abc
|
||||||
|
import math
|
||||||
|
import yaml
|
||||||
|
import torch
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from .option import default_hparas
|
||||||
|
from utils.util import human_format, Timer
|
||||||
|
from utils.load_yaml import HpsYaml
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSolver():
|
||||||
|
'''
|
||||||
|
Prototype Solver for all kinds of tasks
|
||||||
|
Arguments
|
||||||
|
config - yaml-styled config
|
||||||
|
paras - argparse outcome
|
||||||
|
mode - "train"/"test"
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, config, paras, mode="train"):
|
||||||
|
# General Settings
|
||||||
|
self.config = config # load from yaml file
|
||||||
|
self.paras = paras # command line args
|
||||||
|
self.mode = mode # 'train' or 'test'
|
||||||
|
for k, v in default_hparas.items():
|
||||||
|
setattr(self, k, v)
|
||||||
|
self.device = torch.device('cuda') if self.paras.gpu and torch.cuda.is_available() \
|
||||||
|
else torch.device('cpu')
|
||||||
|
|
||||||
|
# Name experiment
|
||||||
|
self.exp_name = paras.name
|
||||||
|
if self.exp_name is None:
|
||||||
|
if 'exp_name' in self.config:
|
||||||
|
self.exp_name = self.config.exp_name
|
||||||
|
else:
|
||||||
|
# By default, exp is named after config file
|
||||||
|
self.exp_name = paras.config.split('/')[-1].replace('.yaml', '')
|
||||||
|
if mode == 'train':
|
||||||
|
self.exp_name += '_seed{}'.format(paras.seed)
|
||||||
|
|
||||||
|
|
||||||
|
if mode == 'train':
|
||||||
|
# Filepath setup
|
||||||
|
os.makedirs(paras.ckpdir, exist_ok=True)
|
||||||
|
self.ckpdir = os.path.join(paras.ckpdir, self.exp_name)
|
||||||
|
os.makedirs(self.ckpdir, exist_ok=True)
|
||||||
|
|
||||||
|
# Logger settings
|
||||||
|
self.logdir = os.path.join(paras.logdir, self.exp_name)
|
||||||
|
self.log = SummaryWriter(
|
||||||
|
self.logdir, flush_secs=self.TB_FLUSH_FREQ)
|
||||||
|
self.timer = Timer()
|
||||||
|
|
||||||
|
# Hyper-parameters
|
||||||
|
self.step = 0
|
||||||
|
self.valid_step = config.hparas.valid_step
|
||||||
|
self.max_step = config.hparas.max_step
|
||||||
|
|
||||||
|
self.verbose('Exp. name : {}'.format(self.exp_name))
|
||||||
|
self.verbose('Loading data... large corpus may took a while.')
|
||||||
|
|
||||||
|
# elif mode == 'test':
|
||||||
|
# # Output path
|
||||||
|
# os.makedirs(paras.outdir, exist_ok=True)
|
||||||
|
# self.ckpdir = os.path.join(paras.outdir, self.exp_name)
|
||||||
|
|
||||||
|
# Load training config to get acoustic feat and build model
|
||||||
|
# self.src_config = HpsYaml(config.src.config)
|
||||||
|
# self.paras.load = config.src.ckpt
|
||||||
|
|
||||||
|
# self.verbose('Evaluating result of tr. config @ {}'.format(
|
||||||
|
# config.src.config))
|
||||||
|
|
||||||
|
def backward(self, loss):
|
||||||
|
'''
|
||||||
|
Standard backward step with self.timer and debugger
|
||||||
|
Arguments
|
||||||
|
loss - the loss to perform loss.backward()
|
||||||
|
'''
|
||||||
|
self.timer.set()
|
||||||
|
loss.backward()
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
self.model.parameters(), self.GRAD_CLIP)
|
||||||
|
if math.isnan(grad_norm):
|
||||||
|
self.verbose('Error : grad norm is NaN @ step '+str(self.step))
|
||||||
|
else:
|
||||||
|
self.optimizer.step()
|
||||||
|
self.timer.cnt('bw')
|
||||||
|
return grad_norm
|
||||||
|
|
||||||
|
def load_ckpt(self):
|
||||||
|
''' Load ckpt if --load option is specified '''
|
||||||
|
print(self.paras)
|
||||||
|
if self.paras.load is not None:
|
||||||
|
if self.paras.warm_start:
|
||||||
|
self.verbose(f"Warm starting model from checkpoint {self.paras.load}.")
|
||||||
|
ckpt = torch.load(
|
||||||
|
self.paras.load, map_location=self.device if self.mode == 'train'
|
||||||
|
else 'cpu')
|
||||||
|
model_dict = ckpt['model']
|
||||||
|
if "ignore_layers" in self.config.model and len(self.config.model.ignore_layers) > 0:
|
||||||
|
model_dict = {k:v for k, v in model_dict.items()
|
||||||
|
if k not in self.config.model.ignore_layers}
|
||||||
|
dummy_dict = self.model.state_dict()
|
||||||
|
dummy_dict.update(model_dict)
|
||||||
|
model_dict = dummy_dict
|
||||||
|
self.model.load_state_dict(model_dict)
|
||||||
|
else:
|
||||||
|
# Load weights
|
||||||
|
ckpt = torch.load(
|
||||||
|
self.paras.load, map_location=self.device if self.mode == 'train'
|
||||||
|
else 'cpu')
|
||||||
|
self.model.load_state_dict(ckpt['model'])
|
||||||
|
|
||||||
|
# Load task-dependent items
|
||||||
|
if self.mode == 'train':
|
||||||
|
self.step = ckpt['global_step']
|
||||||
|
self.optimizer.load_opt_state_dict(ckpt['optimizer'])
|
||||||
|
self.verbose('Load ckpt from {}, restarting at step {}'.format(
|
||||||
|
self.paras.load, self.step))
|
||||||
|
else:
|
||||||
|
for k, v in ckpt.items():
|
||||||
|
if type(v) is float:
|
||||||
|
metric, score = k, v
|
||||||
|
self.model.eval()
|
||||||
|
self.verbose('Evaluation target = {} (recorded {} = {:.2f} %)'.format(
|
||||||
|
self.paras.load, metric, score))
|
||||||
|
|
||||||
|
def verbose(self, msg):
|
||||||
|
''' Verbose function for print information to stdout'''
|
||||||
|
if self.paras.verbose:
|
||||||
|
if type(msg) == list:
|
||||||
|
for m in msg:
|
||||||
|
print('[INFO]', m.ljust(100))
|
||||||
|
else:
|
||||||
|
print('[INFO]', msg.ljust(100))
|
||||||
|
|
||||||
|
def progress(self, msg):
|
||||||
|
''' Verbose function for updating progress on stdout (do not include newline) '''
|
||||||
|
if self.paras.verbose:
|
||||||
|
sys.stdout.write("\033[K") # Clear line
|
||||||
|
print('[{}] {}'.format(human_format(self.step), msg), end='\r')
|
||||||
|
|
||||||
|
def write_log(self, log_name, log_dict):
|
||||||
|
'''
|
||||||
|
Write log to TensorBoard
|
||||||
|
log_name - <str> Name of tensorboard variable
|
||||||
|
log_value - <dict>/<array> Value of variable (e.g. dict of losses), passed if value = None
|
||||||
|
'''
|
||||||
|
if type(log_dict) is dict:
|
||||||
|
log_dict = {key: val for key, val in log_dict.items() if (
|
||||||
|
val is not None and not math.isnan(val))}
|
||||||
|
if log_dict is None:
|
||||||
|
pass
|
||||||
|
elif len(log_dict) > 0:
|
||||||
|
if 'align' in log_name or 'spec' in log_name:
|
||||||
|
img, form = log_dict
|
||||||
|
self.log.add_image(
|
||||||
|
log_name, img, global_step=self.step, dataformats=form)
|
||||||
|
elif 'text' in log_name or 'hyp' in log_name:
|
||||||
|
self.log.add_text(log_name, log_dict, self.step)
|
||||||
|
else:
|
||||||
|
self.log.add_scalars(log_name, log_dict, self.step)
|
||||||
|
|
||||||
|
def save_checkpoint(self, f_name, metric, score, show_msg=True):
|
||||||
|
''''
|
||||||
|
Ckpt saver
|
||||||
|
f_name - <str> the name of ckpt file (w/o prefix) to store, overwrite if existed
|
||||||
|
score - <float> The value of metric used to evaluate model
|
||||||
|
'''
|
||||||
|
ckpt_path = os.path.join(self.ckpdir, f_name)
|
||||||
|
full_dict = {
|
||||||
|
"model": self.model.state_dict(),
|
||||||
|
"optimizer": self.optimizer.get_opt_state_dict(),
|
||||||
|
"global_step": self.step,
|
||||||
|
metric: score
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save(full_dict, ckpt_path)
|
||||||
|
if show_msg:
|
||||||
|
self.verbose("Saved checkpoint (step = {}, {} = {:.2f}) and status @ {}".
|
||||||
|
format(human_format(self.step), metric, score, ckpt_path))
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------- Abtract Methods ------------------------------------------ #
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load_data(self):
|
||||||
|
'''
|
||||||
|
Called by main to load all data
|
||||||
|
After this call, data related attributes should be setup (e.g. self.tr_set, self.dev_set)
|
||||||
|
No return value
|
||||||
|
'''
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def set_model(self):
|
||||||
|
'''
|
||||||
|
Called by main to set models
|
||||||
|
After this call, model related attributes should be setup (e.g. self.l2_loss)
|
||||||
|
The followings MUST be setup
|
||||||
|
- self.model (torch.nn.Module)
|
||||||
|
- self.optimizer (src.Optimizer),
|
||||||
|
init. w/ self.optimizer = src.Optimizer(self.model.parameters(),**self.config['hparas'])
|
||||||
|
Loading pre-trained model should also be performed here
|
||||||
|
No return value
|
||||||
|
'''
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def exec(self):
|
||||||
|
'''
|
||||||
|
Called by main to execute training/inference
|
||||||
|
'''
|
||||||
|
raise NotImplementedError
|
||||||
288
ppg2mel/train/train_linglf02mel_seq2seq_oneshotvc.py
Normal file
288
ppg2mel/train/train_linglf02mel_seq2seq_oneshotvc.py
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
import os, sys
|
||||||
|
# sys.path.append('/home/shaunxliu/projects/nnsp')
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.ticker import MaxNLocator
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import numpy as np
|
||||||
|
from .solver import BaseSolver
|
||||||
|
from utils.data_load import OneshotVcDataset, MultiSpkVcCollate
|
||||||
|
# from src.rnn_ppg2mel import BiRnnPpg2MelModel
|
||||||
|
# from src.mel_decoder_mol_encAddlf0 import MelDecoderMOL
|
||||||
|
from .loss import MaskedMSELoss
|
||||||
|
from .optim import Optimizer
|
||||||
|
from utils.util import human_format
|
||||||
|
from ppg2mel import MelDecoderMOLv2
|
||||||
|
|
||||||
|
|
||||||
|
class Solver(BaseSolver):
|
||||||
|
"""Customized Solver."""
|
||||||
|
def __init__(self, config, paras, mode):
|
||||||
|
super().__init__(config, paras, mode)
|
||||||
|
self.num_att_plots = 5
|
||||||
|
self.att_ws_dir = f"{self.logdir}/att_ws"
|
||||||
|
os.makedirs(self.att_ws_dir, exist_ok=True)
|
||||||
|
self.best_loss = np.inf
|
||||||
|
|
||||||
|
def fetch_data(self, data):
|
||||||
|
"""Move data to device"""
|
||||||
|
data = [i.to(self.device) for i in data]
|
||||||
|
return data
|
||||||
|
|
||||||
|
def load_data(self):
|
||||||
|
""" Load data for training/validation/plotting."""
|
||||||
|
train_dataset = OneshotVcDataset(
|
||||||
|
meta_file=self.config.data.train_fid_list,
|
||||||
|
vctk_ppg_dir=self.config.data.vctk_ppg_dir,
|
||||||
|
libri_ppg_dir=self.config.data.libri_ppg_dir,
|
||||||
|
vctk_f0_dir=self.config.data.vctk_f0_dir,
|
||||||
|
libri_f0_dir=self.config.data.libri_f0_dir,
|
||||||
|
vctk_wav_dir=self.config.data.vctk_wav_dir,
|
||||||
|
libri_wav_dir=self.config.data.libri_wav_dir,
|
||||||
|
vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir,
|
||||||
|
libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir,
|
||||||
|
ppg_file_ext=self.config.data.ppg_file_ext,
|
||||||
|
min_max_norm_mel=self.config.data.min_max_norm_mel,
|
||||||
|
mel_min=self.config.data.mel_min,
|
||||||
|
mel_max=self.config.data.mel_max,
|
||||||
|
)
|
||||||
|
dev_dataset = OneshotVcDataset(
|
||||||
|
meta_file=self.config.data.dev_fid_list,
|
||||||
|
vctk_ppg_dir=self.config.data.vctk_ppg_dir,
|
||||||
|
libri_ppg_dir=self.config.data.libri_ppg_dir,
|
||||||
|
vctk_f0_dir=self.config.data.vctk_f0_dir,
|
||||||
|
libri_f0_dir=self.config.data.libri_f0_dir,
|
||||||
|
vctk_wav_dir=self.config.data.vctk_wav_dir,
|
||||||
|
libri_wav_dir=self.config.data.libri_wav_dir,
|
||||||
|
vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir,
|
||||||
|
libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir,
|
||||||
|
ppg_file_ext=self.config.data.ppg_file_ext,
|
||||||
|
min_max_norm_mel=self.config.data.min_max_norm_mel,
|
||||||
|
mel_min=self.config.data.mel_min,
|
||||||
|
mel_max=self.config.data.mel_max,
|
||||||
|
)
|
||||||
|
self.train_dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
num_workers=self.paras.njobs,
|
||||||
|
shuffle=True,
|
||||||
|
batch_size=self.config.hparas.batch_size,
|
||||||
|
pin_memory=False,
|
||||||
|
drop_last=True,
|
||||||
|
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
|
||||||
|
use_spk_dvec=True),
|
||||||
|
)
|
||||||
|
self.dev_dataloader = DataLoader(
|
||||||
|
dev_dataset,
|
||||||
|
num_workers=self.paras.njobs,
|
||||||
|
shuffle=False,
|
||||||
|
batch_size=self.config.hparas.batch_size,
|
||||||
|
pin_memory=False,
|
||||||
|
drop_last=False,
|
||||||
|
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
|
||||||
|
use_spk_dvec=True),
|
||||||
|
)
|
||||||
|
self.plot_dataloader = DataLoader(
|
||||||
|
dev_dataset,
|
||||||
|
num_workers=self.paras.njobs,
|
||||||
|
shuffle=False,
|
||||||
|
batch_size=1,
|
||||||
|
pin_memory=False,
|
||||||
|
drop_last=False,
|
||||||
|
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
|
||||||
|
use_spk_dvec=True,
|
||||||
|
give_uttids=True),
|
||||||
|
)
|
||||||
|
msg = "Have prepared training set and dev set."
|
||||||
|
self.verbose(msg)
|
||||||
|
|
||||||
|
def load_pretrained_params(self):
|
||||||
|
print("Load pretrained model from: ", self.config.data.pretrain_model_file)
|
||||||
|
ignore_layer_prefixes = ["speaker_embedding_table"]
|
||||||
|
pretrain_model_file = self.config.data.pretrain_model_file
|
||||||
|
pretrain_ckpt = torch.load(
|
||||||
|
pretrain_model_file, map_location=self.device
|
||||||
|
)["model"]
|
||||||
|
model_dict = self.model.state_dict()
|
||||||
|
print(self.model)
|
||||||
|
|
||||||
|
# 1. filter out unnecessrary keys
|
||||||
|
for prefix in ignore_layer_prefixes:
|
||||||
|
pretrain_ckpt = {k : v
|
||||||
|
for k, v in pretrain_ckpt.items() if not k.startswith(prefix)
|
||||||
|
}
|
||||||
|
# 2. overwrite entries in the existing state dict
|
||||||
|
model_dict.update(pretrain_ckpt)
|
||||||
|
|
||||||
|
# 3. load the new state dict
|
||||||
|
self.model.load_state_dict(model_dict)
|
||||||
|
|
||||||
|
def set_model(self):
|
||||||
|
"""Setup model and optimizer"""
|
||||||
|
# Model
|
||||||
|
print("[INFO] Model name: ", self.config["model_name"])
|
||||||
|
self.model = MelDecoderMOLv2(
|
||||||
|
**self.config["model"]
|
||||||
|
).to(self.device)
|
||||||
|
# self.load_pretrained_params()
|
||||||
|
|
||||||
|
# model_params = [{'params': self.model.spk_embedding.weight}]
|
||||||
|
model_params = [{'params': self.model.parameters()}]
|
||||||
|
|
||||||
|
# Loss criterion
|
||||||
|
self.loss_criterion = MaskedMSELoss(self.config.model.frames_per_step)
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
self.optimizer = Optimizer(model_params, **self.config["hparas"])
|
||||||
|
self.verbose(self.optimizer.create_msg())
|
||||||
|
|
||||||
|
# Automatically load pre-trained model if self.paras.load is given
|
||||||
|
self.load_ckpt()
|
||||||
|
|
||||||
|
def exec(self):
|
||||||
|
self.verbose("Total training steps {}.".format(
|
||||||
|
human_format(self.max_step)))
|
||||||
|
|
||||||
|
mel_loss = None
|
||||||
|
n_epochs = 0
|
||||||
|
# Set as current time
|
||||||
|
self.timer.set()
|
||||||
|
|
||||||
|
while self.step < self.max_step:
|
||||||
|
for data in self.train_dataloader:
|
||||||
|
# Pre-step: updata lr_rate and do zero_grad
|
||||||
|
lr_rate = self.optimizer.pre_step(self.step)
|
||||||
|
total_loss = 0
|
||||||
|
# data to device
|
||||||
|
ppgs, lf0_uvs, mels, in_lengths, \
|
||||||
|
out_lengths, spk_ids, stop_tokens = self.fetch_data(data)
|
||||||
|
self.timer.cnt("rd")
|
||||||
|
mel_outputs, mel_outputs_postnet, predicted_stop = self.model(
|
||||||
|
ppgs,
|
||||||
|
in_lengths,
|
||||||
|
mels,
|
||||||
|
out_lengths,
|
||||||
|
lf0_uvs,
|
||||||
|
spk_ids
|
||||||
|
)
|
||||||
|
mel_loss, stop_loss = self.loss_criterion(
|
||||||
|
mel_outputs,
|
||||||
|
mel_outputs_postnet,
|
||||||
|
mels,
|
||||||
|
out_lengths,
|
||||||
|
stop_tokens,
|
||||||
|
predicted_stop
|
||||||
|
)
|
||||||
|
loss = mel_loss + stop_loss
|
||||||
|
|
||||||
|
self.timer.cnt("fw")
|
||||||
|
|
||||||
|
# Back-prop
|
||||||
|
grad_norm = self.backward(loss)
|
||||||
|
self.step += 1
|
||||||
|
|
||||||
|
# Logger
|
||||||
|
if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
|
||||||
|
self.progress("Tr|loss:{:.4f},mel-loss:{:.4f},stop-loss:{:.4f}|Grad.Norm-{:.2f}|{}"
|
||||||
|
.format(loss.cpu().item(), mel_loss.cpu().item(),
|
||||||
|
stop_loss.cpu().item(), grad_norm, self.timer.show()))
|
||||||
|
self.write_log('loss', {'tr/loss': loss,
|
||||||
|
'tr/mel-loss': mel_loss,
|
||||||
|
'tr/stop-loss': stop_loss})
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
if (self.step == 1) or (self.step % self.valid_step == 0):
|
||||||
|
self.validate()
|
||||||
|
|
||||||
|
# End of step
|
||||||
|
# https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
self.timer.set()
|
||||||
|
if self.step > self.max_step:
|
||||||
|
break
|
||||||
|
n_epochs += 1
|
||||||
|
self.log.close()
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
self.model.eval()
|
||||||
|
dev_loss, dev_mel_loss, dev_stop_loss = 0.0, 0.0, 0.0
|
||||||
|
|
||||||
|
for i, data in enumerate(self.dev_dataloader):
|
||||||
|
self.progress('Valid step - {}/{}'.format(i+1, len(self.dev_dataloader)))
|
||||||
|
# Fetch data
|
||||||
|
ppgs, lf0_uvs, mels, in_lengths, \
|
||||||
|
out_lengths, spk_ids, stop_tokens = self.fetch_data(data)
|
||||||
|
with torch.no_grad():
|
||||||
|
mel_outputs, mel_outputs_postnet, predicted_stop = self.model(
|
||||||
|
ppgs,
|
||||||
|
in_lengths,
|
||||||
|
mels,
|
||||||
|
out_lengths,
|
||||||
|
lf0_uvs,
|
||||||
|
spk_ids
|
||||||
|
)
|
||||||
|
mel_loss, stop_loss = self.loss_criterion(
|
||||||
|
mel_outputs,
|
||||||
|
mel_outputs_postnet,
|
||||||
|
mels,
|
||||||
|
out_lengths,
|
||||||
|
stop_tokens,
|
||||||
|
predicted_stop
|
||||||
|
)
|
||||||
|
loss = mel_loss + stop_loss
|
||||||
|
|
||||||
|
dev_loss += loss.cpu().item()
|
||||||
|
dev_mel_loss += mel_loss.cpu().item()
|
||||||
|
dev_stop_loss += stop_loss.cpu().item()
|
||||||
|
|
||||||
|
dev_loss = dev_loss / (i + 1)
|
||||||
|
dev_mel_loss = dev_mel_loss / (i + 1)
|
||||||
|
dev_stop_loss = dev_stop_loss / (i + 1)
|
||||||
|
self.save_checkpoint(f'step_{self.step}.pth', 'loss', dev_loss, show_msg=False)
|
||||||
|
if dev_loss < self.best_loss:
|
||||||
|
self.best_loss = dev_loss
|
||||||
|
self.save_checkpoint(f'best_loss_step_{self.step}.pth', 'loss', dev_loss)
|
||||||
|
self.write_log('loss', {'dv/loss': dev_loss,
|
||||||
|
'dv/mel-loss': dev_mel_loss,
|
||||||
|
'dv/stop-loss': dev_stop_loss})
|
||||||
|
|
||||||
|
# plot attention
|
||||||
|
for i, data in enumerate(self.plot_dataloader):
|
||||||
|
if i == self.num_att_plots:
|
||||||
|
break
|
||||||
|
# Fetch data
|
||||||
|
ppgs, lf0_uvs, mels, in_lengths, \
|
||||||
|
out_lengths, spk_ids, stop_tokens = self.fetch_data(data[:-1])
|
||||||
|
fid = data[-1][0]
|
||||||
|
with torch.no_grad():
|
||||||
|
_, _, _, att_ws = self.model(
|
||||||
|
ppgs,
|
||||||
|
in_lengths,
|
||||||
|
mels,
|
||||||
|
out_lengths,
|
||||||
|
lf0_uvs,
|
||||||
|
spk_ids,
|
||||||
|
output_att_ws=True
|
||||||
|
)
|
||||||
|
att_ws = att_ws.squeeze(0).cpu().numpy()
|
||||||
|
att_ws = att_ws[None]
|
||||||
|
w, h = plt.figaspect(1.0 / len(att_ws))
|
||||||
|
fig = plt.Figure(figsize=(w * 1.3, h * 1.3))
|
||||||
|
axes = fig.subplots(1, len(att_ws))
|
||||||
|
if len(att_ws) == 1:
|
||||||
|
axes = [axes]
|
||||||
|
|
||||||
|
for ax, aw in zip(axes, att_ws):
|
||||||
|
ax.imshow(aw.astype(np.float32), aspect="auto")
|
||||||
|
ax.set_title(f"{fid}")
|
||||||
|
ax.set_xlabel("Input")
|
||||||
|
ax.set_ylabel("Output")
|
||||||
|
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||||||
|
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
|
||||||
|
fig_name = f"{self.att_ws_dir}/{fid}_step{self.step}.png"
|
||||||
|
fig.savefig(fig_name)
|
||||||
|
|
||||||
|
# Resume training
|
||||||
|
self.model.train()
|
||||||
|
|
||||||
23
ppg2mel/utils/abs_model.py
Normal file
23
ppg2mel/utils/abs_model.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from abc import ABC
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class AbsMelDecoder(torch.nn.Module, ABC):
|
||||||
|
"""The abstract PPG-based voice conversion class
|
||||||
|
This "model" is one of mediator objects for "Task" class.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
bottle_neck_features: torch.Tensor,
|
||||||
|
feature_lengths: torch.Tensor,
|
||||||
|
speech: torch.Tensor,
|
||||||
|
speech_lengths: torch.Tensor,
|
||||||
|
logf0_uv: torch.Tensor = None,
|
||||||
|
spembs: torch.Tensor = None,
|
||||||
|
styleembs: torch.Tensor = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
79
ppg2mel/utils/basic_layers.py
Normal file
79
ppg2mel/utils/basic_layers.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.autograd import Function
|
||||||
|
|
||||||
|
def tile(x, count, dim=0):
|
||||||
|
"""
|
||||||
|
Tiles x on dimension dim count times.
|
||||||
|
"""
|
||||||
|
perm = list(range(len(x.size())))
|
||||||
|
if dim != 0:
|
||||||
|
perm[0], perm[dim] = perm[dim], perm[0]
|
||||||
|
x = x.permute(perm).contiguous()
|
||||||
|
out_size = list(x.size())
|
||||||
|
out_size[0] *= count
|
||||||
|
batch = x.size(0)
|
||||||
|
x = x.view(batch, -1) \
|
||||||
|
.transpose(0, 1) \
|
||||||
|
.repeat(count, 1) \
|
||||||
|
.transpose(0, 1) \
|
||||||
|
.contiguous() \
|
||||||
|
.view(*out_size)
|
||||||
|
if dim != 0:
|
||||||
|
x = x.permute(perm).contiguous()
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Linear(torch.nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
||||||
|
super(Linear, self).__init__()
|
||||||
|
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
||||||
|
|
||||||
|
torch.nn.init.xavier_uniform_(
|
||||||
|
self.linear_layer.weight,
|
||||||
|
gain=torch.nn.init.calculate_gain(w_init_gain))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear_layer(x)
|
||||||
|
|
||||||
|
class Conv1d(torch.nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
|
||||||
|
padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
|
||||||
|
super(Conv1d, self).__init__()
|
||||||
|
if padding is None:
|
||||||
|
assert(kernel_size % 2 == 1)
|
||||||
|
padding = int(dilation * (kernel_size - 1)/2)
|
||||||
|
|
||||||
|
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
||||||
|
kernel_size=kernel_size, stride=stride,
|
||||||
|
padding=padding, dilation=dilation,
|
||||||
|
bias=bias)
|
||||||
|
torch.nn.init.xavier_uniform_(
|
||||||
|
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: BxDxT
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def tile(x, count, dim=0):
|
||||||
|
"""
|
||||||
|
Tiles x on dimension dim count times.
|
||||||
|
"""
|
||||||
|
perm = list(range(len(x.size())))
|
||||||
|
if dim != 0:
|
||||||
|
perm[0], perm[dim] = perm[dim], perm[0]
|
||||||
|
x = x.permute(perm).contiguous()
|
||||||
|
out_size = list(x.size())
|
||||||
|
out_size[0] *= count
|
||||||
|
batch = x.size(0)
|
||||||
|
x = x.view(batch, -1) \
|
||||||
|
.transpose(0, 1) \
|
||||||
|
.repeat(count, 1) \
|
||||||
|
.transpose(0, 1) \
|
||||||
|
.contiguous() \
|
||||||
|
.view(*out_size)
|
||||||
|
if dim != 0:
|
||||||
|
x = x.permute(perm).contiguous()
|
||||||
|
return x
|
||||||
52
ppg2mel/utils/cnn_postnet.py
Normal file
52
ppg2mel/utils/cnn_postnet.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .basic_layers import Linear, Conv1d
|
||||||
|
|
||||||
|
|
||||||
|
class Postnet(nn.Module):
|
||||||
|
"""Postnet
|
||||||
|
- Five 1-d convolution with 512 channels and kernel size 5
|
||||||
|
"""
|
||||||
|
def __init__(self, num_mels=80,
|
||||||
|
num_layers=5,
|
||||||
|
hidden_dim=512,
|
||||||
|
kernel_size=5):
|
||||||
|
super(Postnet, self).__init__()
|
||||||
|
self.convolutions = nn.ModuleList()
|
||||||
|
|
||||||
|
self.convolutions.append(
|
||||||
|
nn.Sequential(
|
||||||
|
Conv1d(
|
||||||
|
num_mels, hidden_dim,
|
||||||
|
kernel_size=kernel_size, stride=1,
|
||||||
|
padding=int((kernel_size - 1) / 2),
|
||||||
|
dilation=1, w_init_gain='tanh'),
|
||||||
|
nn.BatchNorm1d(hidden_dim)))
|
||||||
|
|
||||||
|
for i in range(1, num_layers - 1):
|
||||||
|
self.convolutions.append(
|
||||||
|
nn.Sequential(
|
||||||
|
Conv1d(
|
||||||
|
hidden_dim,
|
||||||
|
hidden_dim,
|
||||||
|
kernel_size=kernel_size, stride=1,
|
||||||
|
padding=int((kernel_size - 1) / 2),
|
||||||
|
dilation=1, w_init_gain='tanh'),
|
||||||
|
nn.BatchNorm1d(hidden_dim)))
|
||||||
|
|
||||||
|
self.convolutions.append(
|
||||||
|
nn.Sequential(
|
||||||
|
Conv1d(
|
||||||
|
hidden_dim, num_mels,
|
||||||
|
kernel_size=kernel_size, stride=1,
|
||||||
|
padding=int((kernel_size - 1) / 2),
|
||||||
|
dilation=1, w_init_gain='linear'),
|
||||||
|
nn.BatchNorm1d(num_mels)))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: (B, num_mels, T_dec)
|
||||||
|
for i in range(len(self.convolutions) - 1):
|
||||||
|
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
|
||||||
|
x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
|
||||||
|
return x
|
||||||
123
ppg2mel/utils/mol_attention.py
Normal file
123
ppg2mel/utils/mol_attention.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class MOLAttention(nn.Module):
|
||||||
|
""" Discretized Mixture of Logistic (MOL) attention.
|
||||||
|
C.f. Section 5 of "MelNet: A Generative Model for Audio in the Frequency Domain" and
|
||||||
|
GMMv2b model in "Location-relative attention mechanisms for robust long-form speech synthesis".
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim,
|
||||||
|
r=1,
|
||||||
|
M=5,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
query_dim: attention_rnn_dim.
|
||||||
|
M: number of mixtures.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if r < 1:
|
||||||
|
self.r = float(r)
|
||||||
|
else:
|
||||||
|
self.r = int(r)
|
||||||
|
self.M = M
|
||||||
|
self.score_mask_value = 0.0 # -float("inf")
|
||||||
|
self.eps = 1e-5
|
||||||
|
# Position arrary for encoder time steps
|
||||||
|
self.J = None
|
||||||
|
# Query layer: [w, sigma,]
|
||||||
|
self.query_layer = torch.nn.Sequential(
|
||||||
|
nn.Linear(query_dim, 256, bias=True),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(256, 3*M, bias=True)
|
||||||
|
)
|
||||||
|
self.mu_prev = None
|
||||||
|
self.initialize_bias()
|
||||||
|
|
||||||
|
def initialize_bias(self):
|
||||||
|
"""Initialize sigma and Delta."""
|
||||||
|
# sigma
|
||||||
|
torch.nn.init.constant_(self.query_layer[2].bias[self.M:2*self.M], 1.0)
|
||||||
|
# Delta: softplus(1.8545) = 2.0; softplus(3.9815) = 4.0; softplus(0.5413) = 1.0
|
||||||
|
# softplus(-0.432) = 0.5003
|
||||||
|
if self.r == 2:
|
||||||
|
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 1.8545)
|
||||||
|
elif self.r == 4:
|
||||||
|
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 3.9815)
|
||||||
|
elif self.r == 1:
|
||||||
|
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 0.5413)
|
||||||
|
else:
|
||||||
|
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], -0.432)
|
||||||
|
|
||||||
|
|
||||||
|
def init_states(self, memory):
|
||||||
|
"""Initialize mu_prev and J.
|
||||||
|
This function should be called by the decoder before decoding one batch.
|
||||||
|
Args:
|
||||||
|
memory: (B, T, D_enc) encoder output.
|
||||||
|
"""
|
||||||
|
B, T_enc, _ = memory.size()
|
||||||
|
device = memory.device
|
||||||
|
self.J = torch.arange(0, T_enc + 2.0).to(device) + 0.5 # NOTE: for discretize usage
|
||||||
|
# self.J = memory.new_tensor(np.arange(T_enc), dtype=torch.float)
|
||||||
|
self.mu_prev = torch.zeros(B, self.M).to(device)
|
||||||
|
|
||||||
|
def forward(self, att_rnn_h, memory, memory_pitch=None, mask=None):
|
||||||
|
"""
|
||||||
|
att_rnn_h: attetion rnn hidden state.
|
||||||
|
memory: encoder outputs (B, T_enc, D).
|
||||||
|
mask: binary mask for padded data (B, T_enc).
|
||||||
|
"""
|
||||||
|
# [B, 3M]
|
||||||
|
mixture_params = self.query_layer(att_rnn_h)
|
||||||
|
|
||||||
|
# [B, M]
|
||||||
|
w_hat = mixture_params[:, :self.M]
|
||||||
|
sigma_hat = mixture_params[:, self.M:2*self.M]
|
||||||
|
Delta_hat = mixture_params[:, 2*self.M:3*self.M]
|
||||||
|
|
||||||
|
# print("w_hat: ", w_hat)
|
||||||
|
# print("sigma_hat: ", sigma_hat)
|
||||||
|
# print("Delta_hat: ", Delta_hat)
|
||||||
|
|
||||||
|
# Dropout to de-correlate attention heads
|
||||||
|
w_hat = F.dropout(w_hat, p=0.5, training=self.training) # NOTE(sx): needed?
|
||||||
|
|
||||||
|
# Mixture parameters
|
||||||
|
w = torch.softmax(w_hat, dim=-1) + self.eps
|
||||||
|
sigma = F.softplus(sigma_hat) + self.eps
|
||||||
|
Delta = F.softplus(Delta_hat)
|
||||||
|
mu_cur = self.mu_prev + Delta
|
||||||
|
# print("w:", w)
|
||||||
|
j = self.J[:memory.size(1) + 1]
|
||||||
|
|
||||||
|
# Attention weights
|
||||||
|
# CDF of logistic distribution
|
||||||
|
phi_t = w.unsqueeze(-1) * (1 / (1 + torch.sigmoid(
|
||||||
|
(mu_cur.unsqueeze(-1) - j) / sigma.unsqueeze(-1))))
|
||||||
|
# print("phi_t:", phi_t)
|
||||||
|
|
||||||
|
# Discretize attention weights
|
||||||
|
# (B, T_enc + 1)
|
||||||
|
alpha_t = torch.sum(phi_t, dim=1)
|
||||||
|
alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1]
|
||||||
|
alpha_t[alpha_t == 0] = self.eps
|
||||||
|
# print("alpha_t: ", alpha_t.size())
|
||||||
|
# Apply masking
|
||||||
|
if mask is not None:
|
||||||
|
alpha_t.data.masked_fill_(mask, self.score_mask_value)
|
||||||
|
|
||||||
|
context = torch.bmm(alpha_t.unsqueeze(1), memory).squeeze(1)
|
||||||
|
if memory_pitch is not None:
|
||||||
|
context_pitch = torch.bmm(alpha_t.unsqueeze(1), memory_pitch).squeeze(1)
|
||||||
|
|
||||||
|
self.mu_prev = mu_cur
|
||||||
|
|
||||||
|
if memory_pitch is not None:
|
||||||
|
return context, context_pitch, alpha_t
|
||||||
|
return context, alpha_t
|
||||||
|
|
||||||
451
ppg2mel/utils/nets_utils.py
Normal file
451
ppg2mel/utils/nets_utils.py
Normal file
@@ -0,0 +1,451 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""Network related utility tools."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def to_device(m, x):
|
||||||
|
"""Send tensor into the device of the module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
m (torch.nn.Module): Torch module.
|
||||||
|
x (Tensor): Torch tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Torch tensor located in the same place as torch module.
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert isinstance(m, torch.nn.Module)
|
||||||
|
device = next(m.parameters()).device
|
||||||
|
return x.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def pad_list(xs, pad_value):
|
||||||
|
"""Perform padding for the list of tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
||||||
|
pad_value (float): Value for padding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Padded tensor (B, Tmax, `*`).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
||||||
|
>>> x
|
||||||
|
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
||||||
|
>>> pad_list(x, 0)
|
||||||
|
tensor([[1., 1., 1., 1.],
|
||||||
|
[1., 1., 0., 0.],
|
||||||
|
[1., 0., 0., 0.]])
|
||||||
|
|
||||||
|
"""
|
||||||
|
n_batch = len(xs)
|
||||||
|
max_len = max(x.size(0) for x in xs)
|
||||||
|
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
||||||
|
|
||||||
|
for i in range(n_batch):
|
||||||
|
pad[i, :xs[i].size(0)] = xs[i]
|
||||||
|
|
||||||
|
return pad
|
||||||
|
|
||||||
|
|
||||||
|
def make_pad_mask(lengths, xs=None, length_dim=-1):
|
||||||
|
"""Make mask tensor containing indices of padded part.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lengths (LongTensor or List): Batch of lengths (B,).
|
||||||
|
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
|
||||||
|
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Mask tensor containing indices of padded part.
|
||||||
|
dtype=torch.uint8 in PyTorch 1.2-
|
||||||
|
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
With only lengths.
|
||||||
|
|
||||||
|
>>> lengths = [5, 3, 2]
|
||||||
|
>>> make_non_pad_mask(lengths)
|
||||||
|
masks = [[0, 0, 0, 0 ,0],
|
||||||
|
[0, 0, 0, 1, 1],
|
||||||
|
[0, 0, 1, 1, 1]]
|
||||||
|
|
||||||
|
With the reference tensor.
|
||||||
|
|
||||||
|
>>> xs = torch.zeros((3, 2, 4))
|
||||||
|
>>> make_pad_mask(lengths, xs)
|
||||||
|
tensor([[[0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0]],
|
||||||
|
[[0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 1]],
|
||||||
|
[[0, 0, 1, 1],
|
||||||
|
[0, 0, 1, 1]]], dtype=torch.uint8)
|
||||||
|
>>> xs = torch.zeros((3, 2, 6))
|
||||||
|
>>> make_pad_mask(lengths, xs)
|
||||||
|
tensor([[[0, 0, 0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 0, 0, 1]],
|
||||||
|
[[0, 0, 0, 1, 1, 1],
|
||||||
|
[0, 0, 0, 1, 1, 1]],
|
||||||
|
[[0, 0, 1, 1, 1, 1],
|
||||||
|
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||||
|
|
||||||
|
With the reference tensor and dimension indicator.
|
||||||
|
|
||||||
|
>>> xs = torch.zeros((3, 6, 6))
|
||||||
|
>>> make_pad_mask(lengths, xs, 1)
|
||||||
|
tensor([[[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[1, 1, 1, 1, 1, 1]],
|
||||||
|
[[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1]],
|
||||||
|
[[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||||
|
>>> make_pad_mask(lengths, xs, 2)
|
||||||
|
tensor([[[0, 0, 0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 0, 0, 1]],
|
||||||
|
[[0, 0, 0, 1, 1, 1],
|
||||||
|
[0, 0, 0, 1, 1, 1],
|
||||||
|
[0, 0, 0, 1, 1, 1],
|
||||||
|
[0, 0, 0, 1, 1, 1],
|
||||||
|
[0, 0, 0, 1, 1, 1],
|
||||||
|
[0, 0, 0, 1, 1, 1]],
|
||||||
|
[[0, 0, 1, 1, 1, 1],
|
||||||
|
[0, 0, 1, 1, 1, 1],
|
||||||
|
[0, 0, 1, 1, 1, 1],
|
||||||
|
[0, 0, 1, 1, 1, 1],
|
||||||
|
[0, 0, 1, 1, 1, 1],
|
||||||
|
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||||
|
|
||||||
|
"""
|
||||||
|
if length_dim == 0:
|
||||||
|
raise ValueError('length_dim cannot be 0: {}'.format(length_dim))
|
||||||
|
|
||||||
|
if not isinstance(lengths, list):
|
||||||
|
lengths = lengths.tolist()
|
||||||
|
bs = int(len(lengths))
|
||||||
|
if xs is None:
|
||||||
|
maxlen = int(max(lengths))
|
||||||
|
else:
|
||||||
|
maxlen = xs.size(length_dim)
|
||||||
|
|
||||||
|
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
||||||
|
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
||||||
|
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
||||||
|
mask = seq_range_expand >= seq_length_expand
|
||||||
|
|
||||||
|
if xs is not None:
|
||||||
|
assert xs.size(0) == bs, (xs.size(0), bs)
|
||||||
|
|
||||||
|
if length_dim < 0:
|
||||||
|
length_dim = xs.dim() + length_dim
|
||||||
|
# ind = (:, None, ..., None, :, , None, ..., None)
|
||||||
|
ind = tuple(slice(None) if i in (0, length_dim) else None
|
||||||
|
for i in range(xs.dim()))
|
||||||
|
mask = mask[ind].expand_as(xs).to(xs.device)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
|
||||||
|
"""Make mask tensor containing indices of non-padded part.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lengths (LongTensor or List): Batch of lengths (B,).
|
||||||
|
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
|
||||||
|
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ByteTensor: mask tensor containing indices of padded part.
|
||||||
|
dtype=torch.uint8 in PyTorch 1.2-
|
||||||
|
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
With only lengths.
|
||||||
|
|
||||||
|
>>> lengths = [5, 3, 2]
|
||||||
|
>>> make_non_pad_mask(lengths)
|
||||||
|
masks = [[1, 1, 1, 1 ,1],
|
||||||
|
[1, 1, 1, 0, 0],
|
||||||
|
[1, 1, 0, 0, 0]]
|
||||||
|
|
||||||
|
With the reference tensor.
|
||||||
|
|
||||||
|
>>> xs = torch.zeros((3, 2, 4))
|
||||||
|
>>> make_non_pad_mask(lengths, xs)
|
||||||
|
tensor([[[1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1]],
|
||||||
|
[[1, 1, 1, 0],
|
||||||
|
[1, 1, 1, 0]],
|
||||||
|
[[1, 1, 0, 0],
|
||||||
|
[1, 1, 0, 0]]], dtype=torch.uint8)
|
||||||
|
>>> xs = torch.zeros((3, 2, 6))
|
||||||
|
>>> make_non_pad_mask(lengths, xs)
|
||||||
|
tensor([[[1, 1, 1, 1, 1, 0],
|
||||||
|
[1, 1, 1, 1, 1, 0]],
|
||||||
|
[[1, 1, 1, 0, 0, 0],
|
||||||
|
[1, 1, 1, 0, 0, 0]],
|
||||||
|
[[1, 1, 0, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||||
|
|
||||||
|
With the reference tensor and dimension indicator.
|
||||||
|
|
||||||
|
>>> xs = torch.zeros((3, 6, 6))
|
||||||
|
>>> make_non_pad_mask(lengths, xs, 1)
|
||||||
|
tensor([[[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[0, 0, 0, 0, 0, 0]],
|
||||||
|
[[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0]],
|
||||||
|
[[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||||
|
>>> make_non_pad_mask(lengths, xs, 2)
|
||||||
|
tensor([[[1, 1, 1, 1, 1, 0],
|
||||||
|
[1, 1, 1, 1, 1, 0],
|
||||||
|
[1, 1, 1, 1, 1, 0],
|
||||||
|
[1, 1, 1, 1, 1, 0],
|
||||||
|
[1, 1, 1, 1, 1, 0],
|
||||||
|
[1, 1, 1, 1, 1, 0]],
|
||||||
|
[[1, 1, 1, 0, 0, 0],
|
||||||
|
[1, 1, 1, 0, 0, 0],
|
||||||
|
[1, 1, 1, 0, 0, 0],
|
||||||
|
[1, 1, 1, 0, 0, 0],
|
||||||
|
[1, 1, 1, 0, 0, 0],
|
||||||
|
[1, 1, 1, 0, 0, 0]],
|
||||||
|
[[1, 1, 0, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||||
|
|
||||||
|
"""
|
||||||
|
return ~make_pad_mask(lengths, xs, length_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def mask_by_length(xs, lengths, fill=0):
|
||||||
|
"""Mask tensor according to length.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xs (Tensor): Batch of input tensor (B, `*`).
|
||||||
|
lengths (LongTensor or List): Batch of lengths (B,).
|
||||||
|
fill (int or float): Value to fill masked part.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Batch of masked input tensor (B, `*`).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x = torch.arange(5).repeat(3, 1) + 1
|
||||||
|
>>> x
|
||||||
|
tensor([[1, 2, 3, 4, 5],
|
||||||
|
[1, 2, 3, 4, 5],
|
||||||
|
[1, 2, 3, 4, 5]])
|
||||||
|
>>> lengths = [5, 3, 2]
|
||||||
|
>>> mask_by_length(x, lengths)
|
||||||
|
tensor([[1, 2, 3, 4, 5],
|
||||||
|
[1, 2, 3, 0, 0],
|
||||||
|
[1, 2, 0, 0, 0]])
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert xs.size(0) == len(lengths)
|
||||||
|
ret = xs.data.new(*xs.size()).fill_(fill)
|
||||||
|
for i, l in enumerate(lengths):
|
||||||
|
ret[i, :l] = xs[i, :l]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def th_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||||
|
"""Calculate accuracy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||||||
|
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
||||||
|
ignore_label (int): Ignore label id.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Accuracy value (0.0 - 1.0).
|
||||||
|
|
||||||
|
"""
|
||||||
|
pad_pred = pad_outputs.view(
|
||||||
|
pad_targets.size(0),
|
||||||
|
pad_targets.size(1),
|
||||||
|
pad_outputs.size(1)).argmax(2)
|
||||||
|
mask = pad_targets != ignore_label
|
||||||
|
numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||||||
|
denominator = torch.sum(mask)
|
||||||
|
return float(numerator) / float(denominator)
|
||||||
|
|
||||||
|
|
||||||
|
def to_torch_tensor(x):
|
||||||
|
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor or ComplexTensor: Type converted inputs.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> xs = np.ones(3, dtype=np.float32)
|
||||||
|
>>> xs = to_torch_tensor(xs)
|
||||||
|
tensor([1., 1., 1.])
|
||||||
|
>>> xs = torch.ones(3, 4, 5)
|
||||||
|
>>> assert to_torch_tensor(xs) is xs
|
||||||
|
>>> xs = {'real': xs, 'imag': xs}
|
||||||
|
>>> to_torch_tensor(xs)
|
||||||
|
ComplexTensor(
|
||||||
|
Real:
|
||||||
|
tensor([1., 1., 1.])
|
||||||
|
Imag;
|
||||||
|
tensor([1., 1., 1.])
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
# If numpy, change to torch tensor
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
if x.dtype.kind == 'c':
|
||||||
|
# Dynamically importing because torch_complex requires python3
|
||||||
|
from torch_complex.tensor import ComplexTensor
|
||||||
|
return ComplexTensor(x)
|
||||||
|
else:
|
||||||
|
return torch.from_numpy(x)
|
||||||
|
|
||||||
|
# If {'real': ..., 'imag': ...}, convert to ComplexTensor
|
||||||
|
elif isinstance(x, dict):
|
||||||
|
# Dynamically importing because torch_complex requires python3
|
||||||
|
from torch_complex.tensor import ComplexTensor
|
||||||
|
|
||||||
|
if 'real' not in x or 'imag' not in x:
|
||||||
|
raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
|
||||||
|
# Relative importing because of using python3 syntax
|
||||||
|
return ComplexTensor(x['real'], x['imag'])
|
||||||
|
|
||||||
|
# If torch.Tensor, as it is
|
||||||
|
elif isinstance(x, torch.Tensor):
|
||||||
|
return x
|
||||||
|
|
||||||
|
else:
|
||||||
|
error = ("x must be numpy.ndarray, torch.Tensor or a dict like "
|
||||||
|
"{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
|
||||||
|
"but got {}".format(type(x)))
|
||||||
|
try:
|
||||||
|
from torch_complex.tensor import ComplexTensor
|
||||||
|
except Exception:
|
||||||
|
# If PY2
|
||||||
|
raise ValueError(error)
|
||||||
|
else:
|
||||||
|
# If PY3
|
||||||
|
if isinstance(x, ComplexTensor):
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
raise ValueError(error)
|
||||||
|
|
||||||
|
|
||||||
|
def get_subsample(train_args, mode, arch):
|
||||||
|
"""Parse the subsampling factors from the training args for the specified `mode` and `arch`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_args: argument Namespace containing options.
|
||||||
|
mode: one of ('asr', 'mt', 'st')
|
||||||
|
arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray / List[np.ndarray]: subsampling factors.
|
||||||
|
"""
|
||||||
|
if arch == 'transformer':
|
||||||
|
return np.array([1])
|
||||||
|
|
||||||
|
elif mode == 'mt' and arch == 'rnn':
|
||||||
|
# +1 means input (+1) and layers outputs (train_args.elayer)
|
||||||
|
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||||
|
logging.warning('Subsampling is not performed for machine translation.')
|
||||||
|
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||||
|
return subsample
|
||||||
|
|
||||||
|
elif (mode == 'asr' and arch in ('rnn', 'rnn-t')) or \
|
||||||
|
(mode == 'mt' and arch == 'rnn') or \
|
||||||
|
(mode == 'st' and arch == 'rnn'):
|
||||||
|
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||||
|
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||||
|
ss = train_args.subsample.split("_")
|
||||||
|
for j in range(min(train_args.elayers + 1, len(ss))):
|
||||||
|
subsample[j] = int(ss[j])
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
|
||||||
|
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||||
|
return subsample
|
||||||
|
|
||||||
|
elif mode == 'asr' and arch == 'rnn_mix':
|
||||||
|
subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int)
|
||||||
|
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||||
|
ss = train_args.subsample.split("_")
|
||||||
|
for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))):
|
||||||
|
subsample[j] = int(ss[j])
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
|
||||||
|
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||||
|
return subsample
|
||||||
|
|
||||||
|
elif mode == 'asr' and arch == 'rnn_mulenc':
|
||||||
|
subsample_list = []
|
||||||
|
for idx in range(train_args.num_encs):
|
||||||
|
subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int)
|
||||||
|
if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"):
|
||||||
|
ss = train_args.subsample[idx].split("_")
|
||||||
|
for j in range(min(train_args.elayers[idx] + 1, len(ss))):
|
||||||
|
subsample[j] = int(ss[j])
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
'Encoder %d: Subsampling is not performed for vgg*. '
|
||||||
|
'It is performed in max pooling layers at CNN.', idx + 1)
|
||||||
|
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||||
|
subsample_list.append(subsample)
|
||||||
|
return subsample_list
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError('Invalid options: mode={}, arch={}'.format(mode, arch))
|
||||||
|
|
||||||
|
|
||||||
|
def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]):
|
||||||
|
"""Replace keys of old prefix with new prefix in state dict."""
|
||||||
|
# need this list not to break the dict iterator
|
||||||
|
old_keys = [k for k in state_dict if k.startswith(old_prefix)]
|
||||||
|
if len(old_keys) > 0:
|
||||||
|
logging.warning(f'Rename: {old_prefix} -> {new_prefix}')
|
||||||
|
for k in old_keys:
|
||||||
|
v = state_dict.pop(k)
|
||||||
|
new_k = k.replace(old_prefix, new_prefix)
|
||||||
|
state_dict[new_k] = v
|
||||||
22
ppg2mel/utils/vc_utils.py
Normal file
22
ppg2mel/utils/vc_utils.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def gcd(a, b):
|
||||||
|
"""Greatest common divisor."""
|
||||||
|
a, b = (a, b) if a >=b else (b, a)
|
||||||
|
if a%b == 0:
|
||||||
|
return b
|
||||||
|
else :
|
||||||
|
return gcd(b, a%b)
|
||||||
|
|
||||||
|
def lcm(a, b):
|
||||||
|
"""Least common multiple"""
|
||||||
|
return a * b // gcd(a, b)
|
||||||
|
|
||||||
|
def get_mask_from_lengths(lengths, max_len=None):
|
||||||
|
if max_len is None:
|
||||||
|
max_len = torch.max(lengths).item()
|
||||||
|
ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
|
||||||
|
mask = (ids < lengths.unsqueeze(1)).bool()
|
||||||
|
return mask
|
||||||
|
|
||||||
67
ppg2mel_train.py
Normal file
67
ppg2mel_train.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from utils.load_yaml import HpsYaml
|
||||||
|
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||||
|
|
||||||
|
# For reproducibility, comment these may speed up training
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Arguments
|
||||||
|
parser = argparse.ArgumentParser(description=
|
||||||
|
'Training PPG2Mel VC model.')
|
||||||
|
parser.add_argument('--config', type=str,
|
||||||
|
help='Path to experiment config, e.g., config/vc.yaml')
|
||||||
|
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
||||||
|
parser.add_argument('--logdir', default='log/', type=str,
|
||||||
|
help='Logging path.', required=False)
|
||||||
|
parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str,
|
||||||
|
help='Checkpoint path.', required=False)
|
||||||
|
parser.add_argument('--outdir', default='result/', type=str,
|
||||||
|
help='Decode output path.', required=False)
|
||||||
|
parser.add_argument('--load', default=None, type=str,
|
||||||
|
help='Load pre-trained model (for training only)', required=False)
|
||||||
|
parser.add_argument('--warm_start', action='store_true',
|
||||||
|
help='Load model weights only, ignore specified layers.')
|
||||||
|
parser.add_argument('--seed', default=0, type=int,
|
||||||
|
help='Random seed for reproducable results.', required=False)
|
||||||
|
parser.add_argument('--njobs', default=8, type=int,
|
||||||
|
help='Number of threads for dataloader/decoding.', required=False)
|
||||||
|
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
||||||
|
parser.add_argument('--no-pin', action='store_true',
|
||||||
|
help='Disable pin-memory for dataloader')
|
||||||
|
parser.add_argument('--test', action='store_true', help='Test the model.')
|
||||||
|
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
||||||
|
parser.add_argument('--finetune', action='store_true', help='Finetune model')
|
||||||
|
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
|
||||||
|
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
|
||||||
|
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
|
||||||
|
|
||||||
|
###
|
||||||
|
|
||||||
|
paras = parser.parse_args()
|
||||||
|
setattr(paras, 'gpu', not paras.cpu)
|
||||||
|
setattr(paras, 'pin_memory', not paras.no_pin)
|
||||||
|
setattr(paras, 'verbose', not paras.no_msg)
|
||||||
|
# Make the config dict dot visitable
|
||||||
|
config = HpsYaml(paras.config)
|
||||||
|
|
||||||
|
np.random.seed(paras.seed)
|
||||||
|
torch.manual_seed(paras.seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(paras.seed)
|
||||||
|
|
||||||
|
print(">>> OneShot VC training ...")
|
||||||
|
mode = "train"
|
||||||
|
solver = Solver(config, paras, mode)
|
||||||
|
solver.load_data()
|
||||||
|
solver.set_model()
|
||||||
|
solver.exec()
|
||||||
|
print(">>> Oneshot VC train finished!")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
102
ppg_extractor/__init__.py
Normal file
102
ppg_extractor/__init__.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from .frontend import DefaultFrontend
|
||||||
|
from .utterance_mvn import UtteranceMVN
|
||||||
|
from .encoder.conformer_encoder import ConformerEncoder
|
||||||
|
|
||||||
|
_model = None # type: PPGModel
|
||||||
|
_device = None
|
||||||
|
|
||||||
|
class PPGModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
frontend,
|
||||||
|
normalizer,
|
||||||
|
encoder,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.frontend = frontend
|
||||||
|
self.normalize = normalizer
|
||||||
|
self.encoder = encoder
|
||||||
|
|
||||||
|
def forward(self, speech, speech_lengths):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
speech (tensor): (B, L)
|
||||||
|
speech_lengths (tensor): (B, )
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bottle_neck_feats (tensor): (B, L//hop_size, 144)
|
||||||
|
|
||||||
|
"""
|
||||||
|
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||||
|
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||||
|
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
||||||
|
return encoder_out
|
||||||
|
|
||||||
|
def _extract_feats(
|
||||||
|
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||||
|
):
|
||||||
|
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||||
|
|
||||||
|
# for data-parallel
|
||||||
|
speech = speech[:, : speech_lengths.max()]
|
||||||
|
|
||||||
|
if self.frontend is not None:
|
||||||
|
# Frontend
|
||||||
|
# e.g. STFT and Feature extract
|
||||||
|
# data_loader may send time-domain signal in this case
|
||||||
|
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||||||
|
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||||
|
else:
|
||||||
|
# No frontend and no feature extract
|
||||||
|
feats, feats_lengths = speech, speech_lengths
|
||||||
|
return feats, feats_lengths
|
||||||
|
|
||||||
|
def extract_from_wav(self, src_wav):
|
||||||
|
src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(_device)
|
||||||
|
src_wav_lengths = torch.LongTensor([len(src_wav)]).to(_device)
|
||||||
|
return self(src_wav_tensor, src_wav_lengths)
|
||||||
|
|
||||||
|
|
||||||
|
def build_model(args):
|
||||||
|
normalizer = UtteranceMVN(**args.normalize_conf)
|
||||||
|
frontend = DefaultFrontend(**args.frontend_conf)
|
||||||
|
encoder = ConformerEncoder(input_size=80, **args.encoder_conf)
|
||||||
|
model = PPGModel(frontend, normalizer, encoder)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_file, device=None):
|
||||||
|
global _model, _device
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
else:
|
||||||
|
_device = device
|
||||||
|
# search a config file
|
||||||
|
model_config_fpaths = list(model_file.parent.rglob("*.yaml"))
|
||||||
|
config_file = model_config_fpaths[0]
|
||||||
|
with config_file.open("r", encoding="utf-8") as f:
|
||||||
|
args = yaml.safe_load(f)
|
||||||
|
|
||||||
|
args = argparse.Namespace(**args)
|
||||||
|
|
||||||
|
model = build_model(args)
|
||||||
|
model_state_dict = model.state_dict()
|
||||||
|
|
||||||
|
ckpt_state_dict = torch.load(model_file, map_location=_device)
|
||||||
|
ckpt_state_dict = {k:v for k,v in ckpt_state_dict.items() if 'encoder' in k}
|
||||||
|
|
||||||
|
model_state_dict.update(ckpt_state_dict)
|
||||||
|
model.load_state_dict(model_state_dict)
|
||||||
|
|
||||||
|
_model = model.eval().to(_device)
|
||||||
|
return _model
|
||||||
|
|
||||||
|
|
||||||
398
ppg_extractor/e2e_asr_common.py
Normal file
398
ppg_extractor/e2e_asr_common.py
Normal file
@@ -0,0 +1,398 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""Common functions for ASR."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import editdistance
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import six
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from itertools import groupby
|
||||||
|
|
||||||
|
|
||||||
|
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
|
||||||
|
"""End detection.
|
||||||
|
|
||||||
|
desribed in Eq. (50) of S. Watanabe et al
|
||||||
|
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
|
||||||
|
|
||||||
|
:param ended_hyps:
|
||||||
|
:param i:
|
||||||
|
:param M:
|
||||||
|
:param D_end:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if len(ended_hyps) == 0:
|
||||||
|
return False
|
||||||
|
count = 0
|
||||||
|
best_hyp = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[0]
|
||||||
|
for m in six.moves.range(M):
|
||||||
|
# get ended_hyps with their length is i - m
|
||||||
|
hyp_length = i - m
|
||||||
|
hyps_same_length = [x for x in ended_hyps if len(x['yseq']) == hyp_length]
|
||||||
|
if len(hyps_same_length) > 0:
|
||||||
|
best_hyp_same_length = sorted(hyps_same_length, key=lambda x: x['score'], reverse=True)[0]
|
||||||
|
if best_hyp_same_length['score'] - best_hyp['score'] < D_end:
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
if count == M:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(takaaki-hori): add different smoothing methods
|
||||||
|
def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
|
||||||
|
"""Obtain label distribution for loss smoothing.
|
||||||
|
|
||||||
|
:param odim:
|
||||||
|
:param lsm_type:
|
||||||
|
:param blank:
|
||||||
|
:param transcript:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if transcript is not None:
|
||||||
|
with open(transcript, 'rb') as f:
|
||||||
|
trans_json = json.load(f)['utts']
|
||||||
|
|
||||||
|
if lsm_type == 'unigram':
|
||||||
|
assert transcript is not None, 'transcript is required for %s label smoothing' % lsm_type
|
||||||
|
labelcount = np.zeros(odim)
|
||||||
|
for k, v in trans_json.items():
|
||||||
|
ids = np.array([int(n) for n in v['output'][0]['tokenid'].split()])
|
||||||
|
# to avoid an error when there is no text in an uttrance
|
||||||
|
if len(ids) > 0:
|
||||||
|
labelcount[ids] += 1
|
||||||
|
labelcount[odim - 1] = len(transcript) # count <eos>
|
||||||
|
labelcount[labelcount == 0] = 1 # flooring
|
||||||
|
labelcount[blank] = 0 # remove counts for blank
|
||||||
|
labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
|
||||||
|
else:
|
||||||
|
logging.error(
|
||||||
|
"Error: unexpected label smoothing type: %s" % lsm_type)
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
return labeldist
|
||||||
|
|
||||||
|
|
||||||
|
def get_vgg2l_odim(idim, in_channel=3, out_channel=128, downsample=True):
|
||||||
|
"""Return the output size of the VGG frontend.
|
||||||
|
|
||||||
|
:param in_channel: input channel size
|
||||||
|
:param out_channel: output channel size
|
||||||
|
:return: output size
|
||||||
|
:rtype int
|
||||||
|
"""
|
||||||
|
idim = idim / in_channel
|
||||||
|
if downsample:
|
||||||
|
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling
|
||||||
|
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling
|
||||||
|
return int(idim) * out_channel # numer of channels
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorCalculator(object):
|
||||||
|
"""Calculate CER and WER for E2E_ASR and CTC models during training.
|
||||||
|
|
||||||
|
:param y_hats: numpy array with predicted text
|
||||||
|
:param y_pads: numpy array with true (target) text
|
||||||
|
:param char_list:
|
||||||
|
:param sym_space:
|
||||||
|
:param sym_blank:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False,
|
||||||
|
trans_type="char"):
|
||||||
|
"""Construct an ErrorCalculator object."""
|
||||||
|
super(ErrorCalculator, self).__init__()
|
||||||
|
|
||||||
|
self.report_cer = report_cer
|
||||||
|
self.report_wer = report_wer
|
||||||
|
self.trans_type = trans_type
|
||||||
|
self.char_list = char_list
|
||||||
|
self.space = sym_space
|
||||||
|
self.blank = sym_blank
|
||||||
|
self.idx_blank = self.char_list.index(self.blank)
|
||||||
|
if self.space in self.char_list:
|
||||||
|
self.idx_space = self.char_list.index(self.space)
|
||||||
|
else:
|
||||||
|
self.idx_space = None
|
||||||
|
|
||||||
|
def __call__(self, ys_hat, ys_pad, is_ctc=False):
|
||||||
|
"""Calculate sentence-level WER/CER score.
|
||||||
|
|
||||||
|
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
||||||
|
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
||||||
|
:param bool is_ctc: calculate CER score for CTC
|
||||||
|
:return: sentence-level WER score
|
||||||
|
:rtype float
|
||||||
|
:return: sentence-level CER score
|
||||||
|
:rtype float
|
||||||
|
"""
|
||||||
|
cer, wer = None, None
|
||||||
|
if is_ctc:
|
||||||
|
return self.calculate_cer_ctc(ys_hat, ys_pad)
|
||||||
|
elif not self.report_cer and not self.report_wer:
|
||||||
|
return cer, wer
|
||||||
|
|
||||||
|
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
|
||||||
|
if self.report_cer:
|
||||||
|
cer = self.calculate_cer(seqs_hat, seqs_true)
|
||||||
|
|
||||||
|
if self.report_wer:
|
||||||
|
wer = self.calculate_wer(seqs_hat, seqs_true)
|
||||||
|
return cer, wer
|
||||||
|
|
||||||
|
def calculate_cer_ctc(self, ys_hat, ys_pad):
|
||||||
|
"""Calculate sentence-level CER score for CTC.
|
||||||
|
|
||||||
|
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
||||||
|
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
||||||
|
:return: average sentence-level CER score
|
||||||
|
:rtype float
|
||||||
|
"""
|
||||||
|
cers, char_ref_lens = [], []
|
||||||
|
for i, y in enumerate(ys_hat):
|
||||||
|
y_hat = [x[0] for x in groupby(y)]
|
||||||
|
y_true = ys_pad[i]
|
||||||
|
seq_hat, seq_true = [], []
|
||||||
|
for idx in y_hat:
|
||||||
|
idx = int(idx)
|
||||||
|
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
||||||
|
seq_hat.append(self.char_list[int(idx)])
|
||||||
|
|
||||||
|
for idx in y_true:
|
||||||
|
idx = int(idx)
|
||||||
|
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
||||||
|
seq_true.append(self.char_list[int(idx)])
|
||||||
|
if self.trans_type == "char":
|
||||||
|
hyp_chars = "".join(seq_hat)
|
||||||
|
ref_chars = "".join(seq_true)
|
||||||
|
else:
|
||||||
|
hyp_chars = " ".join(seq_hat)
|
||||||
|
ref_chars = " ".join(seq_true)
|
||||||
|
|
||||||
|
if len(ref_chars) > 0:
|
||||||
|
cers.append(editdistance.eval(hyp_chars, ref_chars))
|
||||||
|
char_ref_lens.append(len(ref_chars))
|
||||||
|
|
||||||
|
cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
|
||||||
|
return cer_ctc
|
||||||
|
|
||||||
|
def convert_to_char(self, ys_hat, ys_pad):
|
||||||
|
"""Convert index to character.
|
||||||
|
|
||||||
|
:param torch.Tensor seqs_hat: prediction (batch, seqlen)
|
||||||
|
:param torch.Tensor seqs_true: reference (batch, seqlen)
|
||||||
|
:return: token list of prediction
|
||||||
|
:rtype list
|
||||||
|
:return: token list of reference
|
||||||
|
:rtype list
|
||||||
|
"""
|
||||||
|
seqs_hat, seqs_true = [], []
|
||||||
|
for i, y_hat in enumerate(ys_hat):
|
||||||
|
y_true = ys_pad[i]
|
||||||
|
eos_true = np.where(y_true == -1)[0]
|
||||||
|
eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true)
|
||||||
|
# To avoid wrong higher WER than the one obtained from the decoding
|
||||||
|
# eos from y_true is used to mark the eos in y_hat
|
||||||
|
# because of that y_hats has not padded outs with -1.
|
||||||
|
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]]
|
||||||
|
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
|
||||||
|
# seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
|
||||||
|
seq_hat_text = " ".join(seq_hat).replace(self.space, ' ')
|
||||||
|
seq_hat_text = seq_hat_text.replace(self.blank, '')
|
||||||
|
# seq_true_text = "".join(seq_true).replace(self.space, ' ')
|
||||||
|
seq_true_text = " ".join(seq_true).replace(self.space, ' ')
|
||||||
|
seqs_hat.append(seq_hat_text)
|
||||||
|
seqs_true.append(seq_true_text)
|
||||||
|
return seqs_hat, seqs_true
|
||||||
|
|
||||||
|
def calculate_cer(self, seqs_hat, seqs_true):
|
||||||
|
"""Calculate sentence-level CER score.
|
||||||
|
|
||||||
|
:param list seqs_hat: prediction
|
||||||
|
:param list seqs_true: reference
|
||||||
|
:return: average sentence-level CER score
|
||||||
|
:rtype float
|
||||||
|
"""
|
||||||
|
char_eds, char_ref_lens = [], []
|
||||||
|
for i, seq_hat_text in enumerate(seqs_hat):
|
||||||
|
seq_true_text = seqs_true[i]
|
||||||
|
hyp_chars = seq_hat_text.replace(' ', '')
|
||||||
|
ref_chars = seq_true_text.replace(' ', '')
|
||||||
|
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
|
||||||
|
char_ref_lens.append(len(ref_chars))
|
||||||
|
return float(sum(char_eds)) / sum(char_ref_lens)
|
||||||
|
|
||||||
|
def calculate_wer(self, seqs_hat, seqs_true):
|
||||||
|
"""Calculate sentence-level WER score.
|
||||||
|
|
||||||
|
:param list seqs_hat: prediction
|
||||||
|
:param list seqs_true: reference
|
||||||
|
:return: average sentence-level WER score
|
||||||
|
:rtype float
|
||||||
|
"""
|
||||||
|
word_eds, word_ref_lens = [], []
|
||||||
|
for i, seq_hat_text in enumerate(seqs_hat):
|
||||||
|
seq_true_text = seqs_true[i]
|
||||||
|
hyp_words = seq_hat_text.split()
|
||||||
|
ref_words = seq_true_text.split()
|
||||||
|
word_eds.append(editdistance.eval(hyp_words, ref_words))
|
||||||
|
word_ref_lens.append(len(ref_words))
|
||||||
|
return float(sum(word_eds)) / sum(word_ref_lens)
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorCalculatorTrans(object):
|
||||||
|
"""Calculate CER and WER for transducer models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder (nn.Module): decoder module
|
||||||
|
args (Namespace): argument Namespace containing options
|
||||||
|
report_cer (boolean): compute CER option
|
||||||
|
report_wer (boolean): compute WER option
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, decoder, args, report_cer=False, report_wer=False):
|
||||||
|
"""Construct an ErrorCalculator object for transducer model."""
|
||||||
|
super(ErrorCalculatorTrans, self).__init__()
|
||||||
|
|
||||||
|
self.dec = decoder
|
||||||
|
|
||||||
|
recog_args = {'beam_size': args.beam_size,
|
||||||
|
'nbest': args.nbest,
|
||||||
|
'space': args.sym_space,
|
||||||
|
'score_norm_transducer': args.score_norm_transducer}
|
||||||
|
|
||||||
|
self.recog_args = argparse.Namespace(**recog_args)
|
||||||
|
|
||||||
|
self.char_list = args.char_list
|
||||||
|
self.space = args.sym_space
|
||||||
|
self.blank = args.sym_blank
|
||||||
|
|
||||||
|
self.report_cer = args.report_cer
|
||||||
|
self.report_wer = args.report_wer
|
||||||
|
|
||||||
|
def __call__(self, hs_pad, ys_pad):
|
||||||
|
"""Calculate sentence-level WER/CER score for transducer models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hs_pad (torch.Tensor): batch of padded input sequence (batch, T, D)
|
||||||
|
ys_pad (torch.Tensor): reference (batch, seqlen)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(float): sentence-level CER score
|
||||||
|
(float): sentence-level WER score
|
||||||
|
|
||||||
|
"""
|
||||||
|
cer, wer = None, None
|
||||||
|
|
||||||
|
if not self.report_cer and not self.report_wer:
|
||||||
|
return cer, wer
|
||||||
|
|
||||||
|
batchsize = int(hs_pad.size(0))
|
||||||
|
batch_nbest = []
|
||||||
|
|
||||||
|
for b in six.moves.range(batchsize):
|
||||||
|
if self.recog_args.beam_size == 1:
|
||||||
|
nbest_hyps = self.dec.recognize(hs_pad[b], self.recog_args)
|
||||||
|
else:
|
||||||
|
nbest_hyps = self.dec.recognize_beam(hs_pad[b], self.recog_args)
|
||||||
|
batch_nbest.append(nbest_hyps)
|
||||||
|
|
||||||
|
ys_hat = [nbest_hyp[0]['yseq'][1:] for nbest_hyp in batch_nbest]
|
||||||
|
|
||||||
|
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad.cpu())
|
||||||
|
|
||||||
|
if self.report_cer:
|
||||||
|
cer = self.calculate_cer(seqs_hat, seqs_true)
|
||||||
|
|
||||||
|
if self.report_wer:
|
||||||
|
wer = self.calculate_wer(seqs_hat, seqs_true)
|
||||||
|
|
||||||
|
return cer, wer
|
||||||
|
|
||||||
|
def convert_to_char(self, ys_hat, ys_pad):
|
||||||
|
"""Convert index to character.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ys_hat (torch.Tensor): prediction (batch, seqlen)
|
||||||
|
ys_pad (torch.Tensor): reference (batch, seqlen)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(list): token list of prediction
|
||||||
|
(list): token list of reference
|
||||||
|
|
||||||
|
"""
|
||||||
|
seqs_hat, seqs_true = [], []
|
||||||
|
|
||||||
|
for i, y_hat in enumerate(ys_hat):
|
||||||
|
y_true = ys_pad[i]
|
||||||
|
|
||||||
|
eos_true = np.where(y_true == -1)[0]
|
||||||
|
eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true)
|
||||||
|
|
||||||
|
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]]
|
||||||
|
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
|
||||||
|
|
||||||
|
seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
|
||||||
|
seq_hat_text = seq_hat_text.replace(self.blank, '')
|
||||||
|
seq_true_text = "".join(seq_true).replace(self.space, ' ')
|
||||||
|
|
||||||
|
seqs_hat.append(seq_hat_text)
|
||||||
|
seqs_true.append(seq_true_text)
|
||||||
|
|
||||||
|
return seqs_hat, seqs_true
|
||||||
|
|
||||||
|
def calculate_cer(self, seqs_hat, seqs_true):
|
||||||
|
"""Calculate sentence-level CER score for transducer model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seqs_hat (torch.Tensor): prediction (batch, seqlen)
|
||||||
|
seqs_true (torch.Tensor): reference (batch, seqlen)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(float): average sentence-level CER score
|
||||||
|
|
||||||
|
"""
|
||||||
|
char_eds, char_ref_lens = [], []
|
||||||
|
|
||||||
|
for i, seq_hat_text in enumerate(seqs_hat):
|
||||||
|
seq_true_text = seqs_true[i]
|
||||||
|
hyp_chars = seq_hat_text.replace(' ', '')
|
||||||
|
ref_chars = seq_true_text.replace(' ', '')
|
||||||
|
|
||||||
|
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
|
||||||
|
char_ref_lens.append(len(ref_chars))
|
||||||
|
|
||||||
|
return float(sum(char_eds)) / sum(char_ref_lens)
|
||||||
|
|
||||||
|
def calculate_wer(self, seqs_hat, seqs_true):
|
||||||
|
"""Calculate sentence-level WER score for transducer model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seqs_hat (torch.Tensor): prediction (batch, seqlen)
|
||||||
|
seqs_true (torch.Tensor): reference (batch, seqlen)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(float): average sentence-level WER score
|
||||||
|
|
||||||
|
"""
|
||||||
|
word_eds, word_ref_lens = [], []
|
||||||
|
|
||||||
|
for i, seq_hat_text in enumerate(seqs_hat):
|
||||||
|
seq_true_text = seqs_true[i]
|
||||||
|
hyp_words = seq_hat_text.split()
|
||||||
|
ref_words = seq_true_text.split()
|
||||||
|
|
||||||
|
word_eds.append(editdistance.eval(hyp_words, ref_words))
|
||||||
|
word_ref_lens.append(len(ref_words))
|
||||||
|
|
||||||
|
return float(sum(word_eds)) / sum(word_ref_lens)
|
||||||
0
ppg_extractor/encoder/__init__.py
Normal file
0
ppg_extractor/encoder/__init__.py
Normal file
183
ppg_extractor/encoder/attention.py
Normal file
183
ppg_extractor/encoder/attention.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright 2019 Shigeki Karita
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""Multi-Head Attention layer definition."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadedAttention(nn.Module):
|
||||||
|
"""Multi-Head Attention layer.
|
||||||
|
|
||||||
|
:param int n_head: the number of head s
|
||||||
|
:param int n_feat: the number of features
|
||||||
|
:param float dropout_rate: dropout rate
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_head, n_feat, dropout_rate):
|
||||||
|
"""Construct an MultiHeadedAttention object."""
|
||||||
|
super(MultiHeadedAttention, self).__init__()
|
||||||
|
assert n_feat % n_head == 0
|
||||||
|
# We assume d_v always equals d_k
|
||||||
|
self.d_k = n_feat // n_head
|
||||||
|
self.h = n_head
|
||||||
|
self.linear_q = nn.Linear(n_feat, n_feat)
|
||||||
|
self.linear_k = nn.Linear(n_feat, n_feat)
|
||||||
|
self.linear_v = nn.Linear(n_feat, n_feat)
|
||||||
|
self.linear_out = nn.Linear(n_feat, n_feat)
|
||||||
|
self.attn = None
|
||||||
|
self.dropout = nn.Dropout(p=dropout_rate)
|
||||||
|
|
||||||
|
def forward_qkv(self, query, key, value):
|
||||||
|
"""Transform query, key and value.
|
||||||
|
|
||||||
|
:param torch.Tensor query: (batch, time1, size)
|
||||||
|
:param torch.Tensor key: (batch, time2, size)
|
||||||
|
:param torch.Tensor value: (batch, time2, size)
|
||||||
|
:return torch.Tensor transformed query, key and value
|
||||||
|
|
||||||
|
"""
|
||||||
|
n_batch = query.size(0)
|
||||||
|
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
||||||
|
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
||||||
|
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
||||||
|
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
||||||
|
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
||||||
|
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
||||||
|
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
def forward_attention(self, value, scores, mask):
|
||||||
|
"""Compute attention context vector.
|
||||||
|
|
||||||
|
:param torch.Tensor value: (batch, head, time2, size)
|
||||||
|
:param torch.Tensor scores: (batch, head, time1, time2)
|
||||||
|
:param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2)
|
||||||
|
:return torch.Tensor transformed `value` (batch, time1, d_model)
|
||||||
|
weighted by the attention score (batch, time1, time2)
|
||||||
|
|
||||||
|
"""
|
||||||
|
n_batch = value.size(0)
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||||
|
min_value = float(
|
||||||
|
numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
|
||||||
|
)
|
||||||
|
scores = scores.masked_fill(mask, min_value)
|
||||||
|
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||||
|
mask, 0.0
|
||||||
|
) # (batch, head, time1, time2)
|
||||||
|
else:
|
||||||
|
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
||||||
|
|
||||||
|
p_attn = self.dropout(self.attn)
|
||||||
|
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||||
|
x = (
|
||||||
|
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
||||||
|
) # (batch, time1, d_model)
|
||||||
|
|
||||||
|
return self.linear_out(x) # (batch, time1, d_model)
|
||||||
|
|
||||||
|
def forward(self, query, key, value, mask):
|
||||||
|
"""Compute 'Scaled Dot Product Attention'.
|
||||||
|
|
||||||
|
:param torch.Tensor query: (batch, time1, size)
|
||||||
|
:param torch.Tensor key: (batch, time2, size)
|
||||||
|
:param torch.Tensor value: (batch, time2, size)
|
||||||
|
:param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2)
|
||||||
|
:param torch.nn.Dropout dropout:
|
||||||
|
:return torch.Tensor: attention output (batch, time1, d_model)
|
||||||
|
"""
|
||||||
|
q, k, v = self.forward_qkv(query, key, value)
|
||||||
|
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
||||||
|
return self.forward_attention(v, scores, mask)
|
||||||
|
|
||||||
|
|
||||||
|
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
||||||
|
"""Multi-Head Attention layer with relative position encoding.
|
||||||
|
|
||||||
|
Paper: https://arxiv.org/abs/1901.02860
|
||||||
|
|
||||||
|
:param int n_head: the number of head s
|
||||||
|
:param int n_feat: the number of features
|
||||||
|
:param float dropout_rate: dropout rate
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_head, n_feat, dropout_rate):
|
||||||
|
"""Construct an RelPositionMultiHeadedAttention object."""
|
||||||
|
super().__init__(n_head, n_feat, dropout_rate)
|
||||||
|
# linear transformation for positional ecoding
|
||||||
|
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
||||||
|
# these two learnable bias are used in matrix c and matrix d
|
||||||
|
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||||
|
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||||
|
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||||
|
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||||
|
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||||
|
|
||||||
|
def rel_shift(self, x, zero_triu=False):
|
||||||
|
"""Compute relative positinal encoding.
|
||||||
|
|
||||||
|
:param torch.Tensor x: (batch, time, size)
|
||||||
|
:param bool zero_triu: return the lower triangular part of the matrix
|
||||||
|
"""
|
||||||
|
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
||||||
|
x_padded = torch.cat([zero_pad, x], dim=-1)
|
||||||
|
|
||||||
|
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
||||||
|
x = x_padded[:, :, 1:].view_as(x)
|
||||||
|
|
||||||
|
if zero_triu:
|
||||||
|
ones = torch.ones((x.size(2), x.size(3)))
|
||||||
|
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, query, key, value, pos_emb, mask):
|
||||||
|
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
||||||
|
|
||||||
|
:param torch.Tensor query: (batch, time1, size)
|
||||||
|
:param torch.Tensor key: (batch, time2, size)
|
||||||
|
:param torch.Tensor value: (batch, time2, size)
|
||||||
|
:param torch.Tensor pos_emb: (batch, time1, size)
|
||||||
|
:param torch.Tensor mask: (batch, time1, time2)
|
||||||
|
:param torch.nn.Dropout dropout:
|
||||||
|
:return torch.Tensor: attention output (batch, time1, d_model)
|
||||||
|
"""
|
||||||
|
q, k, v = self.forward_qkv(query, key, value)
|
||||||
|
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
||||||
|
|
||||||
|
n_batch_pos = pos_emb.size(0)
|
||||||
|
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
||||||
|
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
||||||
|
|
||||||
|
# (batch, head, time1, d_k)
|
||||||
|
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
||||||
|
# (batch, head, time1, d_k)
|
||||||
|
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
||||||
|
|
||||||
|
# compute attention score
|
||||||
|
# first compute matrix a and matrix c
|
||||||
|
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||||
|
# (batch, head, time1, time2)
|
||||||
|
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
||||||
|
|
||||||
|
# compute matrix b and matrix d
|
||||||
|
# (batch, head, time1, time2)
|
||||||
|
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
||||||
|
matrix_bd = self.rel_shift(matrix_bd)
|
||||||
|
|
||||||
|
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
||||||
|
self.d_k
|
||||||
|
) # (batch, head, time1, time2)
|
||||||
|
|
||||||
|
return self.forward_attention(v, scores, mask)
|
||||||
262
ppg_extractor/encoder/conformer_encoder.py
Normal file
262
ppg_extractor/encoder/conformer_encoder.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright 2019 Shigeki Karita
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""Encoder definition."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
from typing import Callable
|
||||||
|
from typing import Collection
|
||||||
|
from typing import Dict
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from .convolution import ConvolutionModule
|
||||||
|
from .encoder_layer import EncoderLayer
|
||||||
|
from ..nets_utils import get_activation, make_pad_mask
|
||||||
|
from .vgg import VGG2L
|
||||||
|
from .attention import MultiHeadedAttention, RelPositionMultiHeadedAttention
|
||||||
|
from .embedding import PositionalEncoding, ScaledPositionalEncoding, RelPositionalEncoding
|
||||||
|
from .layer_norm import LayerNorm
|
||||||
|
from .multi_layer_conv import Conv1dLinear, MultiLayeredConv1d
|
||||||
|
from .positionwise_feed_forward import PositionwiseFeedForward
|
||||||
|
from .repeat import repeat
|
||||||
|
from .subsampling import Conv2dNoSubsampling, Conv2dSubsampling
|
||||||
|
|
||||||
|
|
||||||
|
class ConformerEncoder(torch.nn.Module):
|
||||||
|
"""Conformer encoder module.
|
||||||
|
|
||||||
|
:param int idim: input dim
|
||||||
|
:param int attention_dim: dimention of attention
|
||||||
|
:param int attention_heads: the number of heads of multi head attention
|
||||||
|
:param int linear_units: the number of units of position-wise feed forward
|
||||||
|
:param int num_blocks: the number of decoder blocks
|
||||||
|
:param float dropout_rate: dropout rate
|
||||||
|
:param float attention_dropout_rate: dropout rate in attention
|
||||||
|
:param float positional_dropout_rate: dropout rate after adding positional encoding
|
||||||
|
:param str or torch.nn.Module input_layer: input layer type
|
||||||
|
:param bool normalize_before: whether to use layer_norm before the first block
|
||||||
|
:param bool concat_after: whether to concat attention layer's input and output
|
||||||
|
if True, additional linear will be applied.
|
||||||
|
i.e. x -> x + linear(concat(x, att(x)))
|
||||||
|
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||||
|
:param str positionwise_layer_type: linear of conv1d
|
||||||
|
:param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
|
||||||
|
:param str encoder_pos_enc_layer_type: encoder positional encoding layer type
|
||||||
|
:param str encoder_attn_layer_type: encoder attention layer type
|
||||||
|
:param str activation_type: encoder activation function type
|
||||||
|
:param bool macaron_style: whether to use macaron style for positionwise layer
|
||||||
|
:param bool use_cnn_module: whether to use convolution module
|
||||||
|
:param int cnn_module_kernel: kernerl size of convolution module
|
||||||
|
:param int padding_idx: padding_idx for input_layer=embed
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size,
|
||||||
|
attention_dim=256,
|
||||||
|
attention_heads=4,
|
||||||
|
linear_units=2048,
|
||||||
|
num_blocks=6,
|
||||||
|
dropout_rate=0.1,
|
||||||
|
positional_dropout_rate=0.1,
|
||||||
|
attention_dropout_rate=0.0,
|
||||||
|
input_layer="conv2d",
|
||||||
|
normalize_before=True,
|
||||||
|
concat_after=False,
|
||||||
|
positionwise_layer_type="linear",
|
||||||
|
positionwise_conv_kernel_size=1,
|
||||||
|
macaron_style=False,
|
||||||
|
pos_enc_layer_type="abs_pos",
|
||||||
|
selfattention_layer_type="selfattn",
|
||||||
|
activation_type="swish",
|
||||||
|
use_cnn_module=False,
|
||||||
|
cnn_module_kernel=31,
|
||||||
|
padding_idx=-1,
|
||||||
|
no_subsample=False,
|
||||||
|
subsample_by_2=False,
|
||||||
|
):
|
||||||
|
"""Construct an Encoder object."""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._output_size = attention_dim
|
||||||
|
idim = input_size
|
||||||
|
|
||||||
|
activation = get_activation(activation_type)
|
||||||
|
if pos_enc_layer_type == "abs_pos":
|
||||||
|
pos_enc_class = PositionalEncoding
|
||||||
|
elif pos_enc_layer_type == "scaled_abs_pos":
|
||||||
|
pos_enc_class = ScaledPositionalEncoding
|
||||||
|
elif pos_enc_layer_type == "rel_pos":
|
||||||
|
assert selfattention_layer_type == "rel_selfattn"
|
||||||
|
pos_enc_class = RelPositionalEncoding
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
||||||
|
|
||||||
|
if input_layer == "linear":
|
||||||
|
self.embed = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(idim, attention_dim),
|
||||||
|
torch.nn.LayerNorm(attention_dim),
|
||||||
|
torch.nn.Dropout(dropout_rate),
|
||||||
|
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||||
|
)
|
||||||
|
elif input_layer == "conv2d":
|
||||||
|
logging.info("Encoder input layer type: conv2d")
|
||||||
|
if no_subsample:
|
||||||
|
self.embed = Conv2dNoSubsampling(
|
||||||
|
idim,
|
||||||
|
attention_dim,
|
||||||
|
dropout_rate,
|
||||||
|
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.embed = Conv2dSubsampling(
|
||||||
|
idim,
|
||||||
|
attention_dim,
|
||||||
|
dropout_rate,
|
||||||
|
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||||
|
subsample_by_2, # NOTE(Sx): added by songxiang
|
||||||
|
)
|
||||||
|
elif input_layer == "vgg2l":
|
||||||
|
self.embed = VGG2L(idim, attention_dim)
|
||||||
|
elif input_layer == "embed":
|
||||||
|
self.embed = torch.nn.Sequential(
|
||||||
|
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
|
||||||
|
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||||
|
)
|
||||||
|
elif isinstance(input_layer, torch.nn.Module):
|
||||||
|
self.embed = torch.nn.Sequential(
|
||||||
|
input_layer,
|
||||||
|
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||||
|
)
|
||||||
|
elif input_layer is None:
|
||||||
|
self.embed = torch.nn.Sequential(
|
||||||
|
pos_enc_class(attention_dim, positional_dropout_rate)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown input_layer: " + input_layer)
|
||||||
|
self.normalize_before = normalize_before
|
||||||
|
if positionwise_layer_type == "linear":
|
||||||
|
positionwise_layer = PositionwiseFeedForward
|
||||||
|
positionwise_layer_args = (
|
||||||
|
attention_dim,
|
||||||
|
linear_units,
|
||||||
|
dropout_rate,
|
||||||
|
activation,
|
||||||
|
)
|
||||||
|
elif positionwise_layer_type == "conv1d":
|
||||||
|
positionwise_layer = MultiLayeredConv1d
|
||||||
|
positionwise_layer_args = (
|
||||||
|
attention_dim,
|
||||||
|
linear_units,
|
||||||
|
positionwise_conv_kernel_size,
|
||||||
|
dropout_rate,
|
||||||
|
)
|
||||||
|
elif positionwise_layer_type == "conv1d-linear":
|
||||||
|
positionwise_layer = Conv1dLinear
|
||||||
|
positionwise_layer_args = (
|
||||||
|
attention_dim,
|
||||||
|
linear_units,
|
||||||
|
positionwise_conv_kernel_size,
|
||||||
|
dropout_rate,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Support only linear or conv1d.")
|
||||||
|
|
||||||
|
if selfattention_layer_type == "selfattn":
|
||||||
|
logging.info("encoder self-attention layer type = self-attention")
|
||||||
|
encoder_selfattn_layer = MultiHeadedAttention
|
||||||
|
encoder_selfattn_layer_args = (
|
||||||
|
attention_heads,
|
||||||
|
attention_dim,
|
||||||
|
attention_dropout_rate,
|
||||||
|
)
|
||||||
|
elif selfattention_layer_type == "rel_selfattn":
|
||||||
|
assert pos_enc_layer_type == "rel_pos"
|
||||||
|
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||||||
|
encoder_selfattn_layer_args = (
|
||||||
|
attention_heads,
|
||||||
|
attention_dim,
|
||||||
|
attention_dropout_rate,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
|
||||||
|
|
||||||
|
convolution_layer = ConvolutionModule
|
||||||
|
convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
|
||||||
|
|
||||||
|
self.encoders = repeat(
|
||||||
|
num_blocks,
|
||||||
|
lambda lnum: EncoderLayer(
|
||||||
|
attention_dim,
|
||||||
|
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||||
|
positionwise_layer(*positionwise_layer_args),
|
||||||
|
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
||||||
|
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
||||||
|
dropout_rate,
|
||||||
|
normalize_before,
|
||||||
|
concat_after,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if self.normalize_before:
|
||||||
|
self.after_norm = LayerNorm(attention_dim)
|
||||||
|
|
||||||
|
def output_size(self) -> int:
|
||||||
|
return self._output_size
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
xs_pad: torch.Tensor,
|
||||||
|
ilens: torch.Tensor,
|
||||||
|
prev_states: torch.Tensor = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
xs_pad: input tensor (B, L, D)
|
||||||
|
ilens: input lengths (B)
|
||||||
|
prev_states: Not to be used now.
|
||||||
|
Returns:
|
||||||
|
Position embedded tensor and mask
|
||||||
|
"""
|
||||||
|
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||||
|
|
||||||
|
if isinstance(self.embed, (Conv2dSubsampling, Conv2dNoSubsampling, VGG2L)):
|
||||||
|
# print(xs_pad.shape)
|
||||||
|
xs_pad, masks = self.embed(xs_pad, masks)
|
||||||
|
# print(xs_pad[0].size())
|
||||||
|
else:
|
||||||
|
xs_pad = self.embed(xs_pad)
|
||||||
|
xs_pad, masks = self.encoders(xs_pad, masks)
|
||||||
|
if isinstance(xs_pad, tuple):
|
||||||
|
xs_pad = xs_pad[0]
|
||||||
|
|
||||||
|
if self.normalize_before:
|
||||||
|
xs_pad = self.after_norm(xs_pad)
|
||||||
|
olens = masks.squeeze(1).sum(1)
|
||||||
|
return xs_pad, olens, None
|
||||||
|
|
||||||
|
# def forward(self, xs, masks):
|
||||||
|
# """Encode input sequence.
|
||||||
|
|
||||||
|
# :param torch.Tensor xs: input tensor
|
||||||
|
# :param torch.Tensor masks: input mask
|
||||||
|
# :return: position embedded tensor and mask
|
||||||
|
# :rtype Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# """
|
||||||
|
# if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
|
||||||
|
# xs, masks = self.embed(xs, masks)
|
||||||
|
# else:
|
||||||
|
# xs = self.embed(xs)
|
||||||
|
|
||||||
|
# xs, masks = self.encoders(xs, masks)
|
||||||
|
# if isinstance(xs, tuple):
|
||||||
|
# xs = xs[0]
|
||||||
|
|
||||||
|
# if self.normalize_before:
|
||||||
|
# xs = self.after_norm(xs)
|
||||||
|
# return xs, masks
|
||||||
74
ppg_extractor/encoder/convolution.py
Normal file
74
ppg_extractor/encoder/convolution.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
||||||
|
# Northwestern Polytechnical University (Pengcheng Guo)
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""ConvolutionModule definition."""
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class ConvolutionModule(nn.Module):
|
||||||
|
"""ConvolutionModule in Conformer model.
|
||||||
|
|
||||||
|
:param int channels: channels of cnn
|
||||||
|
:param int kernel_size: kernerl size of cnn
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
|
||||||
|
"""Construct an ConvolutionModule object."""
|
||||||
|
super(ConvolutionModule, self).__init__()
|
||||||
|
# kernerl_size should be a odd number for 'SAME' padding
|
||||||
|
assert (kernel_size - 1) % 2 == 0
|
||||||
|
|
||||||
|
self.pointwise_conv1 = nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
2 * channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.depthwise_conv = nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
groups=channels,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.norm = nn.BatchNorm1d(channels)
|
||||||
|
self.pointwise_conv2 = nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Compute convolution module.
|
||||||
|
|
||||||
|
:param torch.Tensor x: (batch, time, size)
|
||||||
|
:return torch.Tensor: convoluted `value` (batch, time, d_model)
|
||||||
|
"""
|
||||||
|
# exchange the temporal dimension and the feature dimension
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
|
||||||
|
# GLU mechanism
|
||||||
|
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
||||||
|
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
||||||
|
|
||||||
|
# 1D Depthwise Conv
|
||||||
|
x = self.depthwise_conv(x)
|
||||||
|
x = self.activation(self.norm(x))
|
||||||
|
|
||||||
|
x = self.pointwise_conv2(x)
|
||||||
|
|
||||||
|
return x.transpose(1, 2)
|
||||||
166
ppg_extractor/encoder/embedding.py
Normal file
166
ppg_extractor/encoder/embedding.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright 2019 Shigeki Karita
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""Positonal Encoding Module."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def _pre_hook(
|
||||||
|
state_dict,
|
||||||
|
prefix,
|
||||||
|
local_metadata,
|
||||||
|
strict,
|
||||||
|
missing_keys,
|
||||||
|
unexpected_keys,
|
||||||
|
error_msgs,
|
||||||
|
):
|
||||||
|
"""Perform pre-hook in load_state_dict for backward compatibility.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
We saved self.pe until v.0.5.2 but we have omitted it later.
|
||||||
|
Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
|
||||||
|
|
||||||
|
"""
|
||||||
|
k = prefix + "pe"
|
||||||
|
if k in state_dict:
|
||||||
|
state_dict.pop(k)
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalEncoding(torch.nn.Module):
|
||||||
|
"""Positional encoding.
|
||||||
|
|
||||||
|
:param int d_model: embedding dim
|
||||||
|
:param float dropout_rate: dropout rate
|
||||||
|
:param int max_len: maximum input length
|
||||||
|
:param reverse: whether to reverse the input position
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
|
||||||
|
"""Construct an PositionalEncoding object."""
|
||||||
|
super(PositionalEncoding, self).__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.reverse = reverse
|
||||||
|
self.xscale = math.sqrt(self.d_model)
|
||||||
|
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||||
|
self.pe = None
|
||||||
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||||
|
self._register_load_state_dict_pre_hook(_pre_hook)
|
||||||
|
|
||||||
|
def extend_pe(self, x):
|
||||||
|
"""Reset the positional encodings."""
|
||||||
|
if self.pe is not None:
|
||||||
|
if self.pe.size(1) >= x.size(1):
|
||||||
|
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||||
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||||
|
return
|
||||||
|
pe = torch.zeros(x.size(1), self.d_model)
|
||||||
|
if self.reverse:
|
||||||
|
position = torch.arange(
|
||||||
|
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
||||||
|
).unsqueeze(1)
|
||||||
|
else:
|
||||||
|
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||||
|
div_term = torch.exp(
|
||||||
|
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||||
|
* -(math.log(10000.0) / self.d_model)
|
||||||
|
)
|
||||||
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
pe = pe.unsqueeze(0)
|
||||||
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""Add positional encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.extend_pe(x)
|
||||||
|
x = x * self.xscale + self.pe[:, : x.size(1)]
|
||||||
|
return self.dropout(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledPositionalEncoding(PositionalEncoding):
|
||||||
|
"""Scaled positional encoding module.
|
||||||
|
|
||||||
|
See also: Sec. 3.2 https://arxiv.org/pdf/1809.08895.pdf
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, d_model, dropout_rate, max_len=5000):
|
||||||
|
"""Initialize class.
|
||||||
|
|
||||||
|
:param int d_model: embedding dim
|
||||||
|
:param float dropout_rate: dropout rate
|
||||||
|
:param int max_len: maximum input length
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
|
||||||
|
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
"""Reset parameters."""
|
||||||
|
self.alpha.data = torch.tensor(1.0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Add positional encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.extend_pe(x)
|
||||||
|
x = x + self.alpha * self.pe[:, : x.size(1)]
|
||||||
|
return self.dropout(x)
|
||||||
|
|
||||||
|
|
||||||
|
class RelPositionalEncoding(PositionalEncoding):
|
||||||
|
"""Relitive positional encoding module.
|
||||||
|
|
||||||
|
See : Appendix B in https://arxiv.org/abs/1901.02860
|
||||||
|
|
||||||
|
:param int d_model: embedding dim
|
||||||
|
:param float dropout_rate: dropout rate
|
||||||
|
:param int max_len: maximum input length
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, d_model, dropout_rate, max_len=5000):
|
||||||
|
"""Initialize class.
|
||||||
|
|
||||||
|
:param int d_model: embedding dim
|
||||||
|
:param float dropout_rate: dropout rate
|
||||||
|
:param int max_len: maximum input length
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Compute positional encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: x. Its shape is (batch, time, ...)
|
||||||
|
torch.Tensor: pos_emb. Its shape is (1, time, ...)
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.extend_pe(x)
|
||||||
|
x = x * self.xscale
|
||||||
|
pos_emb = self.pe[:, : x.size(1)]
|
||||||
|
return self.dropout(x), self.dropout(pos_emb)
|
||||||
217
ppg_extractor/encoder/encoder.py
Normal file
217
ppg_extractor/encoder/encoder.py
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright 2019 Shigeki Karita
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""Encoder definition."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule
|
||||||
|
from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer
|
||||||
|
from espnet.nets.pytorch_backend.nets_utils import get_activation
|
||||||
|
from espnet.nets.pytorch_backend.transducer.vgg import VGG2L
|
||||||
|
from espnet.nets.pytorch_backend.transformer.attention import (
|
||||||
|
MultiHeadedAttention, # noqa: H301
|
||||||
|
RelPositionMultiHeadedAttention, # noqa: H301
|
||||||
|
)
|
||||||
|
from espnet.nets.pytorch_backend.transformer.embedding import (
|
||||||
|
PositionalEncoding, # noqa: H301
|
||||||
|
ScaledPositionalEncoding, # noqa: H301
|
||||||
|
RelPositionalEncoding, # noqa: H301
|
||||||
|
)
|
||||||
|
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
|
||||||
|
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear
|
||||||
|
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d
|
||||||
|
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
|
||||||
|
PositionwiseFeedForward, # noqa: H301
|
||||||
|
)
|
||||||
|
from espnet.nets.pytorch_backend.transformer.repeat import repeat
|
||||||
|
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(torch.nn.Module):
|
||||||
|
"""Conformer encoder module.
|
||||||
|
|
||||||
|
:param int idim: input dim
|
||||||
|
:param int attention_dim: dimention of attention
|
||||||
|
:param int attention_heads: the number of heads of multi head attention
|
||||||
|
:param int linear_units: the number of units of position-wise feed forward
|
||||||
|
:param int num_blocks: the number of decoder blocks
|
||||||
|
:param float dropout_rate: dropout rate
|
||||||
|
:param float attention_dropout_rate: dropout rate in attention
|
||||||
|
:param float positional_dropout_rate: dropout rate after adding positional encoding
|
||||||
|
:param str or torch.nn.Module input_layer: input layer type
|
||||||
|
:param bool normalize_before: whether to use layer_norm before the first block
|
||||||
|
:param bool concat_after: whether to concat attention layer's input and output
|
||||||
|
if True, additional linear will be applied.
|
||||||
|
i.e. x -> x + linear(concat(x, att(x)))
|
||||||
|
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||||
|
:param str positionwise_layer_type: linear of conv1d
|
||||||
|
:param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
|
||||||
|
:param str encoder_pos_enc_layer_type: encoder positional encoding layer type
|
||||||
|
:param str encoder_attn_layer_type: encoder attention layer type
|
||||||
|
:param str activation_type: encoder activation function type
|
||||||
|
:param bool macaron_style: whether to use macaron style for positionwise layer
|
||||||
|
:param bool use_cnn_module: whether to use convolution module
|
||||||
|
:param int cnn_module_kernel: kernerl size of convolution module
|
||||||
|
:param int padding_idx: padding_idx for input_layer=embed
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
idim,
|
||||||
|
attention_dim=256,
|
||||||
|
attention_heads=4,
|
||||||
|
linear_units=2048,
|
||||||
|
num_blocks=6,
|
||||||
|
dropout_rate=0.1,
|
||||||
|
positional_dropout_rate=0.1,
|
||||||
|
attention_dropout_rate=0.0,
|
||||||
|
input_layer="conv2d",
|
||||||
|
normalize_before=True,
|
||||||
|
concat_after=False,
|
||||||
|
positionwise_layer_type="linear",
|
||||||
|
positionwise_conv_kernel_size=1,
|
||||||
|
macaron_style=False,
|
||||||
|
pos_enc_layer_type="abs_pos",
|
||||||
|
selfattention_layer_type="selfattn",
|
||||||
|
activation_type="swish",
|
||||||
|
use_cnn_module=False,
|
||||||
|
cnn_module_kernel=31,
|
||||||
|
padding_idx=-1,
|
||||||
|
):
|
||||||
|
"""Construct an Encoder object."""
|
||||||
|
super(Encoder, self).__init__()
|
||||||
|
|
||||||
|
activation = get_activation(activation_type)
|
||||||
|
if pos_enc_layer_type == "abs_pos":
|
||||||
|
pos_enc_class = PositionalEncoding
|
||||||
|
elif pos_enc_layer_type == "scaled_abs_pos":
|
||||||
|
pos_enc_class = ScaledPositionalEncoding
|
||||||
|
elif pos_enc_layer_type == "rel_pos":
|
||||||
|
assert selfattention_layer_type == "rel_selfattn"
|
||||||
|
pos_enc_class = RelPositionalEncoding
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
||||||
|
|
||||||
|
if input_layer == "linear":
|
||||||
|
self.embed = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(idim, attention_dim),
|
||||||
|
torch.nn.LayerNorm(attention_dim),
|
||||||
|
torch.nn.Dropout(dropout_rate),
|
||||||
|
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||||
|
)
|
||||||
|
elif input_layer == "conv2d":
|
||||||
|
self.embed = Conv2dSubsampling(
|
||||||
|
idim,
|
||||||
|
attention_dim,
|
||||||
|
dropout_rate,
|
||||||
|
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||||
|
)
|
||||||
|
elif input_layer == "vgg2l":
|
||||||
|
self.embed = VGG2L(idim, attention_dim)
|
||||||
|
elif input_layer == "embed":
|
||||||
|
self.embed = torch.nn.Sequential(
|
||||||
|
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
|
||||||
|
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||||
|
)
|
||||||
|
elif isinstance(input_layer, torch.nn.Module):
|
||||||
|
self.embed = torch.nn.Sequential(
|
||||||
|
input_layer,
|
||||||
|
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||||
|
)
|
||||||
|
elif input_layer is None:
|
||||||
|
self.embed = torch.nn.Sequential(
|
||||||
|
pos_enc_class(attention_dim, positional_dropout_rate)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown input_layer: " + input_layer)
|
||||||
|
self.normalize_before = normalize_before
|
||||||
|
if positionwise_layer_type == "linear":
|
||||||
|
positionwise_layer = PositionwiseFeedForward
|
||||||
|
positionwise_layer_args = (
|
||||||
|
attention_dim,
|
||||||
|
linear_units,
|
||||||
|
dropout_rate,
|
||||||
|
activation,
|
||||||
|
)
|
||||||
|
elif positionwise_layer_type == "conv1d":
|
||||||
|
positionwise_layer = MultiLayeredConv1d
|
||||||
|
positionwise_layer_args = (
|
||||||
|
attention_dim,
|
||||||
|
linear_units,
|
||||||
|
positionwise_conv_kernel_size,
|
||||||
|
dropout_rate,
|
||||||
|
)
|
||||||
|
elif positionwise_layer_type == "conv1d-linear":
|
||||||
|
positionwise_layer = Conv1dLinear
|
||||||
|
positionwise_layer_args = (
|
||||||
|
attention_dim,
|
||||||
|
linear_units,
|
||||||
|
positionwise_conv_kernel_size,
|
||||||
|
dropout_rate,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Support only linear or conv1d.")
|
||||||
|
|
||||||
|
if selfattention_layer_type == "selfattn":
|
||||||
|
logging.info("encoder self-attention layer type = self-attention")
|
||||||
|
encoder_selfattn_layer = MultiHeadedAttention
|
||||||
|
encoder_selfattn_layer_args = (
|
||||||
|
attention_heads,
|
||||||
|
attention_dim,
|
||||||
|
attention_dropout_rate,
|
||||||
|
)
|
||||||
|
elif selfattention_layer_type == "rel_selfattn":
|
||||||
|
assert pos_enc_layer_type == "rel_pos"
|
||||||
|
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||||||
|
encoder_selfattn_layer_args = (
|
||||||
|
attention_heads,
|
||||||
|
attention_dim,
|
||||||
|
attention_dropout_rate,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
|
||||||
|
|
||||||
|
convolution_layer = ConvolutionModule
|
||||||
|
convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
|
||||||
|
|
||||||
|
self.encoders = repeat(
|
||||||
|
num_blocks,
|
||||||
|
lambda lnum: EncoderLayer(
|
||||||
|
attention_dim,
|
||||||
|
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||||
|
positionwise_layer(*positionwise_layer_args),
|
||||||
|
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
||||||
|
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
||||||
|
dropout_rate,
|
||||||
|
normalize_before,
|
||||||
|
concat_after,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if self.normalize_before:
|
||||||
|
self.after_norm = LayerNorm(attention_dim)
|
||||||
|
|
||||||
|
def forward(self, xs, masks):
|
||||||
|
"""Encode input sequence.
|
||||||
|
|
||||||
|
:param torch.Tensor xs: input tensor
|
||||||
|
:param torch.Tensor masks: input mask
|
||||||
|
:return: position embedded tensor and mask
|
||||||
|
:rtype Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
|
||||||
|
xs, masks = self.embed(xs, masks)
|
||||||
|
else:
|
||||||
|
xs = self.embed(xs)
|
||||||
|
|
||||||
|
xs, masks = self.encoders(xs, masks)
|
||||||
|
if isinstance(xs, tuple):
|
||||||
|
xs = xs[0]
|
||||||
|
|
||||||
|
if self.normalize_before:
|
||||||
|
xs = self.after_norm(xs)
|
||||||
|
return xs, masks
|
||||||
152
ppg_extractor/encoder/encoder_layer.py
Normal file
152
ppg_extractor/encoder/encoder_layer.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""Encoder self-attention layer definition."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from .layer_norm import LayerNorm
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderLayer(nn.Module):
|
||||||
|
"""Encoder layer module.
|
||||||
|
|
||||||
|
:param int size: input dim
|
||||||
|
:param espnet.nets.pytorch_backend.transformer.attention.
|
||||||
|
MultiHeadedAttention self_attn: self attention module
|
||||||
|
RelPositionMultiHeadedAttention self_attn: self attention module
|
||||||
|
:param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.
|
||||||
|
PositionwiseFeedForward feed_forward:
|
||||||
|
feed forward module
|
||||||
|
:param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward
|
||||||
|
for macaron style
|
||||||
|
PositionwiseFeedForward feed_forward:
|
||||||
|
feed forward module
|
||||||
|
:param espnet.nets.pytorch_backend.conformer.convolution.
|
||||||
|
ConvolutionModule feed_foreard:
|
||||||
|
feed forward module
|
||||||
|
:param float dropout_rate: dropout rate
|
||||||
|
:param bool normalize_before: whether to use layer_norm before the first block
|
||||||
|
:param bool concat_after: whether to concat attention layer's input and output
|
||||||
|
if True, additional linear will be applied.
|
||||||
|
i.e. x -> x + linear(concat(x, att(x)))
|
||||||
|
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
size,
|
||||||
|
self_attn,
|
||||||
|
feed_forward,
|
||||||
|
feed_forward_macaron,
|
||||||
|
conv_module,
|
||||||
|
dropout_rate,
|
||||||
|
normalize_before=True,
|
||||||
|
concat_after=False,
|
||||||
|
):
|
||||||
|
"""Construct an EncoderLayer object."""
|
||||||
|
super(EncoderLayer, self).__init__()
|
||||||
|
self.self_attn = self_attn
|
||||||
|
self.feed_forward = feed_forward
|
||||||
|
self.feed_forward_macaron = feed_forward_macaron
|
||||||
|
self.conv_module = conv_module
|
||||||
|
self.norm_ff = LayerNorm(size) # for the FNN module
|
||||||
|
self.norm_mha = LayerNorm(size) # for the MHA module
|
||||||
|
if feed_forward_macaron is not None:
|
||||||
|
self.norm_ff_macaron = LayerNorm(size)
|
||||||
|
self.ff_scale = 0.5
|
||||||
|
else:
|
||||||
|
self.ff_scale = 1.0
|
||||||
|
if self.conv_module is not None:
|
||||||
|
self.norm_conv = LayerNorm(size) # for the CNN module
|
||||||
|
self.norm_final = LayerNorm(size) # for the final output of the block
|
||||||
|
self.dropout = nn.Dropout(dropout_rate)
|
||||||
|
self.size = size
|
||||||
|
self.normalize_before = normalize_before
|
||||||
|
self.concat_after = concat_after
|
||||||
|
if self.concat_after:
|
||||||
|
self.concat_linear = nn.Linear(size + size, size)
|
||||||
|
|
||||||
|
def forward(self, x_input, mask, cache=None):
|
||||||
|
"""Compute encoded features.
|
||||||
|
|
||||||
|
:param torch.Tensor x_input: encoded source features, w/o pos_emb
|
||||||
|
tuple((batch, max_time_in, size), (1, max_time_in, size))
|
||||||
|
or (batch, max_time_in, size)
|
||||||
|
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
||||||
|
:param torch.Tensor cache: cache for x (batch, max_time_in - 1, size)
|
||||||
|
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
"""
|
||||||
|
if isinstance(x_input, tuple):
|
||||||
|
x, pos_emb = x_input[0], x_input[1]
|
||||||
|
else:
|
||||||
|
x, pos_emb = x_input, None
|
||||||
|
|
||||||
|
# whether to use macaron style
|
||||||
|
if self.feed_forward_macaron is not None:
|
||||||
|
residual = x
|
||||||
|
if self.normalize_before:
|
||||||
|
x = self.norm_ff_macaron(x)
|
||||||
|
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
|
||||||
|
if not self.normalize_before:
|
||||||
|
x = self.norm_ff_macaron(x)
|
||||||
|
|
||||||
|
# multi-headed self-attention module
|
||||||
|
residual = x
|
||||||
|
if self.normalize_before:
|
||||||
|
x = self.norm_mha(x)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
x_q = x
|
||||||
|
else:
|
||||||
|
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
|
||||||
|
x_q = x[:, -1:, :]
|
||||||
|
residual = residual[:, -1:, :]
|
||||||
|
mask = None if mask is None else mask[:, -1:, :]
|
||||||
|
|
||||||
|
if pos_emb is not None:
|
||||||
|
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
|
||||||
|
else:
|
||||||
|
x_att = self.self_attn(x_q, x, x, mask)
|
||||||
|
|
||||||
|
if self.concat_after:
|
||||||
|
x_concat = torch.cat((x, x_att), dim=-1)
|
||||||
|
x = residual + self.concat_linear(x_concat)
|
||||||
|
else:
|
||||||
|
x = residual + self.dropout(x_att)
|
||||||
|
if not self.normalize_before:
|
||||||
|
x = self.norm_mha(x)
|
||||||
|
|
||||||
|
# convolution module
|
||||||
|
if self.conv_module is not None:
|
||||||
|
residual = x
|
||||||
|
if self.normalize_before:
|
||||||
|
x = self.norm_conv(x)
|
||||||
|
x = residual + self.dropout(self.conv_module(x))
|
||||||
|
if not self.normalize_before:
|
||||||
|
x = self.norm_conv(x)
|
||||||
|
|
||||||
|
# feed forward module
|
||||||
|
residual = x
|
||||||
|
if self.normalize_before:
|
||||||
|
x = self.norm_ff(x)
|
||||||
|
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
||||||
|
if not self.normalize_before:
|
||||||
|
x = self.norm_ff(x)
|
||||||
|
|
||||||
|
if self.conv_module is not None:
|
||||||
|
x = self.norm_final(x)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
x = torch.cat([cache, x], dim=1)
|
||||||
|
|
||||||
|
if pos_emb is not None:
|
||||||
|
return (x, pos_emb), mask
|
||||||
|
|
||||||
|
return x, mask
|
||||||
33
ppg_extractor/encoder/layer_norm.py
Normal file
33
ppg_extractor/encoder/layer_norm.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright 2019 Shigeki Karita
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""Layer normalization module."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(torch.nn.LayerNorm):
|
||||||
|
"""Layer normalization module.
|
||||||
|
|
||||||
|
:param int nout: output dim size
|
||||||
|
:param int dim: dimension to be normalized
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, nout, dim=-1):
|
||||||
|
"""Construct an LayerNorm object."""
|
||||||
|
super(LayerNorm, self).__init__(nout, eps=1e-12)
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Apply layer normalization.
|
||||||
|
|
||||||
|
:param torch.Tensor x: input tensor
|
||||||
|
:return: layer normalized tensor
|
||||||
|
:rtype torch.Tensor
|
||||||
|
"""
|
||||||
|
if self.dim == -1:
|
||||||
|
return super(LayerNorm, self).forward(x)
|
||||||
|
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
||||||
105
ppg_extractor/encoder/multi_layer_conv.py
Normal file
105
ppg_extractor/encoder/multi_layer_conv.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright 2019 Tomoki Hayashi
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""Layer modules for FFT block in FastSpeech (Feed-forward Transformer)."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class MultiLayeredConv1d(torch.nn.Module):
|
||||||
|
"""Multi-layered conv1d for Transformer block.
|
||||||
|
|
||||||
|
This is a module of multi-leyered conv1d designed
|
||||||
|
to replace positionwise feed-forward network
|
||||||
|
in Transforner block, which is introduced in
|
||||||
|
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
||||||
|
|
||||||
|
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
||||||
|
https://arxiv.org/pdf/1905.09263.pdf
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
||||||
|
"""Initialize MultiLayeredConv1d module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_chans (int): Number of input channels.
|
||||||
|
hidden_chans (int): Number of hidden channels.
|
||||||
|
kernel_size (int): Kernel size of conv1d.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super(MultiLayeredConv1d, self).__init__()
|
||||||
|
self.w_1 = torch.nn.Conv1d(
|
||||||
|
in_chans,
|
||||||
|
hidden_chans,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
)
|
||||||
|
self.w_2 = torch.nn.Conv1d(
|
||||||
|
hidden_chans,
|
||||||
|
in_chans,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
)
|
||||||
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Batch of input tensors (B, ..., in_chans).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Batch of output tensors (B, ..., hidden_chans).
|
||||||
|
|
||||||
|
"""
|
||||||
|
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
||||||
|
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1dLinear(torch.nn.Module):
|
||||||
|
"""Conv1D + Linear for Transformer block.
|
||||||
|
|
||||||
|
A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
||||||
|
"""Initialize Conv1dLinear module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_chans (int): Number of input channels.
|
||||||
|
hidden_chans (int): Number of hidden channels.
|
||||||
|
kernel_size (int): Kernel size of conv1d.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super(Conv1dLinear, self).__init__()
|
||||||
|
self.w_1 = torch.nn.Conv1d(
|
||||||
|
in_chans,
|
||||||
|
hidden_chans,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
)
|
||||||
|
self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
|
||||||
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Batch of input tensors (B, ..., in_chans).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Batch of output tensors (B, ..., hidden_chans).
|
||||||
|
|
||||||
|
"""
|
||||||
|
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
||||||
|
return self.w_2(self.dropout(x))
|
||||||
31
ppg_extractor/encoder/positionwise_feed_forward.py
Normal file
31
ppg_extractor/encoder/positionwise_feed_forward.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright 2019 Shigeki Karita
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""Positionwise feed forward layer definition."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class PositionwiseFeedForward(torch.nn.Module):
|
||||||
|
"""Positionwise feed forward layer.
|
||||||
|
|
||||||
|
:param int idim: input dimenstion
|
||||||
|
:param int hidden_units: number of hidden units
|
||||||
|
:param float dropout_rate: dropout rate
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
|
||||||
|
"""Construct an PositionwiseFeedForward object."""
|
||||||
|
super(PositionwiseFeedForward, self).__init__()
|
||||||
|
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
||||||
|
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
||||||
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward funciton."""
|
||||||
|
return self.w_2(self.dropout(self.activation(self.w_1(x))))
|
||||||
30
ppg_extractor/encoder/repeat.py
Normal file
30
ppg_extractor/encoder/repeat.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright 2019 Shigeki Karita
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""Repeat the same layer definition."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class MultiSequential(torch.nn.Sequential):
|
||||||
|
"""Multi-input multi-output torch.nn.Sequential."""
|
||||||
|
|
||||||
|
def forward(self, *args):
|
||||||
|
"""Repeat."""
|
||||||
|
for m in self:
|
||||||
|
args = m(*args)
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def repeat(N, fn):
|
||||||
|
"""Repeat module N times.
|
||||||
|
|
||||||
|
:param int N: repeat time
|
||||||
|
:param function fn: function to generate module
|
||||||
|
:return: repeated modules
|
||||||
|
:rtype: MultiSequential
|
||||||
|
"""
|
||||||
|
return MultiSequential(*[fn(n) for n in range(N)])
|
||||||
218
ppg_extractor/encoder/subsampling.py
Normal file
218
ppg_extractor/encoder/subsampling.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright 2019 Shigeki Karita
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""Subsampling layer definition."""
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dSubsampling(torch.nn.Module):
|
||||||
|
"""Convolutional 2D subsampling (to 1/4 length or 1/2 length).
|
||||||
|
|
||||||
|
:param int idim: input dim
|
||||||
|
:param int odim: output dim
|
||||||
|
:param flaot dropout_rate: dropout rate
|
||||||
|
:param torch.nn.Module pos_enc: custom position encoding layer
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, idim, odim, dropout_rate, pos_enc=None,
|
||||||
|
subsample_by_2=False,
|
||||||
|
):
|
||||||
|
"""Construct an Conv2dSubsampling object."""
|
||||||
|
super(Conv2dSubsampling, self).__init__()
|
||||||
|
self.subsample_by_2 = subsample_by_2
|
||||||
|
if subsample_by_2:
|
||||||
|
self.conv = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
)
|
||||||
|
self.out = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(odim * (idim // 2), odim),
|
||||||
|
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.conv = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(1, odim, kernel_size=4, stride=2, padding=1),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
)
|
||||||
|
self.out = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(odim * (idim // 4), odim),
|
||||||
|
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask):
|
||||||
|
"""Subsample x.
|
||||||
|
|
||||||
|
:param torch.Tensor x: input tensor
|
||||||
|
:param torch.Tensor x_mask: input mask
|
||||||
|
:return: subsampled x and mask
|
||||||
|
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
"""
|
||||||
|
x = x.unsqueeze(1) # (b, c, t, f)
|
||||||
|
x = self.conv(x)
|
||||||
|
b, c, t, f = x.size()
|
||||||
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||||
|
if x_mask is None:
|
||||||
|
return x, None
|
||||||
|
if self.subsample_by_2:
|
||||||
|
return x, x_mask[:, :, ::2]
|
||||||
|
else:
|
||||||
|
return x, x_mask[:, :, ::2][:, :, ::2]
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
"""Subsample x.
|
||||||
|
|
||||||
|
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
||||||
|
return the positioning encoding.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if key != -1:
|
||||||
|
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
||||||
|
return self.out[key]
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dNoSubsampling(torch.nn.Module):
|
||||||
|
"""Convolutional 2D without subsampling.
|
||||||
|
|
||||||
|
:param int idim: input dim
|
||||||
|
:param int odim: output dim
|
||||||
|
:param flaot dropout_rate: dropout rate
|
||||||
|
:param torch.nn.Module pos_enc: custom position encoding layer
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
||||||
|
"""Construct an Conv2dSubsampling object."""
|
||||||
|
super().__init__()
|
||||||
|
logging.info("Encoder does not do down-sample on mel-spectrogram.")
|
||||||
|
self.conv = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv2d(odim, odim, kernel_size=5, stride=1, padding=2),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
)
|
||||||
|
self.out = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(odim * idim, odim),
|
||||||
|
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask):
|
||||||
|
"""Subsample x.
|
||||||
|
|
||||||
|
:param torch.Tensor x: input tensor
|
||||||
|
:param torch.Tensor x_mask: input mask
|
||||||
|
:return: subsampled x and mask
|
||||||
|
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
"""
|
||||||
|
x = x.unsqueeze(1) # (b, c, t, f)
|
||||||
|
x = self.conv(x)
|
||||||
|
b, c, t, f = x.size()
|
||||||
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||||
|
if x_mask is None:
|
||||||
|
return x, None
|
||||||
|
return x, x_mask
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
"""Subsample x.
|
||||||
|
|
||||||
|
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
||||||
|
return the positioning encoding.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if key != -1:
|
||||||
|
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
||||||
|
return self.out[key]
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dSubsampling6(torch.nn.Module):
|
||||||
|
"""Convolutional 2D subsampling (to 1/6 length).
|
||||||
|
|
||||||
|
:param int idim: input dim
|
||||||
|
:param int odim: output dim
|
||||||
|
:param flaot dropout_rate: dropout rate
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, idim, odim, dropout_rate):
|
||||||
|
"""Construct an Conv2dSubsampling object."""
|
||||||
|
super(Conv2dSubsampling6, self).__init__()
|
||||||
|
self.conv = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(1, odim, 3, 2),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv2d(odim, odim, 5, 3),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
)
|
||||||
|
self.out = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim),
|
||||||
|
PositionalEncoding(odim, dropout_rate),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask):
|
||||||
|
"""Subsample x.
|
||||||
|
|
||||||
|
:param torch.Tensor x: input tensor
|
||||||
|
:param torch.Tensor x_mask: input mask
|
||||||
|
:return: subsampled x and mask
|
||||||
|
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
"""
|
||||||
|
x = x.unsqueeze(1) # (b, c, t, f)
|
||||||
|
x = self.conv(x)
|
||||||
|
b, c, t, f = x.size()
|
||||||
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||||
|
if x_mask is None:
|
||||||
|
return x, None
|
||||||
|
return x, x_mask[:, :, :-2:2][:, :, :-4:3]
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dSubsampling8(torch.nn.Module):
|
||||||
|
"""Convolutional 2D subsampling (to 1/8 length).
|
||||||
|
|
||||||
|
:param int idim: input dim
|
||||||
|
:param int odim: output dim
|
||||||
|
:param flaot dropout_rate: dropout rate
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, idim, odim, dropout_rate):
|
||||||
|
"""Construct an Conv2dSubsampling object."""
|
||||||
|
super(Conv2dSubsampling8, self).__init__()
|
||||||
|
self.conv = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(1, odim, 3, 2),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv2d(odim, odim, 3, 2),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv2d(odim, odim, 3, 2),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
)
|
||||||
|
self.out = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim),
|
||||||
|
PositionalEncoding(odim, dropout_rate),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask):
|
||||||
|
"""Subsample x.
|
||||||
|
|
||||||
|
:param torch.Tensor x: input tensor
|
||||||
|
:param torch.Tensor x_mask: input mask
|
||||||
|
:return: subsampled x and mask
|
||||||
|
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
"""
|
||||||
|
x = x.unsqueeze(1) # (b, c, t, f)
|
||||||
|
x = self.conv(x)
|
||||||
|
b, c, t, f = x.size()
|
||||||
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||||
|
if x_mask is None:
|
||||||
|
return x, None
|
||||||
|
return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
|
||||||
18
ppg_extractor/encoder/swish.py
Normal file
18
ppg_extractor/encoder/swish.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
||||||
|
# Northwestern Polytechnical University (Pengcheng Guo)
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""Swish() activation function for Conformer."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Swish(torch.nn.Module):
|
||||||
|
"""Construct an Swish object."""
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Return Swich activation function."""
|
||||||
|
return x * torch.sigmoid(x)
|
||||||
77
ppg_extractor/encoder/vgg.py
Normal file
77
ppg_extractor/encoder/vgg.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""VGG2L definition for transformer-transducer."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class VGG2L(torch.nn.Module):
|
||||||
|
"""VGG2L module for transformer-transducer encoder."""
|
||||||
|
|
||||||
|
def __init__(self, idim, odim):
|
||||||
|
"""Construct a VGG2L object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idim (int): dimension of inputs
|
||||||
|
odim (int): dimension of outputs
|
||||||
|
|
||||||
|
"""
|
||||||
|
super(VGG2L, self).__init__()
|
||||||
|
|
||||||
|
self.vgg2l = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(1, 64, 3, stride=1, padding=1),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv2d(64, 64, 3, stride=1, padding=1),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.MaxPool2d((3, 2)),
|
||||||
|
torch.nn.Conv2d(64, 128, 3, stride=1, padding=1),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.MaxPool2d((2, 2)),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.output = torch.nn.Linear(128 * ((idim // 2) // 2), odim)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask):
|
||||||
|
"""VGG2L forward for x.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): input torch (B, T, idim)
|
||||||
|
x_mask (torch.Tensor): (B, 1, T)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x (torch.Tensor): input torch (B, sub(T), attention_dim)
|
||||||
|
x_mask (torch.Tensor): (B, 1, sub(T))
|
||||||
|
|
||||||
|
"""
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
x = self.vgg2l(x)
|
||||||
|
|
||||||
|
b, c, t, f = x.size()
|
||||||
|
|
||||||
|
x = self.output(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||||
|
|
||||||
|
if x_mask is None:
|
||||||
|
return x, None
|
||||||
|
else:
|
||||||
|
x_mask = self.create_new_mask(x_mask, x)
|
||||||
|
|
||||||
|
return x, x_mask
|
||||||
|
|
||||||
|
def create_new_mask(self, x_mask, x):
|
||||||
|
"""Create a subsampled version of x_mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_mask (torch.Tensor): (B, 1, T)
|
||||||
|
x (torch.Tensor): (B, sub(T), attention_dim)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x_mask (torch.Tensor): (B, 1, sub(T))
|
||||||
|
|
||||||
|
"""
|
||||||
|
x_t1 = x_mask.size(2) - (x_mask.size(2) % 3)
|
||||||
|
x_mask = x_mask[:, :, :x_t1][:, :, ::3]
|
||||||
|
|
||||||
|
x_t2 = x_mask.size(2) - (x_mask.size(2) % 2)
|
||||||
|
x_mask = x_mask[:, :, :x_t2][:, :, ::2]
|
||||||
|
|
||||||
|
return x_mask
|
||||||
298
ppg_extractor/encoders.py
Normal file
298
ppg_extractor/encoders.py
Normal file
@@ -0,0 +1,298 @@
|
|||||||
|
import logging
|
||||||
|
import six
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils.rnn import pack_padded_sequence
|
||||||
|
from torch.nn.utils.rnn import pad_packed_sequence
|
||||||
|
|
||||||
|
from .e2e_asr_common import get_vgg2l_odim
|
||||||
|
from .nets_utils import make_pad_mask, to_device
|
||||||
|
|
||||||
|
|
||||||
|
class RNNP(torch.nn.Module):
|
||||||
|
"""RNN with projection layer module
|
||||||
|
|
||||||
|
:param int idim: dimension of inputs
|
||||||
|
:param int elayers: number of encoder layers
|
||||||
|
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
|
||||||
|
:param int hdim: number of projection units
|
||||||
|
:param np.ndarray subsample: list of subsampling numbers
|
||||||
|
:param float dropout: dropout rate
|
||||||
|
:param str typ: The RNN type
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, idim, elayers, cdim, hdim, subsample, dropout, typ="blstm"):
|
||||||
|
super(RNNP, self).__init__()
|
||||||
|
bidir = typ[0] == "b"
|
||||||
|
for i in six.moves.range(elayers):
|
||||||
|
if i == 0:
|
||||||
|
inputdim = idim
|
||||||
|
else:
|
||||||
|
inputdim = hdim
|
||||||
|
rnn = torch.nn.LSTM(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir,
|
||||||
|
batch_first=True) if "lstm" in typ \
|
||||||
|
else torch.nn.GRU(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir, batch_first=True)
|
||||||
|
setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn)
|
||||||
|
# bottleneck layer to merge
|
||||||
|
if bidir:
|
||||||
|
setattr(self, "bt%d" % i, torch.nn.Linear(2 * cdim, hdim))
|
||||||
|
else:
|
||||||
|
setattr(self, "bt%d" % i, torch.nn.Linear(cdim, hdim))
|
||||||
|
|
||||||
|
self.elayers = elayers
|
||||||
|
self.cdim = cdim
|
||||||
|
self.subsample = subsample
|
||||||
|
self.typ = typ
|
||||||
|
self.bidir = bidir
|
||||||
|
|
||||||
|
def forward(self, xs_pad, ilens, prev_state=None):
|
||||||
|
"""RNNP forward
|
||||||
|
|
||||||
|
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
|
||||||
|
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||||
|
:param torch.Tensor prev_state: batch of previous RNN states
|
||||||
|
:return: batch of hidden state sequences (B, Tmax, hdim)
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
|
||||||
|
elayer_states = []
|
||||||
|
for layer in six.moves.range(self.elayers):
|
||||||
|
xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True, enforce_sorted=False)
|
||||||
|
rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer))
|
||||||
|
rnn.flatten_parameters()
|
||||||
|
if prev_state is not None and rnn.bidirectional:
|
||||||
|
prev_state = reset_backward_rnn_state(prev_state)
|
||||||
|
ys, states = rnn(xs_pack, hx=None if prev_state is None else prev_state[layer])
|
||||||
|
elayer_states.append(states)
|
||||||
|
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
|
||||||
|
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
|
||||||
|
sub = self.subsample[layer + 1]
|
||||||
|
if sub > 1:
|
||||||
|
ys_pad = ys_pad[:, ::sub]
|
||||||
|
ilens = [int(i + 1) // sub for i in ilens]
|
||||||
|
# (sum _utt frame_utt) x dim
|
||||||
|
projected = getattr(self, 'bt' + str(layer)
|
||||||
|
)(ys_pad.contiguous().view(-1, ys_pad.size(2)))
|
||||||
|
if layer == self.elayers - 1:
|
||||||
|
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
|
||||||
|
else:
|
||||||
|
xs_pad = torch.tanh(projected.view(ys_pad.size(0), ys_pad.size(1), -1))
|
||||||
|
|
||||||
|
return xs_pad, ilens, elayer_states # x: utt list of frame x dim
|
||||||
|
|
||||||
|
|
||||||
|
class RNN(torch.nn.Module):
|
||||||
|
"""RNN module
|
||||||
|
|
||||||
|
:param int idim: dimension of inputs
|
||||||
|
:param int elayers: number of encoder layers
|
||||||
|
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
|
||||||
|
:param int hdim: number of final projection units
|
||||||
|
:param float dropout: dropout rate
|
||||||
|
:param str typ: The RNN type
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, idim, elayers, cdim, hdim, dropout, typ="blstm"):
|
||||||
|
super(RNN, self).__init__()
|
||||||
|
bidir = typ[0] == "b"
|
||||||
|
self.nbrnn = torch.nn.LSTM(idim, cdim, elayers, batch_first=True,
|
||||||
|
dropout=dropout, bidirectional=bidir) if "lstm" in typ \
|
||||||
|
else torch.nn.GRU(idim, cdim, elayers, batch_first=True, dropout=dropout,
|
||||||
|
bidirectional=bidir)
|
||||||
|
if bidir:
|
||||||
|
self.l_last = torch.nn.Linear(cdim * 2, hdim)
|
||||||
|
else:
|
||||||
|
self.l_last = torch.nn.Linear(cdim, hdim)
|
||||||
|
self.typ = typ
|
||||||
|
|
||||||
|
def forward(self, xs_pad, ilens, prev_state=None):
|
||||||
|
"""RNN forward
|
||||||
|
|
||||||
|
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
|
||||||
|
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||||
|
:param torch.Tensor prev_state: batch of previous RNN states
|
||||||
|
:return: batch of hidden state sequences (B, Tmax, eprojs)
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
|
||||||
|
xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True)
|
||||||
|
self.nbrnn.flatten_parameters()
|
||||||
|
if prev_state is not None and self.nbrnn.bidirectional:
|
||||||
|
# We assume that when previous state is passed, it means that we're streaming the input
|
||||||
|
# and therefore cannot propagate backward BRNN state (otherwise it goes in the wrong direction)
|
||||||
|
prev_state = reset_backward_rnn_state(prev_state)
|
||||||
|
ys, states = self.nbrnn(xs_pack, hx=prev_state)
|
||||||
|
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
|
||||||
|
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
|
||||||
|
# (sum _utt frame_utt) x dim
|
||||||
|
projected = torch.tanh(self.l_last(
|
||||||
|
ys_pad.contiguous().view(-1, ys_pad.size(2))))
|
||||||
|
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
|
||||||
|
return xs_pad, ilens, states # x: utt list of frame x dim
|
||||||
|
|
||||||
|
|
||||||
|
def reset_backward_rnn_state(states):
|
||||||
|
"""Sets backward BRNN states to zeroes - useful in processing of sliding windows over the inputs"""
|
||||||
|
if isinstance(states, (list, tuple)):
|
||||||
|
for state in states:
|
||||||
|
state[1::2] = 0.
|
||||||
|
else:
|
||||||
|
states[1::2] = 0.
|
||||||
|
return states
|
||||||
|
|
||||||
|
|
||||||
|
class VGG2L(torch.nn.Module):
|
||||||
|
"""VGG-like module
|
||||||
|
|
||||||
|
:param int in_channel: number of input channels
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channel=1, downsample=True):
|
||||||
|
super(VGG2L, self).__init__()
|
||||||
|
# CNN layer (VGG motivated)
|
||||||
|
self.conv1_1 = torch.nn.Conv2d(in_channel, 64, 3, stride=1, padding=1)
|
||||||
|
self.conv1_2 = torch.nn.Conv2d(64, 64, 3, stride=1, padding=1)
|
||||||
|
self.conv2_1 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1)
|
||||||
|
self.conv2_2 = torch.nn.Conv2d(128, 128, 3, stride=1, padding=1)
|
||||||
|
|
||||||
|
self.in_channel = in_channel
|
||||||
|
self.downsample = downsample
|
||||||
|
if downsample:
|
||||||
|
self.stride = 2
|
||||||
|
else:
|
||||||
|
self.stride = 1
|
||||||
|
|
||||||
|
def forward(self, xs_pad, ilens, **kwargs):
|
||||||
|
"""VGG2L forward
|
||||||
|
|
||||||
|
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
|
||||||
|
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||||
|
:return: batch of padded hidden state sequences (B, Tmax // 4, 128 * D // 4) if downsample
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
|
||||||
|
|
||||||
|
# x: utt x frame x dim
|
||||||
|
# xs_pad = F.pad_sequence(xs_pad)
|
||||||
|
|
||||||
|
# x: utt x 1 (input channel num) x frame x dim
|
||||||
|
xs_pad = xs_pad.view(xs_pad.size(0), xs_pad.size(1), self.in_channel,
|
||||||
|
xs_pad.size(2) // self.in_channel).transpose(1, 2)
|
||||||
|
|
||||||
|
# NOTE: max_pool1d ?
|
||||||
|
xs_pad = F.relu(self.conv1_1(xs_pad))
|
||||||
|
xs_pad = F.relu(self.conv1_2(xs_pad))
|
||||||
|
if self.downsample:
|
||||||
|
xs_pad = F.max_pool2d(xs_pad, 2, stride=self.stride, ceil_mode=True)
|
||||||
|
|
||||||
|
xs_pad = F.relu(self.conv2_1(xs_pad))
|
||||||
|
xs_pad = F.relu(self.conv2_2(xs_pad))
|
||||||
|
if self.downsample:
|
||||||
|
xs_pad = F.max_pool2d(xs_pad, 2, stride=self.stride, ceil_mode=True)
|
||||||
|
if torch.is_tensor(ilens):
|
||||||
|
ilens = ilens.cpu().numpy()
|
||||||
|
else:
|
||||||
|
ilens = np.array(ilens, dtype=np.float32)
|
||||||
|
if self.downsample:
|
||||||
|
ilens = np.array(np.ceil(ilens / 2), dtype=np.int64)
|
||||||
|
ilens = np.array(
|
||||||
|
np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64).tolist()
|
||||||
|
|
||||||
|
# x: utt_list of frame (remove zeropaded frames) x (input channel num x dim)
|
||||||
|
xs_pad = xs_pad.transpose(1, 2)
|
||||||
|
xs_pad = xs_pad.contiguous().view(
|
||||||
|
xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3))
|
||||||
|
return xs_pad, ilens, None # no state in this layer
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(torch.nn.Module):
|
||||||
|
"""Encoder module
|
||||||
|
|
||||||
|
:param str etype: type of encoder network
|
||||||
|
:param int idim: number of dimensions of encoder network
|
||||||
|
:param int elayers: number of layers of encoder network
|
||||||
|
:param int eunits: number of lstm units of encoder network
|
||||||
|
:param int eprojs: number of projection units of encoder network
|
||||||
|
:param np.ndarray subsample: list of subsampling numbers
|
||||||
|
:param float dropout: dropout rate
|
||||||
|
:param int in_channel: number of input channels
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1):
|
||||||
|
super(Encoder, self).__init__()
|
||||||
|
typ = etype.lstrip("vgg").rstrip("p")
|
||||||
|
if typ not in ['lstm', 'gru', 'blstm', 'bgru']:
|
||||||
|
logging.error("Error: need to specify an appropriate encoder architecture")
|
||||||
|
|
||||||
|
if etype.startswith("vgg"):
|
||||||
|
if etype[-1] == "p":
|
||||||
|
self.enc = torch.nn.ModuleList([VGG2L(in_channel),
|
||||||
|
RNNP(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits,
|
||||||
|
eprojs,
|
||||||
|
subsample, dropout, typ=typ)])
|
||||||
|
logging.info('Use CNN-VGG + ' + typ.upper() + 'P for encoder')
|
||||||
|
else:
|
||||||
|
self.enc = torch.nn.ModuleList([VGG2L(in_channel),
|
||||||
|
RNN(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits,
|
||||||
|
eprojs,
|
||||||
|
dropout, typ=typ)])
|
||||||
|
logging.info('Use CNN-VGG + ' + typ.upper() + ' for encoder')
|
||||||
|
else:
|
||||||
|
if etype[-1] == "p":
|
||||||
|
self.enc = torch.nn.ModuleList(
|
||||||
|
[RNNP(idim, elayers, eunits, eprojs, subsample, dropout, typ=typ)])
|
||||||
|
logging.info(typ.upper() + ' with every-layer projection for encoder')
|
||||||
|
else:
|
||||||
|
self.enc = torch.nn.ModuleList([RNN(idim, elayers, eunits, eprojs, dropout, typ=typ)])
|
||||||
|
logging.info(typ.upper() + ' without projection for encoder')
|
||||||
|
|
||||||
|
def forward(self, xs_pad, ilens, prev_states=None):
|
||||||
|
"""Encoder forward
|
||||||
|
|
||||||
|
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
|
||||||
|
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||||
|
:param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...)
|
||||||
|
:return: batch of hidden state sequences (B, Tmax, eprojs)
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
if prev_states is None:
|
||||||
|
prev_states = [None] * len(self.enc)
|
||||||
|
assert len(prev_states) == len(self.enc)
|
||||||
|
|
||||||
|
current_states = []
|
||||||
|
for module, prev_state in zip(self.enc, prev_states):
|
||||||
|
xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
|
||||||
|
current_states.append(states)
|
||||||
|
|
||||||
|
# make mask to remove bias value in padded part
|
||||||
|
mask = to_device(self, make_pad_mask(ilens).unsqueeze(-1))
|
||||||
|
|
||||||
|
return xs_pad.masked_fill(mask, 0.0), ilens, current_states
|
||||||
|
|
||||||
|
|
||||||
|
def encoder_for(args, idim, subsample):
|
||||||
|
"""Instantiates an encoder module given the program arguments
|
||||||
|
|
||||||
|
:param Namespace args: The arguments
|
||||||
|
:param int or List of integer idim: dimension of input, e.g. 83, or
|
||||||
|
List of dimensions of inputs, e.g. [83,83]
|
||||||
|
:param List or List of List subsample: subsample factors, e.g. [1,2,2,1,1], or
|
||||||
|
List of subsample factors of each encoder. e.g. [[1,2,2,1,1], [1,2,2,1,1]]
|
||||||
|
:rtype torch.nn.Module
|
||||||
|
:return: The encoder module
|
||||||
|
"""
|
||||||
|
num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility
|
||||||
|
if num_encs == 1:
|
||||||
|
# compatible with single encoder asr mode
|
||||||
|
return Encoder(args.etype, idim, args.elayers, args.eunits, args.eprojs, subsample, args.dropout_rate)
|
||||||
|
elif num_encs >= 1:
|
||||||
|
enc_list = torch.nn.ModuleList()
|
||||||
|
for idx in range(num_encs):
|
||||||
|
enc = Encoder(args.etype[idx], idim[idx], args.elayers[idx], args.eunits[idx], args.eprojs, subsample[idx],
|
||||||
|
args.dropout_rate[idx])
|
||||||
|
enc_list.append(enc)
|
||||||
|
return enc_list
|
||||||
|
else:
|
||||||
|
raise ValueError("Number of encoders needs to be more than one. {}".format(num_encs))
|
||||||
115
ppg_extractor/frontend.py
Normal file
115
ppg_extractor/frontend.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
import copy
|
||||||
|
from typing import Tuple
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch_complex.tensor import ComplexTensor
|
||||||
|
|
||||||
|
from .log_mel import LogMel
|
||||||
|
from .stft import Stft
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultFrontend(torch.nn.Module):
|
||||||
|
"""Conventional frontend structure for ASR
|
||||||
|
|
||||||
|
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fs: 16000,
|
||||||
|
n_fft: int = 1024,
|
||||||
|
win_length: int = 800,
|
||||||
|
hop_length: int = 160,
|
||||||
|
center: bool = True,
|
||||||
|
pad_mode: str = "reflect",
|
||||||
|
normalized: bool = False,
|
||||||
|
onesided: bool = True,
|
||||||
|
n_mels: int = 80,
|
||||||
|
fmin: int = None,
|
||||||
|
fmax: int = None,
|
||||||
|
htk: bool = False,
|
||||||
|
norm=1,
|
||||||
|
frontend_conf=None, #Optional[dict] = get_default_kwargs(Frontend),
|
||||||
|
kaldi_padding_mode=False,
|
||||||
|
downsample_rate: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.downsample_rate = downsample_rate
|
||||||
|
|
||||||
|
# Deepcopy (In general, dict shouldn't be used as default arg)
|
||||||
|
frontend_conf = copy.deepcopy(frontend_conf)
|
||||||
|
|
||||||
|
self.stft = Stft(
|
||||||
|
n_fft=n_fft,
|
||||||
|
win_length=win_length,
|
||||||
|
hop_length=hop_length,
|
||||||
|
center=center,
|
||||||
|
pad_mode=pad_mode,
|
||||||
|
normalized=normalized,
|
||||||
|
onesided=onesided,
|
||||||
|
kaldi_padding_mode=kaldi_padding_mode
|
||||||
|
)
|
||||||
|
if frontend_conf is not None:
|
||||||
|
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
|
||||||
|
else:
|
||||||
|
self.frontend = None
|
||||||
|
|
||||||
|
self.logmel = LogMel(
|
||||||
|
fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm,
|
||||||
|
)
|
||||||
|
self.n_mels = n_mels
|
||||||
|
|
||||||
|
def output_size(self) -> int:
|
||||||
|
return self.n_mels
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# 1. Domain-conversion: e.g. Stft: time -> time-freq
|
||||||
|
input_stft, feats_lens = self.stft(input, input_lengths)
|
||||||
|
|
||||||
|
assert input_stft.dim() >= 4, input_stft.shape
|
||||||
|
# "2" refers to the real/imag parts of Complex
|
||||||
|
assert input_stft.shape[-1] == 2, input_stft.shape
|
||||||
|
|
||||||
|
# Change torch.Tensor to ComplexTensor
|
||||||
|
# input_stft: (..., F, 2) -> (..., F)
|
||||||
|
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
|
||||||
|
|
||||||
|
# 2. [Option] Speech enhancement
|
||||||
|
if self.frontend is not None:
|
||||||
|
assert isinstance(input_stft, ComplexTensor), type(input_stft)
|
||||||
|
# input_stft: (Batch, Length, [Channel], Freq)
|
||||||
|
input_stft, _, mask = self.frontend(input_stft, feats_lens)
|
||||||
|
|
||||||
|
# 3. [Multi channel case]: Select a channel
|
||||||
|
if input_stft.dim() == 4:
|
||||||
|
# h: (B, T, C, F) -> h: (B, T, F)
|
||||||
|
if self.training:
|
||||||
|
# Select 1ch randomly
|
||||||
|
ch = np.random.randint(input_stft.size(2))
|
||||||
|
input_stft = input_stft[:, :, ch, :]
|
||||||
|
else:
|
||||||
|
# Use the first channel
|
||||||
|
input_stft = input_stft[:, :, 0, :]
|
||||||
|
|
||||||
|
# 4. STFT -> Power spectrum
|
||||||
|
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
|
||||||
|
input_power = input_stft.real ** 2 + input_stft.imag ** 2
|
||||||
|
|
||||||
|
# 5. Feature transform e.g. Stft -> Log-Mel-Fbank
|
||||||
|
# input_power: (Batch, [Channel,] Length, Freq)
|
||||||
|
# -> input_feats: (Batch, Length, Dim)
|
||||||
|
input_feats, _ = self.logmel(input_power, feats_lens)
|
||||||
|
|
||||||
|
# NOTE(sx): pad
|
||||||
|
max_len = input_feats.size(1)
|
||||||
|
if self.downsample_rate > 1 and max_len % self.downsample_rate != 0:
|
||||||
|
padding = self.downsample_rate - max_len % self.downsample_rate
|
||||||
|
# print("Logmel: ", input_feats.size())
|
||||||
|
input_feats = torch.nn.functional.pad(input_feats, (0, 0, 0, padding),
|
||||||
|
"constant", 0)
|
||||||
|
# print("Logmel(after padding): ",input_feats.size())
|
||||||
|
feats_lens[torch.argmax(feats_lens)] = max_len + padding
|
||||||
|
|
||||||
|
return input_feats, feats_lens
|
||||||
74
ppg_extractor/log_mel.py
Normal file
74
ppg_extractor/log_mel.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from .nets_utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
class LogMel(torch.nn.Module):
|
||||||
|
"""Convert STFT to fbank feats
|
||||||
|
|
||||||
|
The arguments is same as librosa.filters.mel
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fs: number > 0 [scalar] sampling rate of the incoming signal
|
||||||
|
n_fft: int > 0 [scalar] number of FFT components
|
||||||
|
n_mels: int > 0 [scalar] number of Mel bands to generate
|
||||||
|
fmin: float >= 0 [scalar] lowest frequency (in Hz)
|
||||||
|
fmax: float >= 0 [scalar] highest frequency (in Hz).
|
||||||
|
If `None`, use `fmax = fs / 2.0`
|
||||||
|
htk: use HTK formula instead of Slaney
|
||||||
|
norm: {None, 1, np.inf} [scalar]
|
||||||
|
if 1, divide the triangular mel weights by the width of the mel band
|
||||||
|
(area normalization). Otherwise, leave all the triangles aiming for
|
||||||
|
a peak value of 1.0
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fs: int = 16000,
|
||||||
|
n_fft: int = 512,
|
||||||
|
n_mels: int = 80,
|
||||||
|
fmin: float = None,
|
||||||
|
fmax: float = None,
|
||||||
|
htk: bool = False,
|
||||||
|
norm=1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
fmin = 0 if fmin is None else fmin
|
||||||
|
fmax = fs / 2 if fmax is None else fmax
|
||||||
|
_mel_options = dict(
|
||||||
|
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
|
||||||
|
)
|
||||||
|
self.mel_options = _mel_options
|
||||||
|
|
||||||
|
# Note(kamo): The mel matrix of librosa is different from kaldi.
|
||||||
|
melmat = librosa.filters.mel(**_mel_options)
|
||||||
|
# melmat: (D2, D1) -> (D1, D2)
|
||||||
|
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
|
||||||
|
inv_mel = np.linalg.pinv(melmat)
|
||||||
|
self.register_buffer("inv_melmat", torch.from_numpy(inv_mel.T).float())
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, feat: torch.Tensor, ilens: torch.Tensor = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
|
||||||
|
mel_feat = torch.matmul(feat, self.melmat)
|
||||||
|
|
||||||
|
logmel_feat = (mel_feat + 1e-20).log()
|
||||||
|
# Zero padding
|
||||||
|
if ilens is not None:
|
||||||
|
logmel_feat = logmel_feat.masked_fill(
|
||||||
|
make_pad_mask(ilens, logmel_feat, 1), 0.0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ilens = feat.new_full(
|
||||||
|
[feat.size(0)], fill_value=feat.size(1), dtype=torch.long
|
||||||
|
)
|
||||||
|
return logmel_feat, ilens
|
||||||
465
ppg_extractor/nets_utils.py
Normal file
465
ppg_extractor/nets_utils.py
Normal file
@@ -0,0 +1,465 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""Network related utility tools."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def to_device(m, x):
|
||||||
|
"""Send tensor into the device of the module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
m (torch.nn.Module): Torch module.
|
||||||
|
x (Tensor): Torch tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Torch tensor located in the same place as torch module.
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert isinstance(m, torch.nn.Module)
|
||||||
|
device = next(m.parameters()).device
|
||||||
|
return x.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def pad_list(xs, pad_value):
|
||||||
|
"""Perform padding for the list of tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
||||||
|
pad_value (float): Value for padding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Padded tensor (B, Tmax, `*`).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
||||||
|
>>> x
|
||||||
|
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
||||||
|
>>> pad_list(x, 0)
|
||||||
|
tensor([[1., 1., 1., 1.],
|
||||||
|
[1., 1., 0., 0.],
|
||||||
|
[1., 0., 0., 0.]])
|
||||||
|
|
||||||
|
"""
|
||||||
|
n_batch = len(xs)
|
||||||
|
max_len = max(x.size(0) for x in xs)
|
||||||
|
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
||||||
|
|
||||||
|
for i in range(n_batch):
|
||||||
|
pad[i, :xs[i].size(0)] = xs[i]
|
||||||
|
|
||||||
|
return pad
|
||||||
|
|
||||||
|
|
||||||
|
def make_pad_mask(lengths, xs=None, length_dim=-1):
|
||||||
|
"""Make mask tensor containing indices of padded part.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lengths (LongTensor or List): Batch of lengths (B,).
|
||||||
|
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
|
||||||
|
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Mask tensor containing indices of padded part.
|
||||||
|
dtype=torch.uint8 in PyTorch 1.2-
|
||||||
|
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
With only lengths.
|
||||||
|
|
||||||
|
>>> lengths = [5, 3, 2]
|
||||||
|
>>> make_non_pad_mask(lengths)
|
||||||
|
masks = [[0, 0, 0, 0 ,0],
|
||||||
|
[0, 0, 0, 1, 1],
|
||||||
|
[0, 0, 1, 1, 1]]
|
||||||
|
|
||||||
|
With the reference tensor.
|
||||||
|
|
||||||
|
>>> xs = torch.zeros((3, 2, 4))
|
||||||
|
>>> make_pad_mask(lengths, xs)
|
||||||
|
tensor([[[0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0]],
|
||||||
|
[[0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 1]],
|
||||||
|
[[0, 0, 1, 1],
|
||||||
|
[0, 0, 1, 1]]], dtype=torch.uint8)
|
||||||
|
>>> xs = torch.zeros((3, 2, 6))
|
||||||
|
>>> make_pad_mask(lengths, xs)
|
||||||
|
tensor([[[0, 0, 0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 0, 0, 1]],
|
||||||
|
[[0, 0, 0, 1, 1, 1],
|
||||||
|
[0, 0, 0, 1, 1, 1]],
|
||||||
|
[[0, 0, 1, 1, 1, 1],
|
||||||
|
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||||
|
|
||||||
|
With the reference tensor and dimension indicator.
|
||||||
|
|
||||||
|
>>> xs = torch.zeros((3, 6, 6))
|
||||||
|
>>> make_pad_mask(lengths, xs, 1)
|
||||||
|
tensor([[[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[1, 1, 1, 1, 1, 1]],
|
||||||
|
[[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1]],
|
||||||
|
[[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||||
|
>>> make_pad_mask(lengths, xs, 2)
|
||||||
|
tensor([[[0, 0, 0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 0, 0, 1]],
|
||||||
|
[[0, 0, 0, 1, 1, 1],
|
||||||
|
[0, 0, 0, 1, 1, 1],
|
||||||
|
[0, 0, 0, 1, 1, 1],
|
||||||
|
[0, 0, 0, 1, 1, 1],
|
||||||
|
[0, 0, 0, 1, 1, 1],
|
||||||
|
[0, 0, 0, 1, 1, 1]],
|
||||||
|
[[0, 0, 1, 1, 1, 1],
|
||||||
|
[0, 0, 1, 1, 1, 1],
|
||||||
|
[0, 0, 1, 1, 1, 1],
|
||||||
|
[0, 0, 1, 1, 1, 1],
|
||||||
|
[0, 0, 1, 1, 1, 1],
|
||||||
|
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||||
|
|
||||||
|
"""
|
||||||
|
if length_dim == 0:
|
||||||
|
raise ValueError('length_dim cannot be 0: {}'.format(length_dim))
|
||||||
|
|
||||||
|
if not isinstance(lengths, list):
|
||||||
|
lengths = lengths.tolist()
|
||||||
|
bs = int(len(lengths))
|
||||||
|
if xs is None:
|
||||||
|
maxlen = int(max(lengths))
|
||||||
|
else:
|
||||||
|
maxlen = xs.size(length_dim)
|
||||||
|
|
||||||
|
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
||||||
|
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
||||||
|
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
||||||
|
mask = seq_range_expand >= seq_length_expand
|
||||||
|
|
||||||
|
if xs is not None:
|
||||||
|
assert xs.size(0) == bs, (xs.size(0), bs)
|
||||||
|
|
||||||
|
if length_dim < 0:
|
||||||
|
length_dim = xs.dim() + length_dim
|
||||||
|
# ind = (:, None, ..., None, :, , None, ..., None)
|
||||||
|
ind = tuple(slice(None) if i in (0, length_dim) else None
|
||||||
|
for i in range(xs.dim()))
|
||||||
|
mask = mask[ind].expand_as(xs).to(xs.device)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
|
||||||
|
"""Make mask tensor containing indices of non-padded part.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lengths (LongTensor or List): Batch of lengths (B,).
|
||||||
|
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
|
||||||
|
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ByteTensor: mask tensor containing indices of padded part.
|
||||||
|
dtype=torch.uint8 in PyTorch 1.2-
|
||||||
|
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
With only lengths.
|
||||||
|
|
||||||
|
>>> lengths = [5, 3, 2]
|
||||||
|
>>> make_non_pad_mask(lengths)
|
||||||
|
masks = [[1, 1, 1, 1 ,1],
|
||||||
|
[1, 1, 1, 0, 0],
|
||||||
|
[1, 1, 0, 0, 0]]
|
||||||
|
|
||||||
|
With the reference tensor.
|
||||||
|
|
||||||
|
>>> xs = torch.zeros((3, 2, 4))
|
||||||
|
>>> make_non_pad_mask(lengths, xs)
|
||||||
|
tensor([[[1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1]],
|
||||||
|
[[1, 1, 1, 0],
|
||||||
|
[1, 1, 1, 0]],
|
||||||
|
[[1, 1, 0, 0],
|
||||||
|
[1, 1, 0, 0]]], dtype=torch.uint8)
|
||||||
|
>>> xs = torch.zeros((3, 2, 6))
|
||||||
|
>>> make_non_pad_mask(lengths, xs)
|
||||||
|
tensor([[[1, 1, 1, 1, 1, 0],
|
||||||
|
[1, 1, 1, 1, 1, 0]],
|
||||||
|
[[1, 1, 1, 0, 0, 0],
|
||||||
|
[1, 1, 1, 0, 0, 0]],
|
||||||
|
[[1, 1, 0, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||||
|
|
||||||
|
With the reference tensor and dimension indicator.
|
||||||
|
|
||||||
|
>>> xs = torch.zeros((3, 6, 6))
|
||||||
|
>>> make_non_pad_mask(lengths, xs, 1)
|
||||||
|
tensor([[[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[0, 0, 0, 0, 0, 0]],
|
||||||
|
[[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0]],
|
||||||
|
[[1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||||
|
>>> make_non_pad_mask(lengths, xs, 2)
|
||||||
|
tensor([[[1, 1, 1, 1, 1, 0],
|
||||||
|
[1, 1, 1, 1, 1, 0],
|
||||||
|
[1, 1, 1, 1, 1, 0],
|
||||||
|
[1, 1, 1, 1, 1, 0],
|
||||||
|
[1, 1, 1, 1, 1, 0],
|
||||||
|
[1, 1, 1, 1, 1, 0]],
|
||||||
|
[[1, 1, 1, 0, 0, 0],
|
||||||
|
[1, 1, 1, 0, 0, 0],
|
||||||
|
[1, 1, 1, 0, 0, 0],
|
||||||
|
[1, 1, 1, 0, 0, 0],
|
||||||
|
[1, 1, 1, 0, 0, 0],
|
||||||
|
[1, 1, 1, 0, 0, 0]],
|
||||||
|
[[1, 1, 0, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||||
|
|
||||||
|
"""
|
||||||
|
return ~make_pad_mask(lengths, xs, length_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def mask_by_length(xs, lengths, fill=0):
|
||||||
|
"""Mask tensor according to length.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xs (Tensor): Batch of input tensor (B, `*`).
|
||||||
|
lengths (LongTensor or List): Batch of lengths (B,).
|
||||||
|
fill (int or float): Value to fill masked part.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Batch of masked input tensor (B, `*`).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x = torch.arange(5).repeat(3, 1) + 1
|
||||||
|
>>> x
|
||||||
|
tensor([[1, 2, 3, 4, 5],
|
||||||
|
[1, 2, 3, 4, 5],
|
||||||
|
[1, 2, 3, 4, 5]])
|
||||||
|
>>> lengths = [5, 3, 2]
|
||||||
|
>>> mask_by_length(x, lengths)
|
||||||
|
tensor([[1, 2, 3, 4, 5],
|
||||||
|
[1, 2, 3, 0, 0],
|
||||||
|
[1, 2, 0, 0, 0]])
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert xs.size(0) == len(lengths)
|
||||||
|
ret = xs.data.new(*xs.size()).fill_(fill)
|
||||||
|
for i, l in enumerate(lengths):
|
||||||
|
ret[i, :l] = xs[i, :l]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def th_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||||
|
"""Calculate accuracy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||||||
|
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
||||||
|
ignore_label (int): Ignore label id.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Accuracy value (0.0 - 1.0).
|
||||||
|
|
||||||
|
"""
|
||||||
|
pad_pred = pad_outputs.view(
|
||||||
|
pad_targets.size(0),
|
||||||
|
pad_targets.size(1),
|
||||||
|
pad_outputs.size(1)).argmax(2)
|
||||||
|
mask = pad_targets != ignore_label
|
||||||
|
numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||||||
|
denominator = torch.sum(mask)
|
||||||
|
return float(numerator) / float(denominator)
|
||||||
|
|
||||||
|
|
||||||
|
def to_torch_tensor(x):
|
||||||
|
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor or ComplexTensor: Type converted inputs.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> xs = np.ones(3, dtype=np.float32)
|
||||||
|
>>> xs = to_torch_tensor(xs)
|
||||||
|
tensor([1., 1., 1.])
|
||||||
|
>>> xs = torch.ones(3, 4, 5)
|
||||||
|
>>> assert to_torch_tensor(xs) is xs
|
||||||
|
>>> xs = {'real': xs, 'imag': xs}
|
||||||
|
>>> to_torch_tensor(xs)
|
||||||
|
ComplexTensor(
|
||||||
|
Real:
|
||||||
|
tensor([1., 1., 1.])
|
||||||
|
Imag;
|
||||||
|
tensor([1., 1., 1.])
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
# If numpy, change to torch tensor
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
if x.dtype.kind == 'c':
|
||||||
|
# Dynamically importing because torch_complex requires python3
|
||||||
|
from torch_complex.tensor import ComplexTensor
|
||||||
|
return ComplexTensor(x)
|
||||||
|
else:
|
||||||
|
return torch.from_numpy(x)
|
||||||
|
|
||||||
|
# If {'real': ..., 'imag': ...}, convert to ComplexTensor
|
||||||
|
elif isinstance(x, dict):
|
||||||
|
# Dynamically importing because torch_complex requires python3
|
||||||
|
from torch_complex.tensor import ComplexTensor
|
||||||
|
|
||||||
|
if 'real' not in x or 'imag' not in x:
|
||||||
|
raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
|
||||||
|
# Relative importing because of using python3 syntax
|
||||||
|
return ComplexTensor(x['real'], x['imag'])
|
||||||
|
|
||||||
|
# If torch.Tensor, as it is
|
||||||
|
elif isinstance(x, torch.Tensor):
|
||||||
|
return x
|
||||||
|
|
||||||
|
else:
|
||||||
|
error = ("x must be numpy.ndarray, torch.Tensor or a dict like "
|
||||||
|
"{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
|
||||||
|
"but got {}".format(type(x)))
|
||||||
|
try:
|
||||||
|
from torch_complex.tensor import ComplexTensor
|
||||||
|
except Exception:
|
||||||
|
# If PY2
|
||||||
|
raise ValueError(error)
|
||||||
|
else:
|
||||||
|
# If PY3
|
||||||
|
if isinstance(x, ComplexTensor):
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
raise ValueError(error)
|
||||||
|
|
||||||
|
|
||||||
|
def get_subsample(train_args, mode, arch):
|
||||||
|
"""Parse the subsampling factors from the training args for the specified `mode` and `arch`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_args: argument Namespace containing options.
|
||||||
|
mode: one of ('asr', 'mt', 'st')
|
||||||
|
arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray / List[np.ndarray]: subsampling factors.
|
||||||
|
"""
|
||||||
|
if arch == 'transformer':
|
||||||
|
return np.array([1])
|
||||||
|
|
||||||
|
elif mode == 'mt' and arch == 'rnn':
|
||||||
|
# +1 means input (+1) and layers outputs (train_args.elayer)
|
||||||
|
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||||
|
logging.warning('Subsampling is not performed for machine translation.')
|
||||||
|
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||||
|
return subsample
|
||||||
|
|
||||||
|
elif (mode == 'asr' and arch in ('rnn', 'rnn-t')) or \
|
||||||
|
(mode == 'mt' and arch == 'rnn') or \
|
||||||
|
(mode == 'st' and arch == 'rnn'):
|
||||||
|
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||||
|
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||||
|
ss = train_args.subsample.split("_")
|
||||||
|
for j in range(min(train_args.elayers + 1, len(ss))):
|
||||||
|
subsample[j] = int(ss[j])
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
|
||||||
|
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||||
|
return subsample
|
||||||
|
|
||||||
|
elif mode == 'asr' and arch == 'rnn_mix':
|
||||||
|
subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int)
|
||||||
|
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||||
|
ss = train_args.subsample.split("_")
|
||||||
|
for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))):
|
||||||
|
subsample[j] = int(ss[j])
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
|
||||||
|
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||||
|
return subsample
|
||||||
|
|
||||||
|
elif mode == 'asr' and arch == 'rnn_mulenc':
|
||||||
|
subsample_list = []
|
||||||
|
for idx in range(train_args.num_encs):
|
||||||
|
subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int)
|
||||||
|
if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"):
|
||||||
|
ss = train_args.subsample[idx].split("_")
|
||||||
|
for j in range(min(train_args.elayers[idx] + 1, len(ss))):
|
||||||
|
subsample[j] = int(ss[j])
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
'Encoder %d: Subsampling is not performed for vgg*. '
|
||||||
|
'It is performed in max pooling layers at CNN.', idx + 1)
|
||||||
|
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||||
|
subsample_list.append(subsample)
|
||||||
|
return subsample_list
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError('Invalid options: mode={}, arch={}'.format(mode, arch))
|
||||||
|
|
||||||
|
|
||||||
|
def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]):
|
||||||
|
"""Replace keys of old prefix with new prefix in state dict."""
|
||||||
|
# need this list not to break the dict iterator
|
||||||
|
old_keys = [k for k in state_dict if k.startswith(old_prefix)]
|
||||||
|
if len(old_keys) > 0:
|
||||||
|
logging.warning(f'Rename: {old_prefix} -> {new_prefix}')
|
||||||
|
for k in old_keys:
|
||||||
|
v = state_dict.pop(k)
|
||||||
|
new_k = k.replace(old_prefix, new_prefix)
|
||||||
|
state_dict[new_k] = v
|
||||||
|
|
||||||
|
def get_activation(act):
|
||||||
|
"""Return activation function."""
|
||||||
|
# Lazy load to avoid unused import
|
||||||
|
from .encoder.swish import Swish
|
||||||
|
|
||||||
|
activation_funcs = {
|
||||||
|
"hardtanh": torch.nn.Hardtanh,
|
||||||
|
"relu": torch.nn.ReLU,
|
||||||
|
"selu": torch.nn.SELU,
|
||||||
|
"swish": Swish,
|
||||||
|
}
|
||||||
|
|
||||||
|
return activation_funcs[act]()
|
||||||
118
ppg_extractor/stft.py
Normal file
118
ppg_extractor/stft.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .nets_utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
class Stft(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_fft: int = 512,
|
||||||
|
win_length: Union[int, None] = 512,
|
||||||
|
hop_length: int = 128,
|
||||||
|
center: bool = True,
|
||||||
|
pad_mode: str = "reflect",
|
||||||
|
normalized: bool = False,
|
||||||
|
onesided: bool = True,
|
||||||
|
kaldi_padding_mode=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.n_fft = n_fft
|
||||||
|
if win_length is None:
|
||||||
|
self.win_length = n_fft
|
||||||
|
else:
|
||||||
|
self.win_length = win_length
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.center = center
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
self.normalized = normalized
|
||||||
|
self.onesided = onesided
|
||||||
|
self.kaldi_padding_mode = kaldi_padding_mode
|
||||||
|
if self.kaldi_padding_mode:
|
||||||
|
self.win_length = 400
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return (
|
||||||
|
f"n_fft={self.n_fft}, "
|
||||||
|
f"win_length={self.win_length}, "
|
||||||
|
f"hop_length={self.hop_length}, "
|
||||||
|
f"center={self.center}, "
|
||||||
|
f"pad_mode={self.pad_mode}, "
|
||||||
|
f"normalized={self.normalized}, "
|
||||||
|
f"onesided={self.onesided}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor, ilens: torch.Tensor = None
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""STFT forward function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: (Batch, Nsamples) or (Batch, Nsample, Channels)
|
||||||
|
ilens: (Batch)
|
||||||
|
Returns:
|
||||||
|
output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
|
||||||
|
|
||||||
|
"""
|
||||||
|
bs = input.size(0)
|
||||||
|
if input.dim() == 3:
|
||||||
|
multi_channel = True
|
||||||
|
# input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
|
||||||
|
input = input.transpose(1, 2).reshape(-1, input.size(1))
|
||||||
|
else:
|
||||||
|
multi_channel = False
|
||||||
|
|
||||||
|
# output: (Batch, Freq, Frames, 2=real_imag)
|
||||||
|
# or (Batch, Channel, Freq, Frames, 2=real_imag)
|
||||||
|
if not self.kaldi_padding_mode:
|
||||||
|
output = torch.stft(
|
||||||
|
input,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
win_length=self.win_length,
|
||||||
|
hop_length=self.hop_length,
|
||||||
|
center=self.center,
|
||||||
|
pad_mode=self.pad_mode,
|
||||||
|
normalized=self.normalized,
|
||||||
|
onesided=self.onesided,
|
||||||
|
return_complex=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# NOTE(sx): Use Kaldi-fasion padding, maybe wrong
|
||||||
|
num_pads = self.n_fft - self.win_length
|
||||||
|
input = torch.nn.functional.pad(input, (num_pads, 0))
|
||||||
|
output = torch.stft(
|
||||||
|
input,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
win_length=self.win_length,
|
||||||
|
hop_length=self.hop_length,
|
||||||
|
center=False,
|
||||||
|
pad_mode=self.pad_mode,
|
||||||
|
normalized=self.normalized,
|
||||||
|
onesided=self.onesided,
|
||||||
|
return_complex=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# output: (Batch, Freq, Frames, 2=real_imag)
|
||||||
|
# -> (Batch, Frames, Freq, 2=real_imag)
|
||||||
|
output = output.transpose(1, 2)
|
||||||
|
if multi_channel:
|
||||||
|
# output: (Batch * Channel, Frames, Freq, 2=real_imag)
|
||||||
|
# -> (Batch, Frame, Channel, Freq, 2=real_imag)
|
||||||
|
output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
|
||||||
|
1, 2
|
||||||
|
)
|
||||||
|
|
||||||
|
if ilens is not None:
|
||||||
|
if self.center:
|
||||||
|
pad = self.win_length // 2
|
||||||
|
ilens = ilens + 2 * pad
|
||||||
|
olens = torch.div(ilens - self.win_length, self.hop_length, rounding_mode='floor') + 1
|
||||||
|
# olens = ilens - self.win_length // self.hop_length + 1
|
||||||
|
output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
|
||||||
|
else:
|
||||||
|
olens = None
|
||||||
|
|
||||||
|
return output, olens
|
||||||
82
ppg_extractor/utterance_mvn.py
Normal file
82
ppg_extractor/utterance_mvn.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .nets_utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
class UtteranceMVN(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_means = norm_means
|
||||||
|
self.norm_vars = norm_vars
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, ilens: torch.Tensor = None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Forward function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (B, L, ...)
|
||||||
|
ilens: (B,)
|
||||||
|
|
||||||
|
"""
|
||||||
|
return utterance_mvn(
|
||||||
|
x,
|
||||||
|
ilens,
|
||||||
|
norm_means=self.norm_means,
|
||||||
|
norm_vars=self.norm_vars,
|
||||||
|
eps=self.eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def utterance_mvn(
|
||||||
|
x: torch.Tensor,
|
||||||
|
ilens: torch.Tensor = None,
|
||||||
|
norm_means: bool = True,
|
||||||
|
norm_vars: bool = False,
|
||||||
|
eps: float = 1.0e-20,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Apply utterance mean and variance normalization
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (B, T, D), assumed zero padded
|
||||||
|
ilens: (B,)
|
||||||
|
norm_means:
|
||||||
|
norm_vars:
|
||||||
|
eps:
|
||||||
|
|
||||||
|
"""
|
||||||
|
if ilens is None:
|
||||||
|
ilens = x.new_full([x.size(0)], x.size(1))
|
||||||
|
ilens_ = ilens.to(x.device, x.dtype).view(-1, *[1 for _ in range(x.dim() - 1)])
|
||||||
|
# Zero padding
|
||||||
|
if x.requires_grad:
|
||||||
|
x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
|
||||||
|
else:
|
||||||
|
x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
|
||||||
|
# mean: (B, 1, D)
|
||||||
|
mean = x.sum(dim=1, keepdim=True) / ilens_
|
||||||
|
|
||||||
|
if norm_means:
|
||||||
|
x -= mean
|
||||||
|
|
||||||
|
if norm_vars:
|
||||||
|
var = x.pow(2).sum(dim=1, keepdim=True) / ilens_
|
||||||
|
std = torch.clamp(var.sqrt(), min=eps)
|
||||||
|
x = x / std.sqrt()
|
||||||
|
return x, ilens
|
||||||
|
else:
|
||||||
|
if norm_vars:
|
||||||
|
y = x - mean
|
||||||
|
y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0)
|
||||||
|
var = y.pow(2).sum(dim=1, keepdim=True) / ilens_
|
||||||
|
std = torch.clamp(var.sqrt(), min=eps)
|
||||||
|
x /= std
|
||||||
|
return x, ilens
|
||||||
5
pre.py
5
pre.py
@@ -12,7 +12,8 @@ import argparse
|
|||||||
recognized_datasets = [
|
recognized_datasets = [
|
||||||
"aidatatang_200zh",
|
"aidatatang_200zh",
|
||||||
"magicdata",
|
"magicdata",
|
||||||
"aishell3"
|
"aishell3",
|
||||||
|
"data_aishell"
|
||||||
]
|
]
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -40,7 +41,7 @@ if __name__ == "__main__":
|
|||||||
"Use this option when dataset does not include alignments\
|
"Use this option when dataset does not include alignments\
|
||||||
(these are used to split long audio files into sub-utterances.)")
|
(these are used to split long audio files into sub-utterances.)")
|
||||||
parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\
|
parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\
|
||||||
"Name of the dataset to process, allowing values: magicdata, aidatatang_200zh, aishell3.")
|
"Name of the dataset to process, allowing values: magicdata, aidatatang_200zh, aishell3, data_aishell.")
|
||||||
parser.add_argument("-e", "--encoder_model_fpath", type=Path, default="encoder/saved_models/pretrained.pt", help=\
|
parser.add_argument("-e", "--encoder_model_fpath", type=Path, default="encoder/saved_models/pretrained.pt", help=\
|
||||||
"Path your trained encoder model.")
|
"Path your trained encoder model.")
|
||||||
parser.add_argument("-ne", "--n_processes_embed", type=int, default=1, help=\
|
parser.add_argument("-ne", "--n_processes_embed", type=int, default=1, help=\
|
||||||
|
|||||||
49
pre4ppg.py
Normal file
49
pre4ppg.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from ppg2mel.preprocess import preprocess_dataset
|
||||||
|
from pathlib import Path
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
recognized_datasets = [
|
||||||
|
"aidatatang_200zh",
|
||||||
|
"aidatatang_200zh_s", # sample
|
||||||
|
]
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Preprocesses audio files from datasets, to be used by the "
|
||||||
|
"ppg2mel model for training.",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
parser.add_argument("datasets_root", type=Path, help=\
|
||||||
|
"Path to the directory containing your datasets.")
|
||||||
|
parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\
|
||||||
|
"Name of the dataset to process, allowing values: aidatatang_200zh.")
|
||||||
|
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
|
||||||
|
"Path to the output directory that will contain the mel spectrograms, the audios and the "
|
||||||
|
"embeds. Defaults to <datasets_root>/PPGVC/ppg2mel/")
|
||||||
|
parser.add_argument("-n", "--n_processes", type=int, default=8, help=\
|
||||||
|
"Number of processes in parallel.")
|
||||||
|
# parser.add_argument("-s", "--skip_existing", action="store_true", help=\
|
||||||
|
# "Whether to overwrite existing files with the same name. Useful if the preprocessing was "
|
||||||
|
# "interrupted. ")
|
||||||
|
# parser.add_argument("--hparams", type=str, default="", help=\
|
||||||
|
# "Hyperparameter overrides as a comma-separated list of name-value pairs")
|
||||||
|
# parser.add_argument("--no_trim", action="store_true", help=\
|
||||||
|
# "Preprocess audio without trimming silences (not recommended).")
|
||||||
|
parser.add_argument("-pf", "--ppg_encoder_model_fpath", type=Path, default="ppg_extractor/saved_models/24epoch.pt", help=\
|
||||||
|
"Path your trained ppg encoder model.")
|
||||||
|
parser.add_argument("-sf", "--speaker_encoder_model", type=Path, default="encoder/saved_models/pretrained_bak_5805000.pt", help=\
|
||||||
|
"Path your trained speaker encoder model.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
assert args.dataset in recognized_datasets, 'is not supported, file a issue to propose a new one'
|
||||||
|
|
||||||
|
# Create directories
|
||||||
|
assert args.datasets_root.exists()
|
||||||
|
if not hasattr(args, "out_dir"):
|
||||||
|
args.out_dir = args.datasets_root.joinpath("PPGVC", "ppg2mel")
|
||||||
|
args.out_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
preprocess_dataset(**vars(args))
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
umap-learn
|
umap-learn
|
||||||
visdom
|
visdom
|
||||||
librosa>=0.8.0
|
librosa==0.8.1
|
||||||
matplotlib>=3.3.0
|
matplotlib>=3.3.0
|
||||||
numpy==1.19.3; platform_system == "Windows"
|
numpy==1.19.3; platform_system == "Windows"
|
||||||
numpy==1.19.4; platform_system != "Windows"
|
numpy==1.19.4; platform_system != "Windows"
|
||||||
@@ -17,6 +17,11 @@ webrtcvad; platform_system != "Windows"
|
|||||||
pypinyin
|
pypinyin
|
||||||
flask
|
flask
|
||||||
flask_wtf
|
flask_wtf
|
||||||
flask_cors
|
flask_cors==3.0.10
|
||||||
gevent==21.8.0
|
gevent==21.8.0
|
||||||
flask_restx
|
flask_restx
|
||||||
|
tensorboard
|
||||||
|
streamlit==1.8.0
|
||||||
|
PyYAML==5.4.1
|
||||||
|
torch_complex
|
||||||
|
espnet
|
||||||
142
run.py
Normal file
142
run.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
import time
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import glob
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
from ppg_extractor import load_model
|
||||||
|
import librosa
|
||||||
|
import soundfile as sf
|
||||||
|
from utils.load_yaml import HpsYaml
|
||||||
|
|
||||||
|
from encoder.audio import preprocess_wav
|
||||||
|
from encoder import inference as speacker_encoder
|
||||||
|
from vocoder.hifigan import inference as vocoder
|
||||||
|
from ppg2mel import MelDecoderMOLv2
|
||||||
|
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
|
||||||
|
|
||||||
|
|
||||||
|
def _build_ppg2mel_model(model_config, model_file, device):
|
||||||
|
ppg2mel_model = MelDecoderMOLv2(
|
||||||
|
**model_config["model"]
|
||||||
|
).to(device)
|
||||||
|
ckpt = torch.load(model_file, map_location=device)
|
||||||
|
ppg2mel_model.load_state_dict(ckpt["model"])
|
||||||
|
ppg2mel_model.eval()
|
||||||
|
return ppg2mel_model
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert(args):
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
output_dir = args.output_dir
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
step = os.path.basename(args.ppg2mel_model_file)[:-4].split("_")[-1]
|
||||||
|
|
||||||
|
# Build models
|
||||||
|
print("Load PPG-model, PPG2Mel-model, Vocoder-model...")
|
||||||
|
ppg_model = load_model(
|
||||||
|
Path('./ppg_extractor/saved_models/24epoch.pt'),
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
ppg2mel_model = _build_ppg2mel_model(HpsYaml(args.ppg2mel_model_train_config), args.ppg2mel_model_file, device)
|
||||||
|
# vocoder.load_model('./vocoder/saved_models/pretrained/g_hifigan.pt', "./vocoder/hifigan/config_16k_.json")
|
||||||
|
vocoder.load_model('./vocoder/saved_models/24k/g_02830000.pt')
|
||||||
|
# Data related
|
||||||
|
ref_wav_path = args.ref_wav_path
|
||||||
|
ref_wav = preprocess_wav(ref_wav_path)
|
||||||
|
ref_fid = os.path.basename(ref_wav_path)[:-4]
|
||||||
|
|
||||||
|
# TODO: specify encoder
|
||||||
|
speacker_encoder.load_model(Path("encoder/saved_models/pretrained_bak_5805000.pt"))
|
||||||
|
ref_spk_dvec = speacker_encoder.embed_utterance(ref_wav)
|
||||||
|
ref_spk_dvec = torch.from_numpy(ref_spk_dvec).unsqueeze(0).to(device)
|
||||||
|
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
|
||||||
|
|
||||||
|
source_file_list = sorted(glob.glob(f"{args.wav_dir}/*.wav"))
|
||||||
|
print(f"Number of source utterances: {len(source_file_list)}.")
|
||||||
|
|
||||||
|
total_rtf = 0.0
|
||||||
|
cnt = 0
|
||||||
|
for src_wav_path in tqdm(source_file_list):
|
||||||
|
# Load the audio to a numpy array:
|
||||||
|
src_wav, _ = librosa.load(src_wav_path, sr=16000)
|
||||||
|
src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(device)
|
||||||
|
src_wav_lengths = torch.LongTensor([len(src_wav)]).to(device)
|
||||||
|
ppg = ppg_model(src_wav_tensor, src_wav_lengths)
|
||||||
|
|
||||||
|
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
|
||||||
|
min_len = min(ppg.shape[1], len(lf0_uv))
|
||||||
|
|
||||||
|
ppg = ppg[:, :min_len]
|
||||||
|
lf0_uv = lf0_uv[:min_len]
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
_, mel_pred, att_ws = ppg2mel_model.inference(
|
||||||
|
ppg,
|
||||||
|
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
|
||||||
|
spembs=ref_spk_dvec,
|
||||||
|
)
|
||||||
|
src_fid = os.path.basename(src_wav_path)[:-4]
|
||||||
|
wav_fname = f"{output_dir}/vc_{src_fid}_ref_{ref_fid}_step{step}.wav"
|
||||||
|
mel_len = mel_pred.shape[0]
|
||||||
|
rtf = (time.time() - start) / (0.01 * mel_len)
|
||||||
|
total_rtf += rtf
|
||||||
|
cnt += 1
|
||||||
|
# continue
|
||||||
|
mel_pred= mel_pred.transpose(0, 1)
|
||||||
|
y, output_sample_rate = vocoder.infer_waveform(mel_pred.cpu())
|
||||||
|
sf.write(wav_fname, y.squeeze(), output_sample_rate, "PCM_16")
|
||||||
|
|
||||||
|
print("RTF:")
|
||||||
|
print(total_rtf / cnt)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(description="Conversion from wave input")
|
||||||
|
parser.add_argument(
|
||||||
|
"--wav_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="Source wave directory.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ref_wav_path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Reference wave file path.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ppg2mel_model_train_config", "-c",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="Training config file (yaml file)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ppg2mel_model_file", "-m",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="ppg2mel model checkpoint file path"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir", "-o",
|
||||||
|
type=str,
|
||||||
|
default="vc_gens_vctk_oneshot",
|
||||||
|
help="Output folder to save the converted wave."
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert(args)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
BIN
samples/T0055G0013S0005.wav
Normal file
BIN
samples/T0055G0013S0005.wav
Normal file
Binary file not shown.
@@ -167,7 +167,7 @@ def _mel_to_linear(mel_spectrogram, hparams):
|
|||||||
|
|
||||||
def _build_mel_basis(hparams):
|
def _build_mel_basis(hparams):
|
||||||
assert hparams.fmax <= hparams.sample_rate // 2
|
assert hparams.fmax <= hparams.sample_rate // 2
|
||||||
return librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels,
|
return librosa.filters.mel(sr=hparams.sample_rate, n_fft=hparams.n_fft, n_mels=hparams.num_mels,
|
||||||
fmin=hparams.fmin, fmax=hparams.fmax)
|
fmin=hparams.fmin, fmax=hparams.fmax)
|
||||||
|
|
||||||
def _amp_to_db(x, hparams):
|
def _amp_to_db(x, hparams):
|
||||||
|
|||||||
13
synthesizer/gst_hyperparameters.py
Normal file
13
synthesizer/gst_hyperparameters.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
class GSTHyperparameters():
|
||||||
|
E = 512
|
||||||
|
|
||||||
|
# reference encoder
|
||||||
|
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
||||||
|
|
||||||
|
# style token layer
|
||||||
|
token_num = 10
|
||||||
|
# token_emb_size = 256
|
||||||
|
num_heads = 8
|
||||||
|
|
||||||
|
n_mels = 256 # Number of Mel banks to generate
|
||||||
|
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import ast
|
import ast
|
||||||
import pprint
|
import pprint
|
||||||
|
import json
|
||||||
|
|
||||||
class HParams(object):
|
class HParams(object):
|
||||||
def __init__(self, **kwargs): self.__dict__.update(kwargs)
|
def __init__(self, **kwargs): self.__dict__.update(kwargs)
|
||||||
@@ -18,6 +19,19 @@ class HParams(object):
|
|||||||
self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
|
self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def loadJson(self, dict):
|
||||||
|
print("\Loading the json with %s\n", dict)
|
||||||
|
for k in dict.keys():
|
||||||
|
if k not in ["tts_schedule", "tts_finetune_layers"]:
|
||||||
|
self.__dict__[k] = dict[k]
|
||||||
|
return self
|
||||||
|
|
||||||
|
def dumpJson(self, fp):
|
||||||
|
print("\Saving the json with %s\n", fp)
|
||||||
|
with fp.open("w", encoding="utf-8") as f:
|
||||||
|
json.dump(self.__dict__, f)
|
||||||
|
return self
|
||||||
|
|
||||||
hparams = HParams(
|
hparams = HParams(
|
||||||
### Signal Processing (used in both synthesizer and vocoder)
|
### Signal Processing (used in both synthesizer and vocoder)
|
||||||
sample_rate = 16000,
|
sample_rate = 16000,
|
||||||
@@ -49,19 +63,24 @@ hparams = HParams(
|
|||||||
# frame that has all values < -3.4
|
# frame that has all values < -3.4
|
||||||
|
|
||||||
### Tacotron Training
|
### Tacotron Training
|
||||||
tts_schedule = [(2, 1e-3, 20_000, 24), # Progressive training schedule
|
tts_schedule = [(2, 1e-3, 10_000, 12), # Progressive training schedule
|
||||||
(2, 5e-4, 40_000, 24), # (r, lr, step, batch_size)
|
(2, 5e-4, 15_000, 12), # (r, lr, step, batch_size)
|
||||||
(2, 2e-4, 80_000, 24), #
|
(2, 2e-4, 20_000, 12), # (r, lr, step, batch_size)
|
||||||
(2, 1e-4, 160_000, 24), # r = reduction factor (# of mel frames
|
(2, 1e-4, 30_000, 12), #
|
||||||
(2, 3e-5, 320_000, 24), # synthesized for each decoder iteration)
|
(2, 5e-5, 40_000, 12), #
|
||||||
(2, 1e-5, 640_000, 24)], # lr = learning rate
|
(2, 1e-5, 60_000, 12), #
|
||||||
|
(2, 5e-6, 160_000, 12), # r = reduction factor (# of mel frames
|
||||||
|
(2, 3e-6, 320_000, 12), # synthesized for each decoder iteration)
|
||||||
|
(2, 1e-6, 640_000, 12)], # lr = learning rate
|
||||||
|
|
||||||
tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
|
tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
|
||||||
tts_eval_interval = 500, # Number of steps between model evaluation (sample generation)
|
tts_eval_interval = 500, # Number of steps between model evaluation (sample generation)
|
||||||
# Set to -1 to generate after completing epoch, or 0 to disable
|
# Set to -1 to generate after completing epoch, or 0 to disable
|
||||||
|
|
||||||
tts_eval_num_samples = 1, # Makes this number of samples
|
tts_eval_num_samples = 1, # Makes this number of samples
|
||||||
|
|
||||||
|
## For finetune usage, if set, only selected layers will be trained, available: encoder,encoder_proj,gst,decoder,postnet,post_proj
|
||||||
|
tts_finetune_layers = [],
|
||||||
|
|
||||||
### Data Preprocessing
|
### Data Preprocessing
|
||||||
max_mel_frames = 900,
|
max_mel_frames = 900,
|
||||||
rescale = True,
|
rescale = True,
|
||||||
@@ -86,4 +105,6 @@ hparams = HParams(
|
|||||||
speaker_embedding_size = 256, # Dimension for the speaker embedding
|
speaker_embedding_size = 256, # Dimension for the speaker embedding
|
||||||
silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
|
silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
|
||||||
utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded
|
utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded
|
||||||
|
use_gst = True, # Whether to use global style token
|
||||||
|
use_ser_for_gst = True, # Whether to use speaker embedding referenced for global style token
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from typing import Union, List
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import librosa
|
import librosa
|
||||||
from utils import logmmse
|
from utils import logmmse
|
||||||
|
import json
|
||||||
from pypinyin import lazy_pinyin, Style
|
from pypinyin import lazy_pinyin, Style
|
||||||
|
|
||||||
class Synthesizer:
|
class Synthesizer:
|
||||||
@@ -44,6 +45,11 @@ class Synthesizer:
|
|||||||
return self._model is not None
|
return self._model is not None
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
|
# Try to scan config file
|
||||||
|
model_config_fpaths = list(self.model_fpath.parent.rglob("*.json"))
|
||||||
|
if len(model_config_fpaths)>0 and model_config_fpaths[0].exists():
|
||||||
|
with model_config_fpaths[0].open("r", encoding="utf-8") as f:
|
||||||
|
hparams.loadJson(json.load(f))
|
||||||
"""
|
"""
|
||||||
Instantiates and loads the model given the weights file that was passed in the constructor.
|
Instantiates and loads the model given the weights file that was passed in the constructor.
|
||||||
"""
|
"""
|
||||||
@@ -62,7 +68,7 @@ class Synthesizer:
|
|||||||
stop_threshold=hparams.tts_stop_threshold,
|
stop_threshold=hparams.tts_stop_threshold,
|
||||||
speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
|
speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
|
||||||
|
|
||||||
self._model.load(self.model_fpath)
|
self._model.load(self.model_fpath, self.device)
|
||||||
self._model.eval()
|
self._model.eval()
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
@@ -70,7 +76,7 @@ class Synthesizer:
|
|||||||
|
|
||||||
def synthesize_spectrograms(self, texts: List[str],
|
def synthesize_spectrograms(self, texts: List[str],
|
||||||
embeddings: Union[np.ndarray, List[np.ndarray]],
|
embeddings: Union[np.ndarray, List[np.ndarray]],
|
||||||
return_alignments=False):
|
return_alignments=False, style_idx=0, min_stop_token=5, steps=2000):
|
||||||
"""
|
"""
|
||||||
Synthesizes mel spectrograms from texts and speaker embeddings.
|
Synthesizes mel spectrograms from texts and speaker embeddings.
|
||||||
|
|
||||||
@@ -125,7 +131,7 @@ class Synthesizer:
|
|||||||
speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
|
speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
_, mels, alignments = self._model.generate(chars, speaker_embeddings)
|
_, mels, alignments = self._model.generate(chars, speaker_embeddings, style_idx=style_idx, min_stop_token=min_stop_token, steps=steps)
|
||||||
mels = mels.detach().cpu().numpy()
|
mels = mels.detach().cpu().numpy()
|
||||||
for m in mels:
|
for m in mels:
|
||||||
# Trim silence from end of each spectrogram
|
# Trim silence from end of each spectrogram
|
||||||
@@ -143,7 +149,7 @@ class Synthesizer:
|
|||||||
Loads and preprocesses an audio file under the same conditions the audio files were used to
|
Loads and preprocesses an audio file under the same conditions the audio files were used to
|
||||||
train the synthesizer.
|
train the synthesizer.
|
||||||
"""
|
"""
|
||||||
wav = librosa.load(str(fpath), hparams.sample_rate)[0]
|
wav = librosa.load(path=str(fpath), sr=hparams.sample_rate)[0]
|
||||||
if hparams.rescale:
|
if hparams.rescale:
|
||||||
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
||||||
# denoise
|
# denoise
|
||||||
|
|||||||
145
synthesizer/models/global_style_token.py
Normal file
145
synthesizer/models/global_style_token.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.init as init
|
||||||
|
import torch.nn.functional as tFunctional
|
||||||
|
from synthesizer.gst_hyperparameters import GSTHyperparameters as hp
|
||||||
|
from synthesizer.hparams import hparams
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalStyleToken(nn.Module):
|
||||||
|
"""
|
||||||
|
inputs: style mel spectrograms [batch_size, num_spec_frames, num_mel]
|
||||||
|
speaker_embedding: speaker mel spectrograms [batch_size, num_spec_frames, num_mel]
|
||||||
|
outputs: [batch_size, embedding_dim]
|
||||||
|
"""
|
||||||
|
def __init__(self, speaker_embedding_dim=None):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = ReferenceEncoder()
|
||||||
|
self.stl = STL(speaker_embedding_dim)
|
||||||
|
|
||||||
|
def forward(self, inputs, speaker_embedding=None):
|
||||||
|
enc_out = self.encoder(inputs)
|
||||||
|
# concat speaker_embedding according to https://github.com/mozilla/TTS/blob/master/TTS/tts/layers/gst_layers.py
|
||||||
|
if hparams.use_ser_for_gst and speaker_embedding is not None:
|
||||||
|
enc_out = torch.cat([enc_out, speaker_embedding], dim=-1)
|
||||||
|
style_embed = self.stl(enc_out)
|
||||||
|
|
||||||
|
return style_embed
|
||||||
|
|
||||||
|
|
||||||
|
class ReferenceEncoder(nn.Module):
|
||||||
|
'''
|
||||||
|
inputs --- [N, Ty/r, n_mels*r] mels
|
||||||
|
outputs --- [N, ref_enc_gru_size]
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
K = len(hp.ref_enc_filters)
|
||||||
|
filters = [1] + hp.ref_enc_filters
|
||||||
|
convs = [nn.Conv2d(in_channels=filters[i],
|
||||||
|
out_channels=filters[i + 1],
|
||||||
|
kernel_size=(3, 3),
|
||||||
|
stride=(2, 2),
|
||||||
|
padding=(1, 1)) for i in range(K)]
|
||||||
|
self.convs = nn.ModuleList(convs)
|
||||||
|
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=hp.ref_enc_filters[i]) for i in range(K)])
|
||||||
|
|
||||||
|
out_channels = self.calculate_channels(hp.n_mels, 3, 2, 1, K)
|
||||||
|
self.gru = nn.GRU(input_size=hp.ref_enc_filters[-1] * out_channels,
|
||||||
|
hidden_size=hp.E // 2,
|
||||||
|
batch_first=True)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
N = inputs.size(0)
|
||||||
|
out = inputs.view(N, 1, -1, hp.n_mels) # [N, 1, Ty, n_mels]
|
||||||
|
for conv, bn in zip(self.convs, self.bns):
|
||||||
|
out = conv(out)
|
||||||
|
out = bn(out)
|
||||||
|
out = tFunctional.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
||||||
|
|
||||||
|
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
||||||
|
T = out.size(1)
|
||||||
|
N = out.size(0)
|
||||||
|
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
||||||
|
|
||||||
|
self.gru.flatten_parameters()
|
||||||
|
memory, out = self.gru(out) # out --- [1, N, E//2]
|
||||||
|
|
||||||
|
return out.squeeze(0)
|
||||||
|
|
||||||
|
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
|
||||||
|
for i in range(n_convs):
|
||||||
|
L = (L - kernel_size + 2 * pad) // stride + 1
|
||||||
|
return L
|
||||||
|
|
||||||
|
|
||||||
|
class STL(nn.Module):
|
||||||
|
'''
|
||||||
|
inputs --- [N, E//2]
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, speaker_embedding_dim=None):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.embed = nn.Parameter(torch.FloatTensor(hp.token_num, hp.E // hp.num_heads))
|
||||||
|
d_q = hp.E // 2
|
||||||
|
d_k = hp.E // hp.num_heads
|
||||||
|
# self.attention = MultiHeadAttention(hp.num_heads, d_model, d_q, d_v)
|
||||||
|
if hparams.use_ser_for_gst and speaker_embedding_dim is not None:
|
||||||
|
d_q += speaker_embedding_dim
|
||||||
|
self.attention = MultiHeadAttention(query_dim=d_q, key_dim=d_k, num_units=hp.E, num_heads=hp.num_heads)
|
||||||
|
|
||||||
|
init.normal_(self.embed, mean=0, std=0.5)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
N = inputs.size(0)
|
||||||
|
query = inputs.unsqueeze(1) # [N, 1, E//2]
|
||||||
|
keys = torch.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads]
|
||||||
|
style_embed = self.attention(query, keys)
|
||||||
|
|
||||||
|
return style_embed
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
'''
|
||||||
|
input:
|
||||||
|
query --- [N, T_q, query_dim]
|
||||||
|
key --- [N, T_k, key_dim]
|
||||||
|
output:
|
||||||
|
out --- [N, T_q, num_units]
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, query_dim, key_dim, num_units, num_heads):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.num_units = num_units
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.key_dim = key_dim
|
||||||
|
|
||||||
|
self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
|
||||||
|
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||||
|
self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||||
|
|
||||||
|
def forward(self, query, key):
|
||||||
|
querys = self.W_query(query) # [N, T_q, num_units]
|
||||||
|
keys = self.W_key(key) # [N, T_k, num_units]
|
||||||
|
values = self.W_value(key)
|
||||||
|
|
||||||
|
split_size = self.num_units // self.num_heads
|
||||||
|
querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h]
|
||||||
|
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||||
|
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||||
|
|
||||||
|
# score = softmax(QK^T / (d_k ** 0.5))
|
||||||
|
scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
|
||||||
|
scores = scores / (self.key_dim ** 0.5)
|
||||||
|
scores = tFunctional.softmax(scores, dim=3)
|
||||||
|
|
||||||
|
# out = score * V
|
||||||
|
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
|
||||||
|
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
|
||||||
|
|
||||||
|
return out
|
||||||
@@ -3,8 +3,9 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from pathlib import Path
|
from synthesizer.models.global_style_token import GlobalStyleToken
|
||||||
from typing import Union
|
from synthesizer.gst_hyperparameters import GSTHyperparameters as gst_hp
|
||||||
|
from synthesizer.hparams import hparams
|
||||||
|
|
||||||
|
|
||||||
class HighwayNetwork(nn.Module):
|
class HighwayNetwork(nn.Module):
|
||||||
@@ -60,7 +61,7 @@ class Encoder(nn.Module):
|
|||||||
idx = 1
|
idx = 1
|
||||||
|
|
||||||
# Start by making a copy of each speaker embedding to match the input text length
|
# Start by making a copy of each speaker embedding to match the input text length
|
||||||
# The output of this has size (batch_size, num_chars * tts_embed_dims)
|
# The output of this has size (batch_size, num_chars * speaker_embedding_size)
|
||||||
speaker_embedding_size = speaker_embedding.size()[idx]
|
speaker_embedding_size = speaker_embedding.size()[idx]
|
||||||
e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
|
e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
|
||||||
|
|
||||||
@@ -126,7 +127,7 @@ class CBHG(nn.Module):
|
|||||||
# Although we `_flatten_parameters()` on init, when using DataParallel
|
# Although we `_flatten_parameters()` on init, when using DataParallel
|
||||||
# the model gets replicated, making it no longer guaranteed that the
|
# the model gets replicated, making it no longer guaranteed that the
|
||||||
# weights are contiguous in GPU memory. Hence, we must call it again
|
# weights are contiguous in GPU memory. Hence, we must call it again
|
||||||
self._flatten_parameters()
|
self.rnn.flatten_parameters()
|
||||||
|
|
||||||
# Save these for later
|
# Save these for later
|
||||||
residual = x
|
residual = x
|
||||||
@@ -213,7 +214,7 @@ class LSA(nn.Module):
|
|||||||
self.attention = None
|
self.attention = None
|
||||||
|
|
||||||
def init_attention(self, encoder_seq_proj):
|
def init_attention(self, encoder_seq_proj):
|
||||||
device = next(self.parameters()).device # use same device as parameters
|
device = encoder_seq_proj.device # use same device as parameters
|
||||||
b, t, c = encoder_seq_proj.size()
|
b, t, c = encoder_seq_proj.size()
|
||||||
self.cumulative = torch.zeros(b, t, device=device)
|
self.cumulative = torch.zeros(b, t, device=device)
|
||||||
self.attention = torch.zeros(b, t, device=device)
|
self.attention = torch.zeros(b, t, device=device)
|
||||||
@@ -255,16 +256,17 @@ class Decoder(nn.Module):
|
|||||||
self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
|
self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
|
||||||
dropout=dropout)
|
dropout=dropout)
|
||||||
self.attn_net = LSA(decoder_dims)
|
self.attn_net = LSA(decoder_dims)
|
||||||
|
if hparams.use_gst:
|
||||||
|
speaker_embedding_size += gst_hp.E
|
||||||
self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
|
self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
|
||||||
self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
|
self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
|
||||||
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
|
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
|
||||||
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
|
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
|
||||||
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
|
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
|
||||||
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
|
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
|
||||||
|
|
||||||
def zoneout(self, prev, current, p=0.1):
|
def zoneout(self, prev, current, device, p=0.1):
|
||||||
device = next(self.parameters()).device # Use same device as parameters
|
mask = torch.zeros(prev.size(),device=device).bernoulli_(p)
|
||||||
mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
|
|
||||||
return prev * mask + current * (1 - mask)
|
return prev * mask + current * (1 - mask)
|
||||||
|
|
||||||
def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
|
def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
|
||||||
@@ -272,7 +274,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
# Need this for reshaping mels
|
# Need this for reshaping mels
|
||||||
batch_size = encoder_seq.size(0)
|
batch_size = encoder_seq.size(0)
|
||||||
|
device = encoder_seq.device
|
||||||
# Unpack the hidden and cell states
|
# Unpack the hidden and cell states
|
||||||
attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
|
attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
|
||||||
rnn1_cell, rnn2_cell = cell_states
|
rnn1_cell, rnn2_cell = cell_states
|
||||||
@@ -298,7 +300,7 @@ class Decoder(nn.Module):
|
|||||||
# Compute first Residual RNN
|
# Compute first Residual RNN
|
||||||
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
|
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
|
||||||
if self.training:
|
if self.training:
|
||||||
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
|
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next,device=device)
|
||||||
else:
|
else:
|
||||||
rnn1_hidden = rnn1_hidden_next
|
rnn1_hidden = rnn1_hidden_next
|
||||||
x = x + rnn1_hidden
|
x = x + rnn1_hidden
|
||||||
@@ -306,7 +308,7 @@ class Decoder(nn.Module):
|
|||||||
# Compute second Residual RNN
|
# Compute second Residual RNN
|
||||||
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
|
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
|
||||||
if self.training:
|
if self.training:
|
||||||
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
|
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next, device=device)
|
||||||
else:
|
else:
|
||||||
rnn2_hidden = rnn2_hidden_next
|
rnn2_hidden = rnn2_hidden_next
|
||||||
x = x + rnn2_hidden
|
x = x + rnn2_hidden
|
||||||
@@ -337,7 +339,12 @@ class Tacotron(nn.Module):
|
|||||||
self.speaker_embedding_size = speaker_embedding_size
|
self.speaker_embedding_size = speaker_embedding_size
|
||||||
self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
|
self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
|
||||||
encoder_K, num_highways, dropout)
|
encoder_K, num_highways, dropout)
|
||||||
self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
|
project_dims = encoder_dims + speaker_embedding_size
|
||||||
|
if hparams.use_gst:
|
||||||
|
project_dims += gst_hp.E
|
||||||
|
self.encoder_proj = nn.Linear(project_dims, decoder_dims, bias=False)
|
||||||
|
if hparams.use_gst:
|
||||||
|
self.gst = GlobalStyleToken(speaker_embedding_size)
|
||||||
self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
|
self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
|
||||||
dropout, speaker_embedding_size)
|
dropout, speaker_embedding_size)
|
||||||
self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
|
self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
|
||||||
@@ -357,12 +364,19 @@ class Tacotron(nn.Module):
|
|||||||
@r.setter
|
@r.setter
|
||||||
def r(self, value):
|
def r(self, value):
|
||||||
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _concat_speaker_embedding(outputs, speaker_embeddings):
|
||||||
|
speaker_embeddings_ = speaker_embeddings.expand(
|
||||||
|
outputs.size(0), outputs.size(1), -1)
|
||||||
|
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
|
||||||
|
return outputs
|
||||||
|
|
||||||
def forward(self, x, m, speaker_embedding):
|
def forward(self, texts, mels, speaker_embedding):
|
||||||
device = next(self.parameters()).device # use same device as parameters
|
device = texts.device # use same device as parameters
|
||||||
|
|
||||||
self.step += 1
|
self.step += 1
|
||||||
batch_size, _, steps = m.size()
|
batch_size, _, steps = mels.size()
|
||||||
|
|
||||||
# Initialise all hidden states and pack into tuple
|
# Initialise all hidden states and pack into tuple
|
||||||
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
||||||
@@ -379,11 +393,20 @@ class Tacotron(nn.Module):
|
|||||||
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
||||||
|
|
||||||
# Need an initial context vector
|
# Need an initial context vector
|
||||||
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
|
size = self.encoder_dims + self.speaker_embedding_size
|
||||||
|
if hparams.use_gst:
|
||||||
|
size += gst_hp.E
|
||||||
|
context_vec = torch.zeros(batch_size, size, device=device)
|
||||||
|
|
||||||
# SV2TTS: Run the encoder with the speaker embedding
|
# SV2TTS: Run the encoder with the speaker embedding
|
||||||
# The projection avoids unnecessary matmuls in the decoder loop
|
# The projection avoids unnecessary matmuls in the decoder loop
|
||||||
encoder_seq = self.encoder(x, speaker_embedding)
|
encoder_seq = self.encoder(texts, speaker_embedding)
|
||||||
|
# put after encoder
|
||||||
|
if hparams.use_gst and self.gst is not None:
|
||||||
|
style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced
|
||||||
|
# style_embed = style_embed.expand_as(encoder_seq)
|
||||||
|
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
||||||
|
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
||||||
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
||||||
|
|
||||||
# Need a couple of lists for outputs
|
# Need a couple of lists for outputs
|
||||||
@@ -391,10 +414,10 @@ class Tacotron(nn.Module):
|
|||||||
|
|
||||||
# Run the decoder loop
|
# Run the decoder loop
|
||||||
for t in range(0, steps, self.r):
|
for t in range(0, steps, self.r):
|
||||||
prenet_in = m[:, :, t - 1] if t > 0 else go_frame
|
prenet_in = mels[:, :, t - 1] if t > 0 else go_frame
|
||||||
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
||||||
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
||||||
hidden_states, cell_states, context_vec, t, x)
|
hidden_states, cell_states, context_vec, t, texts)
|
||||||
mel_outputs.append(mel_frames)
|
mel_outputs.append(mel_frames)
|
||||||
attn_scores.append(scores)
|
attn_scores.append(scores)
|
||||||
stop_outputs.extend([stop_tokens] * self.r)
|
stop_outputs.extend([stop_tokens] * self.r)
|
||||||
@@ -414,9 +437,9 @@ class Tacotron(nn.Module):
|
|||||||
|
|
||||||
return mel_outputs, linear, attn_scores, stop_outputs
|
return mel_outputs, linear, attn_scores, stop_outputs
|
||||||
|
|
||||||
def generate(self, x, speaker_embedding=None, steps=2000):
|
def generate(self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5):
|
||||||
self.eval()
|
self.eval()
|
||||||
device = next(self.parameters()).device # use same device as parameters
|
device = x.device # use same device as parameters
|
||||||
|
|
||||||
batch_size, _ = x.size()
|
batch_size, _ = x.size()
|
||||||
|
|
||||||
@@ -435,11 +458,30 @@ class Tacotron(nn.Module):
|
|||||||
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
||||||
|
|
||||||
# Need an initial context vector
|
# Need an initial context vector
|
||||||
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
|
size = self.encoder_dims + self.speaker_embedding_size
|
||||||
|
if hparams.use_gst:
|
||||||
|
size += gst_hp.E
|
||||||
|
context_vec = torch.zeros(batch_size, size, device=device)
|
||||||
|
|
||||||
# SV2TTS: Run the encoder with the speaker embedding
|
# SV2TTS: Run the encoder with the speaker embedding
|
||||||
# The projection avoids unnecessary matmuls in the decoder loop
|
# The projection avoids unnecessary matmuls in the decoder loop
|
||||||
encoder_seq = self.encoder(x, speaker_embedding)
|
encoder_seq = self.encoder(x, speaker_embedding)
|
||||||
|
|
||||||
|
# put after encoder
|
||||||
|
if hparams.use_gst and self.gst is not None:
|
||||||
|
if style_idx >= 0 and style_idx < 10:
|
||||||
|
query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
|
||||||
|
if device.type == 'cuda':
|
||||||
|
query = query.cuda()
|
||||||
|
gst_embed = torch.tanh(self.gst.stl.embed)
|
||||||
|
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
|
||||||
|
style_embed = self.gst.stl.attention(query, key)
|
||||||
|
else:
|
||||||
|
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
|
||||||
|
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
||||||
|
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
||||||
|
# style_embed = style_embed.expand_as(encoder_seq)
|
||||||
|
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
||||||
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
||||||
|
|
||||||
# Need a couple of lists for outputs
|
# Need a couple of lists for outputs
|
||||||
@@ -455,7 +497,7 @@ class Tacotron(nn.Module):
|
|||||||
attn_scores.append(scores)
|
attn_scores.append(scores)
|
||||||
stop_outputs.extend([stop_tokens] * self.r)
|
stop_outputs.extend([stop_tokens] * self.r)
|
||||||
# Stop the loop when all stop tokens in batch exceed threshold
|
# Stop the loop when all stop tokens in batch exceed threshold
|
||||||
if (stop_tokens > 0.5).all() and t > 10: break
|
if (stop_tokens * 10 > min_stop_token).all() and t > 10: break
|
||||||
|
|
||||||
# Concat the mel outputs into sequence
|
# Concat the mel outputs into sequence
|
||||||
mel_outputs = torch.cat(mel_outputs, dim=2)
|
mel_outputs = torch.cat(mel_outputs, dim=2)
|
||||||
@@ -479,6 +521,15 @@ class Tacotron(nn.Module):
|
|||||||
for p in self.parameters():
|
for p in self.parameters():
|
||||||
if p.dim() > 1: nn.init.xavier_uniform_(p)
|
if p.dim() > 1: nn.init.xavier_uniform_(p)
|
||||||
|
|
||||||
|
def finetune_partial(self, whitelist_layers):
|
||||||
|
self.zero_grad()
|
||||||
|
for name, child in self.named_children():
|
||||||
|
if name in whitelist_layers:
|
||||||
|
print("Trainable Layer: %s" % name)
|
||||||
|
print("Trainable Parameters: %.3f" % sum([np.prod(p.size()) for p in child.parameters()]))
|
||||||
|
for param in child.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
def get_step(self):
|
def get_step(self):
|
||||||
return self.step.data.item()
|
return self.step.data.item()
|
||||||
|
|
||||||
@@ -490,11 +541,10 @@ class Tacotron(nn.Module):
|
|||||||
with open(path, "a") as f:
|
with open(path, "a") as f:
|
||||||
print(msg, file=f)
|
print(msg, file=f)
|
||||||
|
|
||||||
def load(self, path, optimizer=None):
|
def load(self, path, device, optimizer=None):
|
||||||
# Use device of model params as location for loaded state
|
# Use device of model params as location for loaded state
|
||||||
device = next(self.parameters()).device
|
|
||||||
checkpoint = torch.load(str(path), map_location=device)
|
checkpoint = torch.load(str(path), map_location=device)
|
||||||
self.load_state_dict(checkpoint["model_state"])
|
self.load_state_dict(checkpoint["model_state"], strict=False)
|
||||||
|
|
||||||
if "optimizer_state" in checkpoint and optimizer is not None:
|
if "optimizer_state" in checkpoint and optimizer is not None:
|
||||||
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from tqdm import tqdm
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from encoder import inference as encoder
|
from encoder import inference as encoder
|
||||||
from synthesizer.preprocess_speaker import preprocess_speaker_general
|
from synthesizer.preprocess_speaker import preprocess_speaker_general
|
||||||
from synthesizer.preprocess_transcript import preprocess_transcript_aishell3
|
from synthesizer.preprocess_transcript import preprocess_transcript_aishell3, preprocess_transcript_magicdata
|
||||||
|
|
||||||
data_info = {
|
data_info = {
|
||||||
"aidatatang_200zh": {
|
"aidatatang_200zh": {
|
||||||
@@ -18,13 +18,19 @@ data_info = {
|
|||||||
"magicdata": {
|
"magicdata": {
|
||||||
"subfolders": ["train"],
|
"subfolders": ["train"],
|
||||||
"trans_filepath": "train/TRANS.txt",
|
"trans_filepath": "train/TRANS.txt",
|
||||||
"speak_func": preprocess_speaker_general
|
"speak_func": preprocess_speaker_general,
|
||||||
|
"transcript_func": preprocess_transcript_magicdata,
|
||||||
},
|
},
|
||||||
"aishell3":{
|
"aishell3":{
|
||||||
"subfolders": ["train/wav"],
|
"subfolders": ["train/wav"],
|
||||||
"trans_filepath": "train/content.txt",
|
"trans_filepath": "train/content.txt",
|
||||||
"speak_func": preprocess_speaker_general,
|
"speak_func": preprocess_speaker_general,
|
||||||
"transcript_func": preprocess_transcript_aishell3,
|
"transcript_func": preprocess_transcript_aishell3,
|
||||||
|
},
|
||||||
|
"data_aishell":{
|
||||||
|
"subfolders": ["wav/train"],
|
||||||
|
"trans_filepath": "transcript/aishell_transcript_v0.8.txt",
|
||||||
|
"speak_func": preprocess_speaker_general
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
|
|||||||
|
|
||||||
def _split_on_silences(wav_fpath, words, hparams):
|
def _split_on_silences(wav_fpath, words, hparams):
|
||||||
# Load the audio waveform
|
# Load the audio waveform
|
||||||
wav, _ = librosa.load(wav_fpath, hparams.sample_rate)
|
wav, _ = librosa.load(wav_fpath, sr= hparams.sample_rate)
|
||||||
wav = librosa.effects.trim(wav, top_db= 40, frame_length=2048, hop_length=512)[0]
|
wav = librosa.effects.trim(wav, top_db= 40, frame_length=2048, hop_length=512)[0]
|
||||||
if hparams.rescale:
|
if hparams.rescale:
|
||||||
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
||||||
|
|||||||
@@ -6,4 +6,13 @@ def preprocess_transcript_aishell3(dict_info, dict_transcript):
|
|||||||
transList = []
|
transList = []
|
||||||
for i in range(2, len(v), 2):
|
for i in range(2, len(v), 2):
|
||||||
transList.append(v[i])
|
transList.append(v[i])
|
||||||
dict_info[v[0]] = " ".join(transList)
|
dict_info[v[0]] = " ".join(transList)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_transcript_magicdata(dict_info, dict_transcript):
|
||||||
|
for v in dict_transcript:
|
||||||
|
if not v:
|
||||||
|
continue
|
||||||
|
v = v.strip().replace("\n","").replace("\t"," ").split(" ")
|
||||||
|
dict_info[v[0]] = " ".join(v[2:])
|
||||||
|
|
||||||
@@ -45,7 +45,7 @@ def run_synthesis(in_dir, out_dir, model_dir, hparams):
|
|||||||
model_dir = Path(model_dir)
|
model_dir = Path(model_dir)
|
||||||
model_fpath = model_dir.joinpath(model_dir.stem).with_suffix(".pt")
|
model_fpath = model_dir.joinpath(model_dir.stem).with_suffix(".pt")
|
||||||
print("\nLoading weights at %s" % model_fpath)
|
print("\nLoading weights at %s" % model_fpath)
|
||||||
model.load(model_fpath)
|
model.load(model_fpath, device)
|
||||||
print("Tacotron weights loaded from step %d" % model.step)
|
print("Tacotron weights loaded from step %d" % model.step)
|
||||||
|
|
||||||
# Synthesize using same reduction factor as the model is currently trained
|
# Synthesize using same reduction factor as the model is currently trained
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ def collate_synthesizer(batch):
|
|||||||
|
|
||||||
# Speaker embedding (SV2TTS)
|
# Speaker embedding (SV2TTS)
|
||||||
embeds = [x[2] for x in batch]
|
embeds = [x[2] for x in batch]
|
||||||
|
embeds = np.stack(embeds)
|
||||||
|
|
||||||
# Index (for vocoder preprocessing)
|
# Index (for vocoder preprocessing)
|
||||||
indices = [x[3] for x in batch]
|
indices = [x[3] for x in batch]
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from synthesizer.utils.symbols import symbols
|
|||||||
from synthesizer.utils.text import sequence_to_text
|
from synthesizer.utils.text import sequence_to_text
|
||||||
from vocoder.display import *
|
from vocoder.display import *
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
import sys
|
||||||
@@ -75,6 +76,13 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
|||||||
if num_chars != loaded_shape[0]:
|
if num_chars != loaded_shape[0]:
|
||||||
print("WARNING: you are using compatible mode due to wrong sympols length, please modify varible _characters in `utils\symbols.py`")
|
print("WARNING: you are using compatible mode due to wrong sympols length, please modify varible _characters in `utils\symbols.py`")
|
||||||
num_chars != loaded_shape[0]
|
num_chars != loaded_shape[0]
|
||||||
|
# Try to scan config file
|
||||||
|
model_config_fpaths = list(weights_fpath.parent.rglob("*.json"))
|
||||||
|
if len(model_config_fpaths)>0 and model_config_fpaths[0].exists():
|
||||||
|
with model_config_fpaths[0].open("r", encoding="utf-8") as f:
|
||||||
|
hparams.loadJson(json.load(f))
|
||||||
|
else: # save a config
|
||||||
|
hparams.dumpJson(weights_fpath.parent.joinpath(run_id).with_suffix(".json"))
|
||||||
|
|
||||||
|
|
||||||
model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
||||||
@@ -93,7 +101,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
|||||||
speaker_embedding_size=hparams.speaker_embedding_size).to(device)
|
speaker_embedding_size=hparams.speaker_embedding_size).to(device)
|
||||||
|
|
||||||
# Initialize the optimizer
|
# Initialize the optimizer
|
||||||
optimizer = optim.Adam(model.parameters())
|
optimizer = optim.Adam(model.parameters(), amsgrad=True)
|
||||||
|
|
||||||
# Load the weights
|
# Load the weights
|
||||||
if force_restart or not weights_fpath.exists():
|
if force_restart or not weights_fpath.exists():
|
||||||
@@ -111,7 +119,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
print("\nLoading weights at %s" % weights_fpath)
|
print("\nLoading weights at %s" % weights_fpath)
|
||||||
model.load(weights_fpath, optimizer)
|
model.load(weights_fpath, device, optimizer)
|
||||||
print("Tacotron weights loaded from step %d" % model.step)
|
print("Tacotron weights loaded from step %d" % model.step)
|
||||||
|
|
||||||
# Initialize the dataset
|
# Initialize the dataset
|
||||||
@@ -146,7 +154,6 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
model.r = r
|
model.r = r
|
||||||
|
|
||||||
# Begin the training
|
# Begin the training
|
||||||
simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
|
simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
|
||||||
("Batch Size", batch_size),
|
("Batch Size", batch_size),
|
||||||
@@ -155,6 +162,8 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
|||||||
|
|
||||||
for p in optimizer.param_groups:
|
for p in optimizer.param_groups:
|
||||||
p["lr"] = lr
|
p["lr"] = lr
|
||||||
|
if hparams.tts_finetune_layers is not None and len(hparams.tts_finetune_layers) > 0:
|
||||||
|
model.finetune_partial(hparams.tts_finetune_layers)
|
||||||
|
|
||||||
data_loader = DataLoader(dataset,
|
data_loader = DataLoader(dataset,
|
||||||
collate_fn=collate_synthesizer,
|
collate_fn=collate_synthesizer,
|
||||||
@@ -221,7 +230,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
|||||||
|
|
||||||
# Backup or save model as appropriate
|
# Backup or save model as appropriate
|
||||||
if backup_every != 0 and step % backup_every == 0 :
|
if backup_every != 0 and step % backup_every == 0 :
|
||||||
backup_fpath = Path("{}/{}_{}k.pt".format(str(weights_fpath.parent), run_id, k))
|
backup_fpath = Path("{}/{}_{}.pt".format(str(weights_fpath.parent), run_id, step))
|
||||||
model.save(backup_fpath, optimizer)
|
model.save(backup_fpath, optimizer)
|
||||||
|
|
||||||
if save_every != 0 and step % save_every == 0 :
|
if save_every != 0 and step % save_every == 0 :
|
||||||
|
|||||||
@@ -10,9 +10,7 @@ import numpy as np
|
|||||||
import traceback
|
import traceback
|
||||||
import sys
|
import sys
|
||||||
import torch
|
import torch
|
||||||
import librosa
|
|
||||||
import re
|
import re
|
||||||
from audioread.exceptions import NoBackendError
|
|
||||||
|
|
||||||
# 默认使用wavernn
|
# 默认使用wavernn
|
||||||
vocoder = rnn_vocoder
|
vocoder = rnn_vocoder
|
||||||
@@ -49,14 +47,20 @@ recognized_datasets = [
|
|||||||
MAX_WAVES = 15
|
MAX_WAVES = 15
|
||||||
|
|
||||||
class Toolbox:
|
class Toolbox:
|
||||||
def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, seed, no_mp3_support):
|
def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, extractor_models_dir, convertor_models_dir, seed, no_mp3_support, vc_mode):
|
||||||
self.no_mp3_support = no_mp3_support
|
self.no_mp3_support = no_mp3_support
|
||||||
|
self.vc_mode = vc_mode
|
||||||
sys.excepthook = self.excepthook
|
sys.excepthook = self.excepthook
|
||||||
self.datasets_root = datasets_root
|
self.datasets_root = datasets_root
|
||||||
self.utterances = set()
|
self.utterances = set()
|
||||||
self.current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav
|
self.current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav
|
||||||
|
|
||||||
self.synthesizer = None # type: Synthesizer
|
self.synthesizer = None # type: Synthesizer
|
||||||
|
|
||||||
|
# for ppg-based voice conversion
|
||||||
|
self.extractor = None
|
||||||
|
self.convertor = None # ppg2mel
|
||||||
|
|
||||||
self.current_wav = None
|
self.current_wav = None
|
||||||
self.waves_list = []
|
self.waves_list = []
|
||||||
self.waves_count = 0
|
self.waves_count = 0
|
||||||
@@ -70,8 +74,9 @@ class Toolbox:
|
|||||||
self.trim_silences = False
|
self.trim_silences = False
|
||||||
|
|
||||||
# Initialize the events and the interface
|
# Initialize the events and the interface
|
||||||
self.ui = UI()
|
self.ui = UI(vc_mode)
|
||||||
self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, seed)
|
self.style_idx = 0
|
||||||
|
self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, extractor_models_dir, convertor_models_dir, seed)
|
||||||
self.setup_events()
|
self.setup_events()
|
||||||
self.ui.start()
|
self.ui.start()
|
||||||
|
|
||||||
@@ -95,7 +100,11 @@ class Toolbox:
|
|||||||
self.ui.encoder_box.currentIndexChanged.connect(self.init_encoder)
|
self.ui.encoder_box.currentIndexChanged.connect(self.init_encoder)
|
||||||
def func():
|
def func():
|
||||||
self.synthesizer = None
|
self.synthesizer = None
|
||||||
self.ui.synthesizer_box.currentIndexChanged.connect(func)
|
if self.vc_mode:
|
||||||
|
self.ui.extractor_box.currentIndexChanged.connect(self.init_extractor)
|
||||||
|
else:
|
||||||
|
self.ui.synthesizer_box.currentIndexChanged.connect(func)
|
||||||
|
|
||||||
self.ui.vocoder_box.currentIndexChanged.connect(self.init_vocoder)
|
self.ui.vocoder_box.currentIndexChanged.connect(self.init_vocoder)
|
||||||
|
|
||||||
# Utterance selection
|
# Utterance selection
|
||||||
@@ -108,6 +117,11 @@ class Toolbox:
|
|||||||
self.ui.stop_button.clicked.connect(self.ui.stop)
|
self.ui.stop_button.clicked.connect(self.ui.stop)
|
||||||
self.ui.record_button.clicked.connect(self.record)
|
self.ui.record_button.clicked.connect(self.record)
|
||||||
|
|
||||||
|
# Source Utterance selection
|
||||||
|
if self.vc_mode:
|
||||||
|
func = lambda: self.load_soruce_button(self.ui.selected_utterance)
|
||||||
|
self.ui.load_soruce_button.clicked.connect(func)
|
||||||
|
|
||||||
#Audio
|
#Audio
|
||||||
self.ui.setup_audio_devices(Synthesizer.sample_rate)
|
self.ui.setup_audio_devices(Synthesizer.sample_rate)
|
||||||
|
|
||||||
@@ -119,12 +133,17 @@ class Toolbox:
|
|||||||
self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
|
self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
|
||||||
|
|
||||||
# Generation
|
# Generation
|
||||||
func = lambda: self.synthesize() or self.vocode()
|
|
||||||
self.ui.generate_button.clicked.connect(func)
|
|
||||||
self.ui.synthesize_button.clicked.connect(self.synthesize)
|
|
||||||
self.ui.vocode_button.clicked.connect(self.vocode)
|
self.ui.vocode_button.clicked.connect(self.vocode)
|
||||||
self.ui.random_seed_checkbox.clicked.connect(self.update_seed_textbox)
|
self.ui.random_seed_checkbox.clicked.connect(self.update_seed_textbox)
|
||||||
|
|
||||||
|
if self.vc_mode:
|
||||||
|
func = lambda: self.convert() or self.vocode()
|
||||||
|
self.ui.convert_button.clicked.connect(func)
|
||||||
|
else:
|
||||||
|
func = lambda: self.synthesize() or self.vocode()
|
||||||
|
self.ui.generate_button.clicked.connect(func)
|
||||||
|
self.ui.synthesize_button.clicked.connect(self.synthesize)
|
||||||
|
|
||||||
# UMAP legend
|
# UMAP legend
|
||||||
self.ui.clear_button.clicked.connect(self.clear_utterances)
|
self.ui.clear_button.clicked.connect(self.clear_utterances)
|
||||||
|
|
||||||
@@ -137,9 +156,9 @@ class Toolbox:
|
|||||||
def replay_last_wav(self):
|
def replay_last_wav(self):
|
||||||
self.ui.play(self.current_wav, Synthesizer.sample_rate)
|
self.ui.play(self.current_wav, Synthesizer.sample_rate)
|
||||||
|
|
||||||
def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, seed):
|
def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, extractor_models_dir, convertor_models_dir, seed):
|
||||||
self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True)
|
self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True)
|
||||||
self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir)
|
self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, extractor_models_dir, convertor_models_dir, self.vc_mode)
|
||||||
self.ui.populate_gen_options(seed, self.trim_silences)
|
self.ui.populate_gen_options(seed, self.trim_silences)
|
||||||
|
|
||||||
def load_from_browser(self, fpath=None):
|
def load_from_browser(self, fpath=None):
|
||||||
@@ -170,7 +189,10 @@ class Toolbox:
|
|||||||
self.ui.log("Loaded %s" % name)
|
self.ui.log("Loaded %s" % name)
|
||||||
|
|
||||||
self.add_real_utterance(wav, name, speaker_name)
|
self.add_real_utterance(wav, name, speaker_name)
|
||||||
|
|
||||||
|
def load_soruce_button(self, utterance: Utterance):
|
||||||
|
self.selected_source_utterance = utterance
|
||||||
|
|
||||||
def record(self):
|
def record(self):
|
||||||
wav = self.ui.record_one(encoder.sampling_rate, 5)
|
wav = self.ui.record_one(encoder.sampling_rate, 5)
|
||||||
if wav is None:
|
if wav is None:
|
||||||
@@ -195,7 +217,7 @@ class Toolbox:
|
|||||||
# Add the utterance
|
# Add the utterance
|
||||||
utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, False)
|
utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, False)
|
||||||
self.utterances.add(utterance)
|
self.utterances.add(utterance)
|
||||||
self.ui.register_utterance(utterance)
|
self.ui.register_utterance(utterance, self.vc_mode)
|
||||||
|
|
||||||
# Plot it
|
# Plot it
|
||||||
self.ui.draw_embed(embed, name, "current")
|
self.ui.draw_embed(embed, name, "current")
|
||||||
@@ -233,7 +255,8 @@ class Toolbox:
|
|||||||
texts = processed_texts
|
texts = processed_texts
|
||||||
embed = self.ui.selected_utterance.embed
|
embed = self.ui.selected_utterance.embed
|
||||||
embeds = [embed] * len(texts)
|
embeds = [embed] * len(texts)
|
||||||
specs = self.synthesizer.synthesize_spectrograms(texts, embeds)
|
min_token = int(self.ui.token_slider.value())
|
||||||
|
specs = self.synthesizer.synthesize_spectrograms(texts, embeds, style_idx=int(self.ui.style_slider.value()), min_stop_token=min_token, steps=int(self.ui.length_slider.value())*200)
|
||||||
breaks = [spec.shape[1] for spec in specs]
|
breaks = [spec.shape[1] for spec in specs]
|
||||||
spec = np.concatenate(specs, axis=1)
|
spec = np.concatenate(specs, axis=1)
|
||||||
|
|
||||||
@@ -267,7 +290,7 @@ class Toolbox:
|
|||||||
self.ui.set_loading(i, seq_len)
|
self.ui.set_loading(i, seq_len)
|
||||||
if self.ui.current_vocoder_fpath is not None:
|
if self.ui.current_vocoder_fpath is not None:
|
||||||
self.ui.log("")
|
self.ui.log("")
|
||||||
wav = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
|
wav, sample_rate = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
|
||||||
else:
|
else:
|
||||||
self.ui.log("Waveform generation with Griffin-Lim... ")
|
self.ui.log("Waveform generation with Griffin-Lim... ")
|
||||||
wav = Synthesizer.griffin_lim(spec)
|
wav = Synthesizer.griffin_lim(spec)
|
||||||
@@ -278,7 +301,7 @@ class Toolbox:
|
|||||||
b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
|
b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
|
||||||
b_starts = np.concatenate(([0], b_ends[:-1]))
|
b_starts = np.concatenate(([0], b_ends[:-1]))
|
||||||
wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)]
|
wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)]
|
||||||
breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks)
|
breaks = [np.zeros(int(0.15 * sample_rate))] * len(breaks)
|
||||||
wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
|
wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
|
||||||
|
|
||||||
# Trim excessive silences
|
# Trim excessive silences
|
||||||
@@ -287,7 +310,7 @@ class Toolbox:
|
|||||||
|
|
||||||
# Play it
|
# Play it
|
||||||
wav = wav / np.abs(wav).max() * 0.97
|
wav = wav / np.abs(wav).max() * 0.97
|
||||||
self.ui.play(wav, Synthesizer.sample_rate)
|
self.ui.play(wav, sample_rate)
|
||||||
|
|
||||||
# Name it (history displayed in combobox)
|
# Name it (history displayed in combobox)
|
||||||
# TODO better naming for the combobox items?
|
# TODO better naming for the combobox items?
|
||||||
@@ -329,6 +352,67 @@ class Toolbox:
|
|||||||
self.ui.draw_embed(embed, name, "generated")
|
self.ui.draw_embed(embed, name, "generated")
|
||||||
self.ui.draw_umap_projections(self.utterances)
|
self.ui.draw_umap_projections(self.utterances)
|
||||||
|
|
||||||
|
def convert(self):
|
||||||
|
self.ui.log("Extract PPG and Converting...")
|
||||||
|
self.ui.set_loading(1)
|
||||||
|
|
||||||
|
# Init
|
||||||
|
if self.convertor is None:
|
||||||
|
self.init_convertor()
|
||||||
|
if self.extractor is None:
|
||||||
|
self.init_extractor()
|
||||||
|
|
||||||
|
src_wav = self.selected_source_utterance.wav
|
||||||
|
|
||||||
|
# Compute the ppg
|
||||||
|
if not self.extractor is None:
|
||||||
|
ppg = self.extractor.extract_from_wav(src_wav)
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
ref_wav = self.ui.selected_utterance.wav
|
||||||
|
# Import necessary dependency of Voice Conversion
|
||||||
|
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
|
||||||
|
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
|
||||||
|
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
|
||||||
|
min_len = min(ppg.shape[1], len(lf0_uv))
|
||||||
|
ppg = ppg[:, :min_len]
|
||||||
|
lf0_uv = lf0_uv[:min_len]
|
||||||
|
_, mel_pred, att_ws = self.convertor.inference(
|
||||||
|
ppg,
|
||||||
|
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
|
||||||
|
spembs=torch.from_numpy(self.ui.selected_utterance.embed).unsqueeze(0).to(device),
|
||||||
|
)
|
||||||
|
mel_pred= mel_pred.transpose(0, 1)
|
||||||
|
breaks = [mel_pred.shape[1]]
|
||||||
|
mel_pred= mel_pred.detach().cpu().numpy()
|
||||||
|
self.ui.draw_spec(mel_pred, "generated")
|
||||||
|
self.current_generated = (self.ui.selected_utterance.speaker_name, mel_pred, breaks, None)
|
||||||
|
self.ui.set_loading(0)
|
||||||
|
|
||||||
|
def init_extractor(self):
|
||||||
|
if self.ui.current_extractor_fpath is None:
|
||||||
|
return
|
||||||
|
model_fpath = self.ui.current_extractor_fpath
|
||||||
|
self.ui.log("Loading the extractor %s... " % model_fpath)
|
||||||
|
self.ui.set_loading(1)
|
||||||
|
start = timer()
|
||||||
|
import ppg_extractor as extractor
|
||||||
|
self.extractor = extractor.load_model(model_fpath)
|
||||||
|
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
||||||
|
self.ui.set_loading(0)
|
||||||
|
|
||||||
|
def init_convertor(self):
|
||||||
|
if self.ui.current_convertor_fpath is None:
|
||||||
|
return
|
||||||
|
model_fpath = self.ui.current_convertor_fpath
|
||||||
|
self.ui.log("Loading the convertor %s... " % model_fpath)
|
||||||
|
self.ui.set_loading(1)
|
||||||
|
start = timer()
|
||||||
|
import ppg2mel as convertor
|
||||||
|
self.convertor = convertor.load_model( model_fpath)
|
||||||
|
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
||||||
|
self.ui.set_loading(0)
|
||||||
|
|
||||||
def init_encoder(self):
|
def init_encoder(self):
|
||||||
model_fpath = self.ui.current_encoder_fpath
|
model_fpath = self.ui.current_encoder_fpath
|
||||||
|
|
||||||
@@ -356,12 +440,17 @@ class Toolbox:
|
|||||||
# Case of Griffin-lim
|
# Case of Griffin-lim
|
||||||
if model_fpath is None:
|
if model_fpath is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
# Sekect vocoder based on model name
|
# Sekect vocoder based on model name
|
||||||
|
model_config_fpath = None
|
||||||
if model_fpath.name[0] == "g":
|
if model_fpath.name[0] == "g":
|
||||||
vocoder = gan_vocoder
|
vocoder = gan_vocoder
|
||||||
self.ui.log("set hifigan as vocoder")
|
self.ui.log("set hifigan as vocoder")
|
||||||
|
# search a config file
|
||||||
|
model_config_fpaths = list(model_fpath.parent.rglob("*.json"))
|
||||||
|
if self.vc_mode and self.ui.current_extractor_fpath is None:
|
||||||
|
return
|
||||||
|
if len(model_config_fpaths) > 0:
|
||||||
|
model_config_fpath = model_config_fpaths[0]
|
||||||
else:
|
else:
|
||||||
vocoder = rnn_vocoder
|
vocoder = rnn_vocoder
|
||||||
self.ui.log("set wavernn as vocoder")
|
self.ui.log("set wavernn as vocoder")
|
||||||
@@ -369,7 +458,7 @@ class Toolbox:
|
|||||||
self.ui.log("Loading the vocoder %s... " % model_fpath)
|
self.ui.log("Loading the vocoder %s... " % model_fpath)
|
||||||
self.ui.set_loading(1)
|
self.ui.set_loading(1)
|
||||||
start = timer()
|
start = timer()
|
||||||
vocoder.load_model(model_fpath)
|
vocoder.load_model(model_fpath, model_config_fpath)
|
||||||
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
||||||
self.ui.set_loading(0)
|
self.ui.set_loading(0)
|
||||||
|
|
||||||
|
|||||||
BIN
toolbox/assets/mb.png
Normal file
BIN
toolbox/assets/mb.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.6 KiB |
261
toolbox/ui.py
261
toolbox/ui.py
@@ -1,8 +1,9 @@
|
|||||||
|
from PyQt5.QtCore import Qt, QStringListModel
|
||||||
|
from PyQt5 import QtGui
|
||||||
|
from PyQt5.QtWidgets import *
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from PyQt5.QtCore import Qt, QStringListModel
|
|
||||||
from PyQt5.QtWidgets import *
|
|
||||||
from encoder.inference import plot_embedding_as_heatmap
|
from encoder.inference import plot_embedding_as_heatmap
|
||||||
from toolbox.utterance import Utterance
|
from toolbox.utterance import Utterance
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -325,30 +326,51 @@ class UI(QDialog):
|
|||||||
def current_vocoder_fpath(self):
|
def current_vocoder_fpath(self):
|
||||||
return self.vocoder_box.itemData(self.vocoder_box.currentIndex())
|
return self.vocoder_box.itemData(self.vocoder_box.currentIndex())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_extractor_fpath(self):
|
||||||
|
return self.extractor_box.itemData(self.extractor_box.currentIndex())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_convertor_fpath(self):
|
||||||
|
return self.convertor_box.itemData(self.convertor_box.currentIndex())
|
||||||
|
|
||||||
def populate_models(self, encoder_models_dir: Path, synthesizer_models_dir: Path,
|
def populate_models(self, encoder_models_dir: Path, synthesizer_models_dir: Path,
|
||||||
vocoder_models_dir: Path):
|
vocoder_models_dir: Path, extractor_models_dir: Path, convertor_models_dir: Path, vc_mode: bool):
|
||||||
# Encoder
|
# Encoder
|
||||||
encoder_fpaths = list(encoder_models_dir.glob("*.pt"))
|
encoder_fpaths = list(encoder_models_dir.glob("*.pt"))
|
||||||
if len(encoder_fpaths) == 0:
|
if len(encoder_fpaths) == 0:
|
||||||
raise Exception("No encoder models found in %s" % encoder_models_dir)
|
raise Exception("No encoder models found in %s" % encoder_models_dir)
|
||||||
self.repopulate_box(self.encoder_box, [(f.stem, f) for f in encoder_fpaths])
|
self.repopulate_box(self.encoder_box, [(f.stem, f) for f in encoder_fpaths])
|
||||||
|
|
||||||
# Synthesizer
|
if vc_mode:
|
||||||
synthesizer_fpaths = list(synthesizer_models_dir.glob("**/*.pt"))
|
# Extractor
|
||||||
if len(synthesizer_fpaths) == 0:
|
extractor_fpaths = list(extractor_models_dir.glob("*.pt"))
|
||||||
raise Exception("No synthesizer models found in %s" % synthesizer_models_dir)
|
if len(extractor_fpaths) == 0:
|
||||||
self.repopulate_box(self.synthesizer_box, [(f.stem, f) for f in synthesizer_fpaths])
|
self.log("No extractor models found in %s" % extractor_fpaths)
|
||||||
|
self.repopulate_box(self.extractor_box, [(f.stem, f) for f in extractor_fpaths])
|
||||||
|
|
||||||
|
# Convertor
|
||||||
|
convertor_fpaths = list(convertor_models_dir.glob("*.pth"))
|
||||||
|
if len(convertor_fpaths) == 0:
|
||||||
|
self.log("No convertor models found in %s" % convertor_fpaths)
|
||||||
|
self.repopulate_box(self.convertor_box, [(f.stem, f) for f in convertor_fpaths])
|
||||||
|
else:
|
||||||
|
# Synthesizer
|
||||||
|
synthesizer_fpaths = list(synthesizer_models_dir.glob("**/*.pt"))
|
||||||
|
if len(synthesizer_fpaths) == 0:
|
||||||
|
raise Exception("No synthesizer models found in %s" % synthesizer_models_dir)
|
||||||
|
self.repopulate_box(self.synthesizer_box, [(f.stem, f) for f in synthesizer_fpaths])
|
||||||
|
|
||||||
# Vocoder
|
# Vocoder
|
||||||
vocoder_fpaths = list(vocoder_models_dir.glob("**/*.pt"))
|
vocoder_fpaths = list(vocoder_models_dir.glob("**/*.pt"))
|
||||||
vocoder_items = [(f.stem, f) for f in vocoder_fpaths] + [("Griffin-Lim", None)]
|
vocoder_items = [(f.stem, f) for f in vocoder_fpaths] + [("Griffin-Lim", None)]
|
||||||
self.repopulate_box(self.vocoder_box, vocoder_items)
|
self.repopulate_box(self.vocoder_box, vocoder_items)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def selected_utterance(self):
|
def selected_utterance(self):
|
||||||
return self.utterance_history.itemData(self.utterance_history.currentIndex())
|
return self.utterance_history.itemData(self.utterance_history.currentIndex())
|
||||||
|
|
||||||
def register_utterance(self, utterance: Utterance):
|
def register_utterance(self, utterance: Utterance, vc_mode):
|
||||||
self.utterance_history.blockSignals(True)
|
self.utterance_history.blockSignals(True)
|
||||||
self.utterance_history.insertItem(0, utterance.name, utterance)
|
self.utterance_history.insertItem(0, utterance.name, utterance)
|
||||||
self.utterance_history.setCurrentIndex(0)
|
self.utterance_history.setCurrentIndex(0)
|
||||||
@@ -358,8 +380,11 @@ class UI(QDialog):
|
|||||||
self.utterance_history.removeItem(self.max_saved_utterances)
|
self.utterance_history.removeItem(self.max_saved_utterances)
|
||||||
|
|
||||||
self.play_button.setDisabled(False)
|
self.play_button.setDisabled(False)
|
||||||
self.generate_button.setDisabled(False)
|
if vc_mode:
|
||||||
self.synthesize_button.setDisabled(False)
|
self.convert_button.setDisabled(False)
|
||||||
|
else:
|
||||||
|
self.generate_button.setDisabled(False)
|
||||||
|
self.synthesize_button.setDisabled(False)
|
||||||
|
|
||||||
def log(self, line, mode="newline"):
|
def log(self, line, mode="newline"):
|
||||||
if mode == "newline":
|
if mode == "newline":
|
||||||
@@ -401,7 +426,7 @@ class UI(QDialog):
|
|||||||
else:
|
else:
|
||||||
self.seed_textbox.setEnabled(False)
|
self.seed_textbox.setEnabled(False)
|
||||||
|
|
||||||
def reset_interface(self):
|
def reset_interface(self, vc_mode):
|
||||||
self.draw_embed(None, None, "current")
|
self.draw_embed(None, None, "current")
|
||||||
self.draw_embed(None, None, "generated")
|
self.draw_embed(None, None, "generated")
|
||||||
self.draw_spec(None, "current")
|
self.draw_spec(None, "current")
|
||||||
@@ -409,18 +434,24 @@ class UI(QDialog):
|
|||||||
self.draw_umap_projections(set())
|
self.draw_umap_projections(set())
|
||||||
self.set_loading(0)
|
self.set_loading(0)
|
||||||
self.play_button.setDisabled(True)
|
self.play_button.setDisabled(True)
|
||||||
self.generate_button.setDisabled(True)
|
if vc_mode:
|
||||||
self.synthesize_button.setDisabled(True)
|
self.convert_button.setDisabled(True)
|
||||||
|
else:
|
||||||
|
self.generate_button.setDisabled(True)
|
||||||
|
self.synthesize_button.setDisabled(True)
|
||||||
self.vocode_button.setDisabled(True)
|
self.vocode_button.setDisabled(True)
|
||||||
self.replay_wav_button.setDisabled(True)
|
self.replay_wav_button.setDisabled(True)
|
||||||
self.export_wav_button.setDisabled(True)
|
self.export_wav_button.setDisabled(True)
|
||||||
[self.log("") for _ in range(self.max_log_lines)]
|
[self.log("") for _ in range(self.max_log_lines)]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, vc_mode):
|
||||||
## Initialize the application
|
## Initialize the application
|
||||||
self.app = QApplication(sys.argv)
|
self.app = QApplication(sys.argv)
|
||||||
super().__init__(None)
|
super().__init__(None)
|
||||||
self.setWindowTitle("SV2TTS toolbox")
|
self.setWindowTitle("MockingBird GUI")
|
||||||
|
self.setWindowIcon(QtGui.QIcon('toolbox\\assets\\mb.png'))
|
||||||
|
self.setWindowFlag(Qt.WindowMinimizeButtonHint, True)
|
||||||
|
self.setWindowFlag(Qt.WindowMaximizeButtonHint, True)
|
||||||
|
|
||||||
|
|
||||||
## Main layouts
|
## Main layouts
|
||||||
@@ -430,21 +461,24 @@ class UI(QDialog):
|
|||||||
|
|
||||||
# Browser
|
# Browser
|
||||||
browser_layout = QGridLayout()
|
browser_layout = QGridLayout()
|
||||||
root_layout.addLayout(browser_layout, 0, 0, 1, 2)
|
root_layout.addLayout(browser_layout, 0, 0, 1, 8)
|
||||||
|
|
||||||
# Generation
|
# Generation
|
||||||
gen_layout = QVBoxLayout()
|
gen_layout = QVBoxLayout()
|
||||||
root_layout.addLayout(gen_layout, 0, 2, 1, 2)
|
root_layout.addLayout(gen_layout, 0, 8)
|
||||||
|
|
||||||
# Projections
|
|
||||||
self.projections_layout = QVBoxLayout()
|
|
||||||
root_layout.addLayout(self.projections_layout, 1, 0, 1, 1)
|
|
||||||
|
|
||||||
# Visualizations
|
# Visualizations
|
||||||
vis_layout = QVBoxLayout()
|
vis_layout = QVBoxLayout()
|
||||||
root_layout.addLayout(vis_layout, 1, 1, 1, 3)
|
root_layout.addLayout(vis_layout, 1, 0, 2, 8)
|
||||||
|
|
||||||
|
# Output
|
||||||
|
output_layout = QGridLayout()
|
||||||
|
vis_layout.addLayout(output_layout, 0)
|
||||||
|
|
||||||
|
# Projections
|
||||||
|
self.projections_layout = QVBoxLayout()
|
||||||
|
root_layout.addLayout(self.projections_layout, 1, 8, 2, 2)
|
||||||
|
|
||||||
## Projections
|
## Projections
|
||||||
# UMap
|
# UMap
|
||||||
fig, self.umap_ax = plt.subplots(figsize=(3, 3), facecolor="#F0F0F0")
|
fig, self.umap_ax = plt.subplots(figsize=(3, 3), facecolor="#F0F0F0")
|
||||||
@@ -458,84 +492,102 @@ class UI(QDialog):
|
|||||||
## Browser
|
## Browser
|
||||||
# Dataset, speaker and utterance selection
|
# Dataset, speaker and utterance selection
|
||||||
i = 0
|
i = 0
|
||||||
self.dataset_box = QComboBox()
|
|
||||||
browser_layout.addWidget(QLabel("<b>Dataset</b>"), i, 0)
|
|
||||||
browser_layout.addWidget(self.dataset_box, i + 1, 0)
|
|
||||||
self.speaker_box = QComboBox()
|
|
||||||
browser_layout.addWidget(QLabel("<b>Speaker</b>"), i, 1)
|
|
||||||
browser_layout.addWidget(self.speaker_box, i + 1, 1)
|
|
||||||
self.utterance_box = QComboBox()
|
|
||||||
browser_layout.addWidget(QLabel("<b>Utterance</b>"), i, 2)
|
|
||||||
browser_layout.addWidget(self.utterance_box, i + 1, 2)
|
|
||||||
self.browser_load_button = QPushButton("Load")
|
|
||||||
browser_layout.addWidget(self.browser_load_button, i + 1, 3)
|
|
||||||
i += 2
|
|
||||||
|
|
||||||
# Random buttons
|
source_groupbox = QGroupBox('Source(源音频)')
|
||||||
|
source_layout = QGridLayout()
|
||||||
|
source_groupbox.setLayout(source_layout)
|
||||||
|
browser_layout.addWidget(source_groupbox, i, 0, 1, 5)
|
||||||
|
|
||||||
|
self.dataset_box = QComboBox()
|
||||||
|
source_layout.addWidget(QLabel("Dataset(数据集):"), i, 0)
|
||||||
|
source_layout.addWidget(self.dataset_box, i, 1)
|
||||||
self.random_dataset_button = QPushButton("Random")
|
self.random_dataset_button = QPushButton("Random")
|
||||||
browser_layout.addWidget(self.random_dataset_button, i, 0)
|
source_layout.addWidget(self.random_dataset_button, i, 2)
|
||||||
|
i += 1
|
||||||
|
self.speaker_box = QComboBox()
|
||||||
|
source_layout.addWidget(QLabel("Speaker(说话者)"), i, 0)
|
||||||
|
source_layout.addWidget(self.speaker_box, i, 1)
|
||||||
self.random_speaker_button = QPushButton("Random")
|
self.random_speaker_button = QPushButton("Random")
|
||||||
browser_layout.addWidget(self.random_speaker_button, i, 1)
|
source_layout.addWidget(self.random_speaker_button, i, 2)
|
||||||
|
i += 1
|
||||||
|
self.utterance_box = QComboBox()
|
||||||
|
source_layout.addWidget(QLabel("Utterance(音频):"), i, 0)
|
||||||
|
source_layout.addWidget(self.utterance_box, i, 1)
|
||||||
self.random_utterance_button = QPushButton("Random")
|
self.random_utterance_button = QPushButton("Random")
|
||||||
browser_layout.addWidget(self.random_utterance_button, i, 2)
|
source_layout.addWidget(self.random_utterance_button, i, 2)
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
source_layout.addWidget(QLabel("<b>Use(使用):</b>"), i, 0)
|
||||||
|
self.browser_load_button = QPushButton("Load Above(加载上面)")
|
||||||
|
source_layout.addWidget(self.browser_load_button, i, 1, 1, 2)
|
||||||
self.auto_next_checkbox = QCheckBox("Auto select next")
|
self.auto_next_checkbox = QCheckBox("Auto select next")
|
||||||
self.auto_next_checkbox.setChecked(True)
|
self.auto_next_checkbox.setChecked(True)
|
||||||
browser_layout.addWidget(self.auto_next_checkbox, i, 3)
|
source_layout.addWidget(self.auto_next_checkbox, i+1, 1)
|
||||||
i += 1
|
self.browser_browse_button = QPushButton("Browse(打开本地)")
|
||||||
|
source_layout.addWidget(self.browser_browse_button, i, 3)
|
||||||
|
self.record_button = QPushButton("Record(录音)")
|
||||||
|
source_layout.addWidget(self.record_button, i+1, 3)
|
||||||
|
|
||||||
|
i += 2
|
||||||
# Utterance box
|
# Utterance box
|
||||||
browser_layout.addWidget(QLabel("<b>Use embedding from:</b>"), i, 0)
|
browser_layout.addWidget(QLabel("<b>Current(当前):</b>"), i, 0)
|
||||||
self.utterance_history = QComboBox()
|
self.utterance_history = QComboBox()
|
||||||
browser_layout.addWidget(self.utterance_history, i, 1, 1, 3)
|
browser_layout.addWidget(self.utterance_history, i, 1)
|
||||||
i += 1
|
self.play_button = QPushButton("Play(播放)")
|
||||||
|
|
||||||
# Random & next utterance buttons
|
|
||||||
self.browser_browse_button = QPushButton("Browse")
|
|
||||||
browser_layout.addWidget(self.browser_browse_button, i, 0)
|
|
||||||
self.record_button = QPushButton("Record")
|
|
||||||
browser_layout.addWidget(self.record_button, i, 1)
|
|
||||||
self.play_button = QPushButton("Play")
|
|
||||||
browser_layout.addWidget(self.play_button, i, 2)
|
browser_layout.addWidget(self.play_button, i, 2)
|
||||||
self.stop_button = QPushButton("Stop")
|
self.stop_button = QPushButton("Stop(暂停)")
|
||||||
browser_layout.addWidget(self.stop_button, i, 3)
|
browser_layout.addWidget(self.stop_button, i, 3)
|
||||||
i += 1
|
if vc_mode:
|
||||||
|
self.load_soruce_button = QPushButton("Select(选择为被转换的语音输入)")
|
||||||
|
browser_layout.addWidget(self.load_soruce_button, i, 4)
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
model_groupbox = QGroupBox('Models(模型选择)')
|
||||||
|
model_layout = QHBoxLayout()
|
||||||
|
model_groupbox.setLayout(model_layout)
|
||||||
|
browser_layout.addWidget(model_groupbox, i, 0, 2, 5)
|
||||||
|
|
||||||
# Model and audio output selection
|
# Model and audio output selection
|
||||||
self.encoder_box = QComboBox()
|
self.encoder_box = QComboBox()
|
||||||
browser_layout.addWidget(QLabel("<b>Encoder</b>"), i, 0)
|
model_layout.addWidget(QLabel("Encoder:"))
|
||||||
browser_layout.addWidget(self.encoder_box, i + 1, 0)
|
model_layout.addWidget(self.encoder_box)
|
||||||
self.synthesizer_box = QComboBox()
|
self.synthesizer_box = QComboBox()
|
||||||
browser_layout.addWidget(QLabel("<b>Synthesizer</b>"), i, 1)
|
if vc_mode:
|
||||||
browser_layout.addWidget(self.synthesizer_box, i + 1, 1)
|
self.extractor_box = QComboBox()
|
||||||
|
model_layout.addWidget(QLabel("Extractor:"))
|
||||||
|
model_layout.addWidget(self.extractor_box)
|
||||||
|
self.convertor_box = QComboBox()
|
||||||
|
model_layout.addWidget(QLabel("Convertor:"))
|
||||||
|
model_layout.addWidget(self.convertor_box)
|
||||||
|
else:
|
||||||
|
model_layout.addWidget(QLabel("Synthesizer:"))
|
||||||
|
model_layout.addWidget(self.synthesizer_box)
|
||||||
self.vocoder_box = QComboBox()
|
self.vocoder_box = QComboBox()
|
||||||
browser_layout.addWidget(QLabel("<b>Vocoder</b>"), i, 2)
|
model_layout.addWidget(QLabel("Vocoder:"))
|
||||||
browser_layout.addWidget(self.vocoder_box, i + 1, 2)
|
model_layout.addWidget(self.vocoder_box)
|
||||||
|
|
||||||
self.audio_out_devices_cb=QComboBox()
|
|
||||||
browser_layout.addWidget(QLabel("<b>Audio Output</b>"), i, 3)
|
|
||||||
browser_layout.addWidget(self.audio_out_devices_cb, i + 1, 3)
|
|
||||||
i += 2
|
|
||||||
|
|
||||||
#Replay & Save Audio
|
#Replay & Save Audio
|
||||||
browser_layout.addWidget(QLabel("<b>Toolbox Output:</b>"), i, 0)
|
i = 0
|
||||||
|
output_layout.addWidget(QLabel("<b>Toolbox Output:</b>"), i, 0)
|
||||||
self.waves_cb = QComboBox()
|
self.waves_cb = QComboBox()
|
||||||
self.waves_cb_model = QStringListModel()
|
self.waves_cb_model = QStringListModel()
|
||||||
self.waves_cb.setModel(self.waves_cb_model)
|
self.waves_cb.setModel(self.waves_cb_model)
|
||||||
self.waves_cb.setToolTip("Select one of the last generated waves in this section for replaying or exporting")
|
self.waves_cb.setToolTip("Select one of the last generated waves in this section for replaying or exporting")
|
||||||
browser_layout.addWidget(self.waves_cb, i, 1)
|
output_layout.addWidget(self.waves_cb, i, 1)
|
||||||
self.replay_wav_button = QPushButton("Replay")
|
self.replay_wav_button = QPushButton("Replay")
|
||||||
self.replay_wav_button.setToolTip("Replay last generated vocoder")
|
self.replay_wav_button.setToolTip("Replay last generated vocoder")
|
||||||
browser_layout.addWidget(self.replay_wav_button, i, 2)
|
output_layout.addWidget(self.replay_wav_button, i, 2)
|
||||||
self.export_wav_button = QPushButton("Export")
|
self.export_wav_button = QPushButton("Export")
|
||||||
self.export_wav_button.setToolTip("Save last generated vocoder audio in filesystem as a wav file")
|
self.export_wav_button.setToolTip("Save last generated vocoder audio in filesystem as a wav file")
|
||||||
browser_layout.addWidget(self.export_wav_button, i, 3)
|
output_layout.addWidget(self.export_wav_button, i, 3)
|
||||||
|
self.audio_out_devices_cb=QComboBox()
|
||||||
i += 1
|
i += 1
|
||||||
|
output_layout.addWidget(QLabel("<b>Audio Output</b>"), i, 0)
|
||||||
|
output_layout.addWidget(self.audio_out_devices_cb, i, 1)
|
||||||
|
|
||||||
## Embed & spectrograms
|
## Embed & spectrograms
|
||||||
vis_layout.addStretch()
|
vis_layout.addStretch()
|
||||||
|
# TODO: add spectrograms for source
|
||||||
gridspec_kw = {"width_ratios": [1, 4]}
|
gridspec_kw = {"width_ratios": [1, 4]}
|
||||||
fig, self.current_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0",
|
fig, self.current_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0",
|
||||||
gridspec_kw=gridspec_kw)
|
gridspec_kw=gridspec_kw)
|
||||||
@@ -552,21 +604,27 @@ class UI(QDialog):
|
|||||||
for side in ["top", "right", "bottom", "left"]:
|
for side in ["top", "right", "bottom", "left"]:
|
||||||
ax.spines[side].set_visible(False)
|
ax.spines[side].set_visible(False)
|
||||||
|
|
||||||
|
|
||||||
## Generation
|
## Generation
|
||||||
self.text_prompt = QPlainTextEdit(default_text)
|
self.text_prompt = QPlainTextEdit(default_text)
|
||||||
gen_layout.addWidget(self.text_prompt, stretch=1)
|
gen_layout.addWidget(self.text_prompt, stretch=1)
|
||||||
|
|
||||||
self.generate_button = QPushButton("Synthesize and vocode")
|
if vc_mode:
|
||||||
gen_layout.addWidget(self.generate_button)
|
layout = QHBoxLayout()
|
||||||
|
self.convert_button = QPushButton("Extract and Convert")
|
||||||
layout = QHBoxLayout()
|
layout.addWidget(self.convert_button)
|
||||||
self.synthesize_button = QPushButton("Synthesize only")
|
gen_layout.addLayout(layout)
|
||||||
layout.addWidget(self.synthesize_button)
|
else:
|
||||||
|
self.generate_button = QPushButton("Synthesize and vocode")
|
||||||
|
gen_layout.addWidget(self.generate_button)
|
||||||
|
layout = QHBoxLayout()
|
||||||
|
self.synthesize_button = QPushButton("Synthesize only")
|
||||||
|
layout.addWidget(self.synthesize_button)
|
||||||
|
|
||||||
self.vocode_button = QPushButton("Vocode only")
|
self.vocode_button = QPushButton("Vocode only")
|
||||||
layout.addWidget(self.vocode_button)
|
layout.addWidget(self.vocode_button)
|
||||||
gen_layout.addLayout(layout)
|
gen_layout.addLayout(layout)
|
||||||
|
|
||||||
|
|
||||||
layout_seed = QGridLayout()
|
layout_seed = QGridLayout()
|
||||||
self.random_seed_checkbox = QCheckBox("Random seed:")
|
self.random_seed_checkbox = QCheckBox("Random seed:")
|
||||||
self.random_seed_checkbox.setToolTip("When checked, makes the synthesizer and vocoder deterministic.")
|
self.random_seed_checkbox.setToolTip("When checked, makes the synthesizer and vocoder deterministic.")
|
||||||
@@ -578,6 +636,45 @@ class UI(QDialog):
|
|||||||
self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output."
|
self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output."
|
||||||
" This feature requires `webrtcvad` to be installed.")
|
" This feature requires `webrtcvad` to be installed.")
|
||||||
layout_seed.addWidget(self.trim_silences_checkbox, 0, 2, 1, 2)
|
layout_seed.addWidget(self.trim_silences_checkbox, 0, 2, 1, 2)
|
||||||
|
self.style_slider = QSlider(Qt.Horizontal)
|
||||||
|
self.style_slider.setTickInterval(1)
|
||||||
|
self.style_slider.setFocusPolicy(Qt.NoFocus)
|
||||||
|
self.style_slider.setSingleStep(1)
|
||||||
|
self.style_slider.setRange(-1, 9)
|
||||||
|
self.style_value_label = QLabel("-1")
|
||||||
|
self.style_slider.setValue(-1)
|
||||||
|
layout_seed.addWidget(QLabel("Style:"), 1, 0)
|
||||||
|
|
||||||
|
self.style_slider.valueChanged.connect(lambda s: self.style_value_label.setNum(s))
|
||||||
|
layout_seed.addWidget(self.style_value_label, 1, 1)
|
||||||
|
layout_seed.addWidget(self.style_slider, 1, 3)
|
||||||
|
|
||||||
|
self.token_slider = QSlider(Qt.Horizontal)
|
||||||
|
self.token_slider.setTickInterval(1)
|
||||||
|
self.token_slider.setFocusPolicy(Qt.NoFocus)
|
||||||
|
self.token_slider.setSingleStep(1)
|
||||||
|
self.token_slider.setRange(3, 9)
|
||||||
|
self.token_value_label = QLabel("5")
|
||||||
|
self.token_slider.setValue(4)
|
||||||
|
layout_seed.addWidget(QLabel("Accuracy(精度):"), 2, 0)
|
||||||
|
|
||||||
|
self.token_slider.valueChanged.connect(lambda s: self.token_value_label.setNum(s))
|
||||||
|
layout_seed.addWidget(self.token_value_label, 2, 1)
|
||||||
|
layout_seed.addWidget(self.token_slider, 2, 3)
|
||||||
|
|
||||||
|
self.length_slider = QSlider(Qt.Horizontal)
|
||||||
|
self.length_slider.setTickInterval(1)
|
||||||
|
self.length_slider.setFocusPolicy(Qt.NoFocus)
|
||||||
|
self.length_slider.setSingleStep(1)
|
||||||
|
self.length_slider.setRange(1, 10)
|
||||||
|
self.length_value_label = QLabel("2")
|
||||||
|
self.length_slider.setValue(2)
|
||||||
|
layout_seed.addWidget(QLabel("MaxLength(最大句长):"), 3, 0)
|
||||||
|
|
||||||
|
self.length_slider.valueChanged.connect(lambda s: self.length_value_label.setNum(s))
|
||||||
|
layout_seed.addWidget(self.length_value_label, 3, 1)
|
||||||
|
layout_seed.addWidget(self.length_slider, 3, 3)
|
||||||
|
|
||||||
gen_layout.addLayout(layout_seed)
|
gen_layout.addLayout(layout_seed)
|
||||||
|
|
||||||
self.loading_bar = QProgressBar()
|
self.loading_bar = QProgressBar()
|
||||||
@@ -591,11 +688,11 @@ class UI(QDialog):
|
|||||||
|
|
||||||
|
|
||||||
## Set the size of the window and of the elements
|
## Set the size of the window and of the elements
|
||||||
max_size = QDesktopWidget().availableGeometry(self).size() * 0.8
|
max_size = QDesktopWidget().availableGeometry(self).size() * 0.5
|
||||||
self.resize(max_size)
|
self.resize(max_size)
|
||||||
|
|
||||||
## Finalize the display
|
## Finalize the display
|
||||||
self.reset_interface()
|
self.reset_interface(vc_mode)
|
||||||
self.show()
|
self.show()
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
|
|||||||
67
train.py
Normal file
67
train.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from utils.load_yaml import HpsYaml
|
||||||
|
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||||
|
|
||||||
|
# For reproducibility, comment these may speed up training
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Arguments
|
||||||
|
parser = argparse.ArgumentParser(description=
|
||||||
|
'Training PPG2Mel VC model.')
|
||||||
|
parser.add_argument('--config', type=str,
|
||||||
|
help='Path to experiment config, e.g., config/vc.yaml')
|
||||||
|
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
||||||
|
parser.add_argument('--logdir', default='log/', type=str,
|
||||||
|
help='Logging path.', required=False)
|
||||||
|
parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str,
|
||||||
|
help='Checkpoint path.', required=False)
|
||||||
|
parser.add_argument('--outdir', default='result/', type=str,
|
||||||
|
help='Decode output path.', required=False)
|
||||||
|
parser.add_argument('--load', default=None, type=str,
|
||||||
|
help='Load pre-trained model (for training only)', required=False)
|
||||||
|
parser.add_argument('--warm_start', action='store_true',
|
||||||
|
help='Load model weights only, ignore specified layers.')
|
||||||
|
parser.add_argument('--seed', default=0, type=int,
|
||||||
|
help='Random seed for reproducable results.', required=False)
|
||||||
|
parser.add_argument('--njobs', default=8, type=int,
|
||||||
|
help='Number of threads for dataloader/decoding.', required=False)
|
||||||
|
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
||||||
|
parser.add_argument('--no-pin', action='store_true',
|
||||||
|
help='Disable pin-memory for dataloader')
|
||||||
|
parser.add_argument('--test', action='store_true', help='Test the model.')
|
||||||
|
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
||||||
|
parser.add_argument('--finetune', action='store_true', help='Finetune model')
|
||||||
|
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
|
||||||
|
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
|
||||||
|
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
|
||||||
|
|
||||||
|
###
|
||||||
|
|
||||||
|
paras = parser.parse_args()
|
||||||
|
setattr(paras, 'gpu', not paras.cpu)
|
||||||
|
setattr(paras, 'pin_memory', not paras.no_pin)
|
||||||
|
setattr(paras, 'verbose', not paras.no_msg)
|
||||||
|
# Make the config dict dot visitable
|
||||||
|
config = HpsYaml(paras.config)
|
||||||
|
|
||||||
|
np.random.seed(paras.seed)
|
||||||
|
torch.manual_seed(paras.seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(paras.seed)
|
||||||
|
|
||||||
|
print(">>> OneShot VC training ...")
|
||||||
|
mode = "train"
|
||||||
|
solver = Solver(config, paras, mode)
|
||||||
|
solver.load_data()
|
||||||
|
solver.set_model()
|
||||||
|
solver.exec()
|
||||||
|
print(">>> Oneshot VC train finished!")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
60
utils/audio_utils.py
Normal file
60
utils/audio_utils.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
from scipy.io.wavfile import read
|
||||||
|
from librosa.filters import mel as librosa_mel_fn
|
||||||
|
|
||||||
|
MAX_WAV_VALUE = 32768.0
|
||||||
|
|
||||||
|
|
||||||
|
def load_wav(full_path):
|
||||||
|
sampling_rate, data = read(full_path)
|
||||||
|
return data, sampling_rate
|
||||||
|
|
||||||
|
def _dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||||
|
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||||
|
|
||||||
|
|
||||||
|
def _spectral_normalize_torch(magnitudes):
|
||||||
|
output = _dynamic_range_compression_torch(magnitudes)
|
||||||
|
return output
|
||||||
|
|
||||||
|
mel_basis = {}
|
||||||
|
hann_window = {}
|
||||||
|
|
||||||
|
def mel_spectrogram(
|
||||||
|
y,
|
||||||
|
n_fft,
|
||||||
|
num_mels,
|
||||||
|
sampling_rate,
|
||||||
|
hop_size,
|
||||||
|
win_size,
|
||||||
|
fmin,
|
||||||
|
fmax,
|
||||||
|
center=False,
|
||||||
|
output_energy=False,
|
||||||
|
):
|
||||||
|
if torch.min(y) < -1.:
|
||||||
|
print('min value is ', torch.min(y))
|
||||||
|
if torch.max(y) > 1.:
|
||||||
|
print('max value is ', torch.max(y))
|
||||||
|
|
||||||
|
global mel_basis, hann_window
|
||||||
|
if fmax not in mel_basis:
|
||||||
|
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
||||||
|
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
||||||
|
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
||||||
|
|
||||||
|
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
||||||
|
y = y.squeeze(1)
|
||||||
|
|
||||||
|
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
||||||
|
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
|
||||||
|
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
||||||
|
mel_spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
|
||||||
|
mel_spec = _spectral_normalize_torch(mel_spec)
|
||||||
|
if output_energy:
|
||||||
|
energy = torch.norm(spec, dim=1)
|
||||||
|
return mel_spec, energy
|
||||||
|
else:
|
||||||
|
return mel_spec
|
||||||
214
utils/data_load.py
Normal file
214
utils/data_load.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from utils.f0_utils import get_cont_lf0
|
||||||
|
import resampy
|
||||||
|
from .audio_utils import MAX_WAV_VALUE, load_wav, mel_spectrogram
|
||||||
|
from librosa.util import normalize
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
SAMPLE_RATE=16000
|
||||||
|
|
||||||
|
def read_fids(fid_list_f):
|
||||||
|
with open(fid_list_f, 'r') as f:
|
||||||
|
fids = [l.strip().split()[0] for l in f if l.strip()]
|
||||||
|
return fids
|
||||||
|
|
||||||
|
class OneshotVcDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
meta_file: str,
|
||||||
|
vctk_ppg_dir: str,
|
||||||
|
libri_ppg_dir: str,
|
||||||
|
vctk_f0_dir: str,
|
||||||
|
libri_f0_dir: str,
|
||||||
|
vctk_wav_dir: str,
|
||||||
|
libri_wav_dir: str,
|
||||||
|
vctk_spk_dvec_dir: str,
|
||||||
|
libri_spk_dvec_dir: str,
|
||||||
|
min_max_norm_mel: bool = False,
|
||||||
|
mel_min: float = None,
|
||||||
|
mel_max: float = None,
|
||||||
|
ppg_file_ext: str = "ling_feat.npy",
|
||||||
|
f0_file_ext: str = "f0.npy",
|
||||||
|
wav_file_ext: str = "wav",
|
||||||
|
):
|
||||||
|
self.fid_list = read_fids(meta_file)
|
||||||
|
self.vctk_ppg_dir = vctk_ppg_dir
|
||||||
|
self.libri_ppg_dir = libri_ppg_dir
|
||||||
|
self.vctk_f0_dir = vctk_f0_dir
|
||||||
|
self.libri_f0_dir = libri_f0_dir
|
||||||
|
self.vctk_wav_dir = vctk_wav_dir
|
||||||
|
self.libri_wav_dir = libri_wav_dir
|
||||||
|
self.vctk_spk_dvec_dir = vctk_spk_dvec_dir
|
||||||
|
self.libri_spk_dvec_dir = libri_spk_dvec_dir
|
||||||
|
|
||||||
|
self.ppg_file_ext = ppg_file_ext
|
||||||
|
self.f0_file_ext = f0_file_ext
|
||||||
|
self.wav_file_ext = wav_file_ext
|
||||||
|
|
||||||
|
self.min_max_norm_mel = min_max_norm_mel
|
||||||
|
if min_max_norm_mel:
|
||||||
|
print("[INFO] Min-Max normalize Melspec.")
|
||||||
|
assert mel_min is not None
|
||||||
|
assert mel_max is not None
|
||||||
|
self.mel_max = mel_max
|
||||||
|
self.mel_min = mel_min
|
||||||
|
|
||||||
|
random.seed(1234)
|
||||||
|
random.shuffle(self.fid_list)
|
||||||
|
print(f'[INFO] Got {len(self.fid_list)} samples.')
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.fid_list)
|
||||||
|
|
||||||
|
def get_spk_dvec(self, fid):
|
||||||
|
spk_name = fid
|
||||||
|
if spk_name.startswith("p"):
|
||||||
|
spk_dvec_path = f"{self.vctk_spk_dvec_dir}{os.sep}{spk_name}.npy"
|
||||||
|
else:
|
||||||
|
spk_dvec_path = f"{self.libri_spk_dvec_dir}{os.sep}{spk_name}.npy"
|
||||||
|
return torch.from_numpy(np.load(spk_dvec_path))
|
||||||
|
|
||||||
|
def compute_mel(self, wav_path):
|
||||||
|
audio, sr = load_wav(wav_path)
|
||||||
|
if sr != SAMPLE_RATE:
|
||||||
|
audio = resampy.resample(audio, sr, SAMPLE_RATE)
|
||||||
|
audio = audio / MAX_WAV_VALUE
|
||||||
|
audio = normalize(audio) * 0.95
|
||||||
|
audio = torch.FloatTensor(audio).unsqueeze(0)
|
||||||
|
melspec = mel_spectrogram(
|
||||||
|
audio,
|
||||||
|
n_fft=1024,
|
||||||
|
num_mels=80,
|
||||||
|
sampling_rate=SAMPLE_RATE,
|
||||||
|
hop_size=160,
|
||||||
|
win_size=1024,
|
||||||
|
fmin=80,
|
||||||
|
fmax=8000,
|
||||||
|
)
|
||||||
|
return melspec.squeeze(0).numpy().T
|
||||||
|
|
||||||
|
def bin_level_min_max_norm(self, melspec):
|
||||||
|
# frequency bin level min-max normalization to [-4, 4]
|
||||||
|
mel = (melspec - self.mel_min) / (self.mel_max - self.mel_min) * 8.0 - 4.0
|
||||||
|
return np.clip(mel, -4., 4.)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
fid = self.fid_list[index]
|
||||||
|
|
||||||
|
# 1. Load features
|
||||||
|
if fid.startswith("p"):
|
||||||
|
# vctk
|
||||||
|
sub = fid.split("_")[0]
|
||||||
|
ppg = np.load(f"{self.vctk_ppg_dir}{os.sep}{fid}.{self.ppg_file_ext}")
|
||||||
|
f0 = np.load(f"{self.vctk_f0_dir}{os.sep}{fid}.{self.f0_file_ext}")
|
||||||
|
mel = self.compute_mel(f"{self.vctk_wav_dir}{os.sep}{sub}{os.sep}{fid}.{self.wav_file_ext}")
|
||||||
|
else:
|
||||||
|
# aidatatang
|
||||||
|
sub = fid[5:10]
|
||||||
|
ppg = np.load(f"{self.libri_ppg_dir}{os.sep}{fid}.{self.ppg_file_ext}")
|
||||||
|
f0 = np.load(f"{self.libri_f0_dir}{os.sep}{fid}.{self.f0_file_ext}")
|
||||||
|
mel = self.compute_mel(f"{self.libri_wav_dir}{os.sep}{sub}{os.sep}{fid}.{self.wav_file_ext}")
|
||||||
|
if self.min_max_norm_mel:
|
||||||
|
mel = self.bin_level_min_max_norm(mel)
|
||||||
|
|
||||||
|
f0, ppg, mel = self._adjust_lengths(f0, ppg, mel, fid)
|
||||||
|
spk_dvec = self.get_spk_dvec(fid)
|
||||||
|
|
||||||
|
# 2. Convert f0 to continuous log-f0 and u/v flags
|
||||||
|
uv, cont_lf0 = get_cont_lf0(f0, 10.0, False)
|
||||||
|
# cont_lf0 = (cont_lf0 - np.amin(cont_lf0)) / (np.amax(cont_lf0) - np.amin(cont_lf0))
|
||||||
|
# cont_lf0 = self.utt_mvn(cont_lf0)
|
||||||
|
lf0_uv = np.concatenate([cont_lf0[:, np.newaxis], uv[:, np.newaxis]], axis=1)
|
||||||
|
|
||||||
|
# uv, cont_f0 = convert_continuous_f0(f0)
|
||||||
|
# cont_f0 = (cont_f0 - np.amin(cont_f0)) / (np.amax(cont_f0) - np.amin(cont_f0))
|
||||||
|
# lf0_uv = np.concatenate([cont_f0[:, np.newaxis], uv[:, np.newaxis]], axis=1)
|
||||||
|
|
||||||
|
# 3. Convert numpy array to torch.tensor
|
||||||
|
ppg = torch.from_numpy(ppg)
|
||||||
|
lf0_uv = torch.from_numpy(lf0_uv)
|
||||||
|
mel = torch.from_numpy(mel)
|
||||||
|
|
||||||
|
return (ppg, lf0_uv, mel, spk_dvec, fid)
|
||||||
|
|
||||||
|
def check_lengths(self, f0, ppg, mel, fid):
|
||||||
|
LEN_THRESH = 10
|
||||||
|
assert abs(len(ppg) - len(f0)) <= LEN_THRESH, \
|
||||||
|
f"{abs(len(ppg) - len(f0))}: for file {fid}"
|
||||||
|
assert abs(len(mel) - len(f0)) <= LEN_THRESH, \
|
||||||
|
f"{abs(len(mel) - len(f0))}: for file {fid}"
|
||||||
|
|
||||||
|
def _adjust_lengths(self, f0, ppg, mel, fid):
|
||||||
|
self.check_lengths(f0, ppg, mel, fid)
|
||||||
|
min_len = min(
|
||||||
|
len(f0),
|
||||||
|
len(ppg),
|
||||||
|
len(mel),
|
||||||
|
)
|
||||||
|
f0 = f0[:min_len]
|
||||||
|
ppg = ppg[:min_len]
|
||||||
|
mel = mel[:min_len]
|
||||||
|
return f0, ppg, mel
|
||||||
|
|
||||||
|
class MultiSpkVcCollate():
|
||||||
|
"""Zero-pads model inputs and targets based on number of frames per step
|
||||||
|
"""
|
||||||
|
def __init__(self, n_frames_per_step=1, give_uttids=False,
|
||||||
|
f02ppg_length_ratio=1, use_spk_dvec=False):
|
||||||
|
self.n_frames_per_step = n_frames_per_step
|
||||||
|
self.give_uttids = give_uttids
|
||||||
|
self.f02ppg_length_ratio = f02ppg_length_ratio
|
||||||
|
self.use_spk_dvec = use_spk_dvec
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
batch_size = len(batch)
|
||||||
|
# Prepare different features
|
||||||
|
ppgs = [x[0] for x in batch]
|
||||||
|
lf0_uvs = [x[1] for x in batch]
|
||||||
|
mels = [x[2] for x in batch]
|
||||||
|
fids = [x[-1] for x in batch]
|
||||||
|
if len(batch[0]) == 5:
|
||||||
|
spk_ids = [x[3] for x in batch]
|
||||||
|
if self.use_spk_dvec:
|
||||||
|
# use d-vector
|
||||||
|
spk_ids = torch.stack(spk_ids).float()
|
||||||
|
else:
|
||||||
|
# use one-hot ids
|
||||||
|
spk_ids = torch.LongTensor(spk_ids)
|
||||||
|
# Pad features into chunk
|
||||||
|
ppg_lengths = [x.shape[0] for x in ppgs]
|
||||||
|
mel_lengths = [x.shape[0] for x in mels]
|
||||||
|
max_ppg_len = max(ppg_lengths)
|
||||||
|
max_mel_len = max(mel_lengths)
|
||||||
|
if max_mel_len % self.n_frames_per_step != 0:
|
||||||
|
max_mel_len += (self.n_frames_per_step - max_mel_len % self.n_frames_per_step)
|
||||||
|
ppg_dim = ppgs[0].shape[1]
|
||||||
|
mel_dim = mels[0].shape[1]
|
||||||
|
ppgs_padded = torch.FloatTensor(batch_size, max_ppg_len, ppg_dim).zero_()
|
||||||
|
mels_padded = torch.FloatTensor(batch_size, max_mel_len, mel_dim).zero_()
|
||||||
|
lf0_uvs_padded = torch.FloatTensor(batch_size, self.f02ppg_length_ratio * max_ppg_len, 2).zero_()
|
||||||
|
stop_tokens = torch.FloatTensor(batch_size, max_mel_len).zero_()
|
||||||
|
for i in range(batch_size):
|
||||||
|
cur_ppg_len = ppgs[i].shape[0]
|
||||||
|
cur_mel_len = mels[i].shape[0]
|
||||||
|
ppgs_padded[i, :cur_ppg_len, :] = ppgs[i]
|
||||||
|
lf0_uvs_padded[i, :self.f02ppg_length_ratio*cur_ppg_len, :] = lf0_uvs[i]
|
||||||
|
mels_padded[i, :cur_mel_len, :] = mels[i]
|
||||||
|
stop_tokens[i, cur_ppg_len-self.n_frames_per_step:] = 1
|
||||||
|
if len(batch[0]) == 5:
|
||||||
|
ret_tup = (ppgs_padded, lf0_uvs_padded, mels_padded, torch.LongTensor(ppg_lengths), \
|
||||||
|
torch.LongTensor(mel_lengths), spk_ids, stop_tokens)
|
||||||
|
if self.give_uttids:
|
||||||
|
return ret_tup + (fids, )
|
||||||
|
else:
|
||||||
|
return ret_tup
|
||||||
|
else:
|
||||||
|
ret_tup = (ppgs_padded, lf0_uvs_padded, mels_padded, torch.LongTensor(ppg_lengths), \
|
||||||
|
torch.LongTensor(mel_lengths), stop_tokens)
|
||||||
|
if self.give_uttids:
|
||||||
|
return ret_tup + (fids, )
|
||||||
|
else:
|
||||||
|
return ret_tup
|
||||||
124
utils/f0_utils.py
Normal file
124
utils/f0_utils.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import pyworld
|
||||||
|
from scipy.interpolate import interp1d
|
||||||
|
from scipy.signal import firwin, get_window, lfilter
|
||||||
|
|
||||||
|
def compute_mean_std(lf0):
|
||||||
|
nonzero_indices = np.nonzero(lf0)
|
||||||
|
mean = np.mean(lf0[nonzero_indices])
|
||||||
|
std = np.std(lf0[nonzero_indices])
|
||||||
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
|
def compute_f0(wav, sr=16000, frame_period=10.0):
|
||||||
|
"""Compute f0 from wav using pyworld harvest algorithm."""
|
||||||
|
wav = wav.astype(np.float64)
|
||||||
|
f0, _ = pyworld.harvest(
|
||||||
|
wav, sr, frame_period=frame_period, f0_floor=80.0, f0_ceil=600.0)
|
||||||
|
return f0.astype(np.float32)
|
||||||
|
|
||||||
|
def f02lf0(f0):
|
||||||
|
lf0 = f0.copy()
|
||||||
|
nonzero_indices = np.nonzero(f0)
|
||||||
|
lf0[nonzero_indices] = np.log(f0[nonzero_indices])
|
||||||
|
return lf0
|
||||||
|
|
||||||
|
def get_converted_lf0uv(
|
||||||
|
wav,
|
||||||
|
lf0_mean_trg,
|
||||||
|
lf0_std_trg,
|
||||||
|
convert=True,
|
||||||
|
):
|
||||||
|
f0_src = compute_f0(wav)
|
||||||
|
if not convert:
|
||||||
|
uv, cont_lf0 = get_cont_lf0(f0_src)
|
||||||
|
lf0_uv = np.concatenate([cont_lf0[:, np.newaxis], uv[:, np.newaxis]], axis=1)
|
||||||
|
return lf0_uv
|
||||||
|
|
||||||
|
lf0_src = f02lf0(f0_src)
|
||||||
|
lf0_mean_src, lf0_std_src = compute_mean_std(lf0_src)
|
||||||
|
|
||||||
|
lf0_vc = lf0_src.copy()
|
||||||
|
lf0_vc[lf0_src > 0.0] = (lf0_src[lf0_src > 0.0] - lf0_mean_src) / lf0_std_src * lf0_std_trg + lf0_mean_trg
|
||||||
|
f0_vc = lf0_vc.copy()
|
||||||
|
f0_vc[lf0_src > 0.0] = np.exp(lf0_vc[lf0_src > 0.0])
|
||||||
|
|
||||||
|
uv, cont_lf0_vc = get_cont_lf0(f0_vc)
|
||||||
|
lf0_uv = np.concatenate([cont_lf0_vc[:, np.newaxis], uv[:, np.newaxis]], axis=1)
|
||||||
|
return lf0_uv
|
||||||
|
|
||||||
|
def low_pass_filter(x, fs, cutoff=70, padding=True):
|
||||||
|
"""FUNCTION TO APPLY LOW PASS FILTER
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (ndarray): Waveform sequence
|
||||||
|
fs (int): Sampling frequency
|
||||||
|
cutoff (float): Cutoff frequency of low pass filter
|
||||||
|
|
||||||
|
Return:
|
||||||
|
(ndarray): Low pass filtered waveform sequence
|
||||||
|
"""
|
||||||
|
|
||||||
|
nyquist = fs // 2
|
||||||
|
norm_cutoff = cutoff / nyquist
|
||||||
|
|
||||||
|
# low cut filter
|
||||||
|
numtaps = 255
|
||||||
|
fil = firwin(numtaps, norm_cutoff)
|
||||||
|
x_pad = np.pad(x, (numtaps, numtaps), 'edge')
|
||||||
|
lpf_x = lfilter(fil, 1, x_pad)
|
||||||
|
lpf_x = lpf_x[numtaps + numtaps // 2: -numtaps // 2]
|
||||||
|
|
||||||
|
return lpf_x
|
||||||
|
|
||||||
|
|
||||||
|
def convert_continuos_f0(f0):
|
||||||
|
"""CONVERT F0 TO CONTINUOUS F0
|
||||||
|
|
||||||
|
Args:
|
||||||
|
f0 (ndarray): original f0 sequence with the shape (T)
|
||||||
|
|
||||||
|
Return:
|
||||||
|
(ndarray): continuous f0 with the shape (T)
|
||||||
|
"""
|
||||||
|
# get uv information as binary
|
||||||
|
uv = np.float32(f0 != 0)
|
||||||
|
|
||||||
|
# get start and end of f0
|
||||||
|
if (f0 == 0).all():
|
||||||
|
logging.warn("all of the f0 values are 0.")
|
||||||
|
return uv, f0
|
||||||
|
start_f0 = f0[f0 != 0][0]
|
||||||
|
end_f0 = f0[f0 != 0][-1]
|
||||||
|
|
||||||
|
# padding start and end of f0 sequence
|
||||||
|
start_idx = np.where(f0 == start_f0)[0][0]
|
||||||
|
end_idx = np.where(f0 == end_f0)[0][-1]
|
||||||
|
f0[:start_idx] = start_f0
|
||||||
|
f0[end_idx:] = end_f0
|
||||||
|
|
||||||
|
# get non-zero frame index
|
||||||
|
nz_frames = np.where(f0 != 0)[0]
|
||||||
|
|
||||||
|
# perform linear interpolation
|
||||||
|
f = interp1d(nz_frames, f0[nz_frames])
|
||||||
|
cont_f0 = f(np.arange(0, f0.shape[0]))
|
||||||
|
|
||||||
|
return uv, cont_f0
|
||||||
|
|
||||||
|
|
||||||
|
def get_cont_lf0(f0, frame_period=10.0, lpf=False):
|
||||||
|
uv, cont_f0 = convert_continuos_f0(f0)
|
||||||
|
if lpf:
|
||||||
|
cont_f0_lpf = low_pass_filter(cont_f0, int(1.0 / (frame_period * 0.001)), cutoff=20)
|
||||||
|
cont_lf0_lpf = cont_f0_lpf.copy()
|
||||||
|
nonzero_indices = np.nonzero(cont_lf0_lpf)
|
||||||
|
cont_lf0_lpf[nonzero_indices] = np.log(cont_f0_lpf[nonzero_indices])
|
||||||
|
# cont_lf0_lpf = np.log(cont_f0_lpf)
|
||||||
|
return uv, cont_lf0_lpf
|
||||||
|
else:
|
||||||
|
nonzero_indices = np.nonzero(cont_f0)
|
||||||
|
cont_lf0 = cont_f0.copy()
|
||||||
|
cont_lf0[cont_f0>0] = np.log(cont_f0[cont_f0>0])
|
||||||
|
return uv, cont_lf0
|
||||||
58
utils/load_yaml.py
Normal file
58
utils/load_yaml.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def load_hparams(filename):
|
||||||
|
stream = open(filename, 'r')
|
||||||
|
docs = yaml.safe_load_all(stream)
|
||||||
|
hparams_dict = dict()
|
||||||
|
for doc in docs:
|
||||||
|
for k, v in doc.items():
|
||||||
|
hparams_dict[k] = v
|
||||||
|
return hparams_dict
|
||||||
|
|
||||||
|
def merge_dict(user, default):
|
||||||
|
if isinstance(user, dict) and isinstance(default, dict):
|
||||||
|
for k, v in default.items():
|
||||||
|
if k not in user:
|
||||||
|
user[k] = v
|
||||||
|
else:
|
||||||
|
user[k] = merge_dict(user[k], v)
|
||||||
|
return user
|
||||||
|
|
||||||
|
class Dotdict(dict):
|
||||||
|
"""
|
||||||
|
a dictionary that supports dot notation
|
||||||
|
as well as dictionary access notation
|
||||||
|
usage: d = DotDict() or d = DotDict({'val1':'first'})
|
||||||
|
set attributes: d.val2 = 'second' or d['val2'] = 'second'
|
||||||
|
get attributes: d.val2 or d['val2']
|
||||||
|
"""
|
||||||
|
__getattr__ = dict.__getitem__
|
||||||
|
__setattr__ = dict.__setitem__
|
||||||
|
__delattr__ = dict.__delitem__
|
||||||
|
|
||||||
|
def __init__(self, dct=None):
|
||||||
|
dct = dict() if not dct else dct
|
||||||
|
for key, value in dct.items():
|
||||||
|
if hasattr(value, 'keys'):
|
||||||
|
value = Dotdict(value)
|
||||||
|
self[key] = value
|
||||||
|
|
||||||
|
class HpsYaml(Dotdict):
|
||||||
|
def __init__(self, yaml_file):
|
||||||
|
super(Dotdict, self).__init__()
|
||||||
|
hps = load_hparams(yaml_file)
|
||||||
|
hp_dict = Dotdict(hps)
|
||||||
|
for k, v in hp_dict.items():
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
__getattr__ = Dotdict.__getitem__
|
||||||
|
__setattr__ = Dotdict.__setitem__
|
||||||
|
__delattr__ = Dotdict.__delitem__
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -11,7 +11,6 @@ def check_model_paths(encoder_path: Path, synthesizer_path: Path, vocoder_path:
|
|||||||
|
|
||||||
# If none of the paths exist, remind the user to download models if needed
|
# If none of the paths exist, remind the user to download models if needed
|
||||||
print("********************************************************************************")
|
print("********************************************************************************")
|
||||||
print("Error: Model files not found. Follow these instructions to get and install the models:")
|
print("Error: Model files not found. Please download the models")
|
||||||
print("https://github.com/CorentinJ/Real-Time-Voice-Cloning/wiki/Pretrained-models")
|
|
||||||
print("********************************************************************************\n")
|
print("********************************************************************************\n")
|
||||||
quit(-1)
|
quit(-1)
|
||||||
|
|||||||
50
utils/util.py
Normal file
50
utils/util.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import time
|
||||||
|
|
||||||
|
class Timer():
|
||||||
|
''' Timer for recording training time distribution. '''
|
||||||
|
def __init__(self):
|
||||||
|
self.prev_t = time.time()
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
def set(self):
|
||||||
|
self.prev_t = time.time()
|
||||||
|
|
||||||
|
def cnt(self, mode):
|
||||||
|
self.time_table[mode] += time.time()-self.prev_t
|
||||||
|
self.set()
|
||||||
|
if mode == 'bw':
|
||||||
|
self.click += 1
|
||||||
|
|
||||||
|
def show(self):
|
||||||
|
total_time = sum(self.time_table.values())
|
||||||
|
self.time_table['avg'] = total_time/self.click
|
||||||
|
self.time_table['rd'] = 100*self.time_table['rd']/total_time
|
||||||
|
self.time_table['fw'] = 100*self.time_table['fw']/total_time
|
||||||
|
self.time_table['bw'] = 100*self.time_table['bw']/total_time
|
||||||
|
msg = '{avg:.3f} sec/step (rd {rd:.1f}% | fw {fw:.1f}% | bw {bw:.1f}%)'.format(
|
||||||
|
**self.time_table)
|
||||||
|
self.clear()
|
||||||
|
return msg
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.time_table = {'rd': 0, 'fw': 0, 'bw': 0}
|
||||||
|
self.click = 0
|
||||||
|
|
||||||
|
# Reference : https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/e2e_asr.py#L168
|
||||||
|
|
||||||
|
def human_format(num):
|
||||||
|
magnitude = 0
|
||||||
|
while num >= 1000:
|
||||||
|
magnitude += 1
|
||||||
|
num /= 1000.0
|
||||||
|
# add more suffixes if you need them
|
||||||
|
return '{:3.1f}{}'.format(num, [' ', 'K', 'M', 'G', 'T', 'P'][magnitude])
|
||||||
|
|
||||||
|
|
||||||
|
# provide easy access of attribute from dict, such abc.key
|
||||||
|
class AttrDict(dict):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(AttrDict, self).__init__(*args, **kwargs)
|
||||||
|
self.__dict__ = self
|
||||||
@@ -1,13 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
class AttrDict(dict):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(AttrDict, self).__init__(*args, **kwargs)
|
|
||||||
self.__dict__ = self
|
|
||||||
|
|
||||||
|
|
||||||
def build_env(config, config_name, path):
|
def build_env(config, config_name, path):
|
||||||
t_path = os.path.join(path, config_name)
|
t_path = os.path.join(path, config_name)
|
||||||
if config != t_path:
|
if config != t_path:
|
||||||
|
|||||||
@@ -3,14 +3,11 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
from scipy.io.wavfile import write
|
from utils.util import AttrDict
|
||||||
from vocoder.hifigan.env import AttrDict
|
|
||||||
from vocoder.hifigan.meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav
|
|
||||||
from vocoder.hifigan.models import Generator
|
from vocoder.hifigan.models import Generator
|
||||||
import soundfile as sf
|
|
||||||
|
|
||||||
|
|
||||||
generator = None # type: Generator
|
generator = None # type: Generator
|
||||||
|
output_sample_rate = None
|
||||||
_device = None
|
_device = None
|
||||||
|
|
||||||
|
|
||||||
@@ -22,16 +19,23 @@ def load_checkpoint(filepath, device):
|
|||||||
return checkpoint_dict
|
return checkpoint_dict
|
||||||
|
|
||||||
|
|
||||||
def load_model(weights_fpath, verbose=True):
|
def load_model(weights_fpath, config_fpath=None, verbose=True):
|
||||||
global generator, _device
|
global generator, _device, output_sample_rate
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print("Building hifigan")
|
print("Building hifigan")
|
||||||
|
|
||||||
with open("./vocoder/hifigan/config_16k_.json") as f:
|
if config_fpath == None:
|
||||||
|
model_config_fpaths = list(weights_fpath.parent.rglob("*.json"))
|
||||||
|
if len(model_config_fpaths) > 0:
|
||||||
|
config_fpath = model_config_fpaths[0]
|
||||||
|
else:
|
||||||
|
config_fpath = "./vocoder/hifigan/config_16k_.json"
|
||||||
|
with open(config_fpath) as f:
|
||||||
data = f.read()
|
data = f.read()
|
||||||
json_config = json.loads(data)
|
json_config = json.loads(data)
|
||||||
h = AttrDict(json_config)
|
h = AttrDict(json_config)
|
||||||
|
output_sample_rate = h.sampling_rate
|
||||||
torch.manual_seed(h.seed)
|
torch.manual_seed(h.seed)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@@ -66,5 +70,5 @@ def infer_waveform(mel, progress_callback=None):
|
|||||||
audio = y_g_hat.squeeze()
|
audio = y_g_hat.squeeze()
|
||||||
audio = audio.cpu().numpy()
|
audio = audio.cpu().numpy()
|
||||||
|
|
||||||
return audio
|
return audio, output_sample_rate
|
||||||
|
|
||||||
|
|||||||
@@ -71,6 +71,24 @@ class ResBlock2(torch.nn.Module):
|
|||||||
for l in self.convs:
|
for l in self.convs:
|
||||||
remove_weight_norm(l)
|
remove_weight_norm(l)
|
||||||
|
|
||||||
|
class InterpolationBlock(torch.nn.Module):
|
||||||
|
def __init__(self, scale_factor, mode='nearest', align_corners=None, downsample=False):
|
||||||
|
super(InterpolationBlock, self).__init__()
|
||||||
|
self.downsample = downsample
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.mode = mode
|
||||||
|
self.align_corners = align_corners
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
outputs = torch.nn.functional.interpolate(
|
||||||
|
x,
|
||||||
|
size=x.shape[-1] * self.scale_factor \
|
||||||
|
if not self.downsample else x.shape[-1] // self.scale_factor,
|
||||||
|
mode=self.mode,
|
||||||
|
align_corners=self.align_corners,
|
||||||
|
recompute_scale_factor=False
|
||||||
|
)
|
||||||
|
return outputs
|
||||||
|
|
||||||
class Generator(torch.nn.Module):
|
class Generator(torch.nn.Module):
|
||||||
def __init__(self, h):
|
def __init__(self, h):
|
||||||
@@ -82,14 +100,27 @@ class Generator(torch.nn.Module):
|
|||||||
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
||||||
|
|
||||||
self.ups = nn.ModuleList()
|
self.ups = nn.ModuleList()
|
||||||
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
# for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||||
# self.ups.append(weight_norm(
|
# # self.ups.append(weight_norm(
|
||||||
# ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
|
# # ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
|
||||||
# k, u, padding=(k-u)//2)))
|
# # k, u, padding=(k-u)//2)))
|
||||||
self.ups.append(weight_norm(ConvTranspose1d(h.upsample_initial_channel//(2**i),
|
if h.sampling_rate == 24000:
|
||||||
h.upsample_initial_channel//(2**(i+1)),
|
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||||
k, u, padding=(u//2 + u%2), output_padding=u%2)))
|
self.ups.append(
|
||||||
|
torch.nn.Sequential(
|
||||||
|
InterpolationBlock(u),
|
||||||
|
weight_norm(torch.nn.Conv1d(
|
||||||
|
h.upsample_initial_channel//(2**i),
|
||||||
|
h.upsample_initial_channel//(2**(i+1)),
|
||||||
|
k, padding=(k-1)//2,
|
||||||
|
))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||||
|
self.ups.append(weight_norm(ConvTranspose1d(h.upsample_initial_channel//(2**i),
|
||||||
|
h.upsample_initial_channel//(2**(i+1)),
|
||||||
|
k, u, padding=(u//2 + u%2), output_padding=u%2)))
|
||||||
self.resblocks = nn.ModuleList()
|
self.resblocks = nn.ModuleList()
|
||||||
for i in range(len(self.ups)):
|
for i in range(len(self.ups)):
|
||||||
ch = h.upsample_initial_channel//(2**(i+1))
|
ch = h.upsample_initial_channel//(2**(i+1))
|
||||||
@@ -121,7 +152,10 @@ class Generator(torch.nn.Module):
|
|||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
print('Removing weight norm...')
|
print('Removing weight norm...')
|
||||||
for l in self.ups:
|
for l in self.ups:
|
||||||
remove_weight_norm(l)
|
if self.h.sampling_rate == 24000:
|
||||||
|
remove_weight_norm(l[-1])
|
||||||
|
else:
|
||||||
|
remove_weight_norm(l)
|
||||||
for l in self.resblocks:
|
for l in self.resblocks:
|
||||||
l.remove_weight_norm()
|
l.remove_weight_norm()
|
||||||
remove_weight_norm(self.conv_pre)
|
remove_weight_norm(self.conv_pre)
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from torch.utils.data import DistributedSampler, DataLoader
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from torch.distributed import init_process_group
|
from torch.distributed import init_process_group
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
from vocoder.hifigan.env import AttrDict, build_env
|
|
||||||
from vocoder.hifigan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
|
from vocoder.hifigan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
|
||||||
from vocoder.hifigan.models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\
|
from vocoder.hifigan.models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\
|
||||||
discriminator_loss
|
discriminator_loss
|
||||||
@@ -23,11 +22,11 @@ torch.backends.cudnn.benchmark = True
|
|||||||
|
|
||||||
def train(rank, a, h):
|
def train(rank, a, h):
|
||||||
|
|
||||||
a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_hifigan')
|
a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_hifigan')
|
||||||
a.checkpoint_path.mkdir(exist_ok=True)
|
a.checkpoint_path.mkdir(exist_ok=True)
|
||||||
a.training_epochs = 3100
|
a.training_epochs = 3100
|
||||||
a.stdout_interval = 5
|
a.stdout_interval = 5
|
||||||
a.checkpoint_interval = 25000
|
a.checkpoint_interval = a.backup_every
|
||||||
a.summary_interval = 5000
|
a.summary_interval = 5000
|
||||||
a.validation_interval = 1000
|
a.validation_interval = 1000
|
||||||
a.fine_tuning = True
|
a.fine_tuning = True
|
||||||
@@ -185,12 +184,10 @@ def train(rank, a, h):
|
|||||||
checkpoint_path = "{}/g_{:08d}.pt".format(a.checkpoint_path, steps)
|
checkpoint_path = "{}/g_{:08d}.pt".format(a.checkpoint_path, steps)
|
||||||
save_checkpoint(checkpoint_path,
|
save_checkpoint(checkpoint_path,
|
||||||
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
|
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
|
||||||
checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
|
checkpoint_path = "{}/do_{:08d}.pt".format(a.checkpoint_path, steps)
|
||||||
save_checkpoint(checkpoint_path,
|
save_checkpoint(checkpoint_path,
|
||||||
{'mpd': (mpd.module if h.num_gpus > 1
|
{'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
|
||||||
else mpd).state_dict(),
|
'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(),
|
||||||
'msd': (msd.module if h.num_gpus > 1
|
|
||||||
else msd).state_dict(),
|
|
||||||
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
|
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
|
||||||
'epoch': epoch})
|
'epoch': epoch})
|
||||||
|
|
||||||
@@ -198,6 +195,19 @@ def train(rank, a, h):
|
|||||||
if steps % a.summary_interval == 0:
|
if steps % a.summary_interval == 0:
|
||||||
sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
|
sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
|
||||||
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
||||||
|
|
||||||
|
|
||||||
|
# save temperate hifigan model
|
||||||
|
if steps % a.save_every == 0:
|
||||||
|
checkpoint_path = "{}/g_hifigan.pt".format(a.checkpoint_path)
|
||||||
|
save_checkpoint(checkpoint_path,
|
||||||
|
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
|
||||||
|
checkpoint_path = "{}/do_hifigan.pt".format(a.checkpoint_path)
|
||||||
|
save_checkpoint(checkpoint_path,
|
||||||
|
{'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
|
||||||
|
'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(),
|
||||||
|
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
|
||||||
|
'epoch': epoch})
|
||||||
|
|
||||||
# Validation
|
# Validation
|
||||||
if steps % a.validation_interval == 0: # and steps != 0:
|
if steps % a.validation_interval == 0: # and steps != 0:
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ def save_checkpoint(filepath, obj):
|
|||||||
|
|
||||||
|
|
||||||
def scan_checkpoint(cp_dir, prefix):
|
def scan_checkpoint(cp_dir, prefix):
|
||||||
pattern = os.path.join(cp_dir, prefix + '????????')
|
pattern = os.path.join(cp_dir, prefix + 'hifigan.pt')
|
||||||
cp_list = glob.glob(pattern)
|
cp_list = glob.glob(pattern)
|
||||||
if len(cp_list) == 0:
|
if len(cp_list) == 0:
|
||||||
return None
|
return None
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user