55 Commits

Author SHA1 Message Date
babysor00
3a2d50c862 Add readme 2022-03-05 00:51:55 +08:00
babysor00
d786e78121 Add UI usage of PPG-vc 2022-03-03 23:34:47 +08:00
babysor00
6befb700e9 Fix sample issues 2022-03-02 23:15:37 +08:00
babysor00
dd3abebc4d Fix bug of preparing fid 2022-02-27 13:25:58 +08:00
babysor00
eeee32f3e3 Fix length issue 2022-02-26 17:26:27 +08:00
babysor00
8ef5e1411d Update __init__.py
Allow to gen audio
2022-02-24 09:46:24 +08:00
babysor00
20bea3546b Merge branch 'main' into ppg-vc-init 2022-02-24 00:31:13 +08:00
babysor00
0536874dec Add config file for pretrained 2022-02-23 09:37:39 +08:00
babysor00
fad5023fca FIx known issues 2022-02-20 11:56:58 +08:00
babysor00
19eaa68202 add preprocess and training 2022-02-13 11:28:41 +08:00
李子
4529479091 指定librosa版本 (#378)
* 支持data_aishell(SLR33)数据集

* 更新readme

* 指定librosa版本
2022-02-10 20:47:26 +08:00
babysor00
379fd2b9fd Init ppg extractor and ppg2mel 2022-02-09 00:44:43 +08:00
babysor00
8ad9ba2b60 change naming logic of saving trained file for synthesizer to allow shorter interval 2022-01-15 17:56:14 +08:00
D-Blue
b56ec5ee1b Fix a UserWarning (#273)
Fix a UserWarning in synthesizer/synthesizer_dataset.py, because of converting list of numpy array to torch tensor at Ln.85.
2021-12-20 20:33:12 +08:00
CrystalRays
0bc34a5bc9 Fix TypeError at line 459 in toolbox/ui.py when both PySide6(PyQt6) and PyQt5 installed (#255)
### Error Info Screenshot
![](https://cdn.jsdelivr.net/gh/CrystalRays/CDN@main/img/16389623959301638962395845.png)

### Error Reason
Matplotlib.backends.qt_compat.py decide the version of qt library according to sys.modules firstly, os.environ secondly and the sequence of PyQt6, PySide6, PyQt5, PySide 2 and etc finally. Import PyQt5 after matplotlib make that there is no PyQt5 in sys.modules so that it choose PyQt6 or PySide6 before PyQt5 if it installed.
因为Matplotlib.backends.qt_compat.py优先根据导入的库决定要使用的Python Qt的库,如果没有导入则根据环境变量PYQT_APT决定,再不济就按照PyQt6, PySide6, PyQt5, PySide 2的顺序导入已经安装的库。因为ui.py先导入matplotlib而不是PYQT5导致matplotlib在导入的库里找不到Qt的库,又没有指定环境变量,然后用户安装了Qt6的库的话就导入Qt6的库去了
2021-12-15 12:41:10 +08:00
Wings Music
875fe15069 Update readme for training encoder (#250) 2021-12-07 19:10:29 +08:00
zzxiang
4728863f9d Fix inference on cpu device (#241) 2021-11-29 21:10:07 +08:00
hertz
a4daf42868 1k steps to save tmp hifigan model (#240) 2021-11-29 21:09:54 +08:00
harian
b50c7984ab tacotron.py-Multi GPU with DataParallel (#231) 2021-11-27 20:53:08 +08:00
babysor00
26fe4a047d Differentiate GST token 2021-11-18 22:55:13 +08:00
babysor00
aff1b5313b Order of declared pytorch module matters 2021-11-17 00:12:27 +08:00
babysor
7dca74e032 Change default to use speaker embed for reference 2021-11-13 10:57:45 +08:00
babysor00
a37b26a89c 模型兼容问题加强 Compatibility Enhance of Pretrained Models and code base #209 2021-11-10 23:23:13 +08:00
babysor00
902e1eb537 Merge branch 'main' of https://github.com/babysor/Realtime-Voice-Clone-Chinese 2021-11-09 21:08:33 +08:00
babysor00
5c0e53a29a Fix #205 2021-11-09 21:08:28 +08:00
DragonDreamer
4edebdfeba 修复synthesizer/models/tacotron.Encoder注释错误 (#203)
fix Issue#202
2021-11-09 13:59:19 +08:00
babysor00
6c8f3f4515 Allow to select vocoder in web 2021-11-08 23:55:16 +08:00
babysor00
2bd323b7df Update readme 2021-11-07 21:59:03 +08:00
babysor00
3674d8b5c6 Use speaker embedding anyway even with default style 2021-11-07 21:48:15 +08:00
babysor00
80aaf32164 Add max steps control in toolbox 2021-11-06 13:27:11 +08:00
babysor00
c396792b22 Upload new models 2021-10-27 20:19:50 +08:00
babysor00
7c58fe01d1 Concat GST output instead of adding directly with original output 2021-10-23 10:28:32 +08:00
Vega
724194a4de Add code to control finetune layers (#154) 2021-10-23 10:25:43 +08:00
babysor00
31bc6656c3 Fix bug of importing GST and add more parameters in toolbox 2021-10-21 00:40:00 +08:00
洛竹
aa35fb3139 docs: this repo -> 本代码库 (#157)
Co-authored-by: 洛竹 <youngjuning@aliyun.com>
2021-10-20 22:54:31 +08:00
babysor00
727eafc51b Merge branch 'main' of https://github.com/babysor/Realtime-Voice-Clone-Chinese 2021-10-20 00:27:19 +08:00
babysor00
d328ecba81 Reconstruct UI of toolbox 2021-10-20 00:27:13 +08:00
Vega
fad574118c Update README-CN.md 2021-10-18 13:50:19 +08:00
babysor00
b0c156a537 Add new dataset support to preprocess parameter 2021-10-17 17:21:49 +08:00
Vega
724809abf4 Update README.md 2021-10-15 14:34:29 +08:00
Vega
05cd1a54ea Add new pretrain model with gst 2021-10-14 01:26:23 +08:00
李子
245099c740 支持data_aishell(SLR33)数据集 (#141)
* 支持data_aishell(SLR33)数据集

* 更新readme
2021-10-12 23:40:27 +08:00
babysor00
6dd2af49fe Merge branch 'main' of https://github.com/babysor/Realtime-Voice-Clone-Chinese 2021-10-12 20:02:05 +08:00
babysor00
8b43ec9a64 Fix bug pre-processing magicdata 2021-10-12 20:01:37 +08:00
Vega
2a99f0ff05 Add gst (#137)
* Commit with working GST

* Make it backward compatible

* Add readme
2021-10-12 19:43:29 +08:00
babysor00
a824b54122 补充预处理文档 2021-10-12 09:22:10 +08:00
weida wang
81befb91b0 Update ui.py (#136)
Add minimize and maximize button of window
2021-10-11 17:17:36 +08:00
babysor00
e2017d0314 Merge branch 'main' of https://github.com/babysor/Realtime-Voice-Clone-Chinese into main 2021-10-05 10:48:58 +08:00
babysor00
547ac816df Update demo and training param
A
2021-10-05 10:48:54 +08:00
Ji Zhang
6b4ab39601 add alternative download source for dataset (google drive) (#112) 2021-10-03 10:10:40 +08:00
babysor00
b46e7a7866 New web with selecting wav files 2021-10-01 22:13:39 +08:00
babysor00
8a384a1191 Merge branch 'main' of https://github.com/babysor/Realtime-Voice-Clone-Chinese into main 2021-10-01 09:33:31 +08:00
Nemo
11154783d8 web tool box update UI (#111)
* web tool box update UI

* update img
2021-10-01 00:32:29 +08:00
AkifSaeed20
d52db0444e Update launch.json (#109) 2021-10-01 00:22:43 +08:00
babysor00
790d11a58b Allow to train encoder 2021-10-01 00:01:33 +08:00
81 changed files with 7006 additions and 311 deletions

9
.gitignore vendored
View File

@@ -15,7 +15,8 @@
*.toc
*.wav
*.sh
synthesizer/saved_models/*
vocoder/saved_models/*
cp_hifigan/*
!vocoder/saved_models/pretrained/*
*/saved_models
!vocoder/saved_models/pretrained/**
!encoder/saved_models/pretrained.pt
wavs
log

36
.vscode/launch.json vendored
View File

@@ -17,7 +17,7 @@
"request": "launch",
"program": "vocoder_preprocess.py",
"console": "integratedTerminal",
"args": ["..\\..\\chs1"]
"args": ["..\\audiodata"]
},
{
"name": "Python: Vocoder Train",
@@ -25,15 +25,41 @@
"request": "launch",
"program": "vocoder_train.py",
"console": "integratedTerminal",
"args": ["dev", "..\\..\\chs1"]
"args": ["dev", "..\\audiodata"]
},
{
"name": "Python: demo box",
"name": "Python: Demo Box",
"type": "python",
"request": "launch",
"program": "demo_toolbox.py",
"console": "integratedTerminal",
"args": ["-d", "..\\..\\chs"]
}
"args": ["-d","..\\audiodata"]
},
{
"name": "Python: Demo Box VC",
"type": "python",
"request": "launch",
"program": "demo_toolbox.py",
"console": "integratedTerminal",
"args": ["-d","..\\audiodata","-vc"]
},
{
"name": "Python: Synth Train",
"type": "python",
"request": "launch",
"program": "synthesizer_train.py",
"console": "integratedTerminal",
"args": ["my_run", "..\\"]
},
{
"name": "Python: PPG Convert",
"type": "python",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": ["-c", ".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml",
"-m", ".\\ppg2mel\\saved_models\\best_loss_step_304000.pth", "--wav_dir", ".\\wavs\\input", "--ref_wav_path", ".\\wavs\\pkq.mp3", "-o", ".\\wavs\\output\\"
]
},
]
}

View File

@@ -5,10 +5,10 @@
### [English](README.md) | 中文
### [DEMO VIDEO](https://www.bilibili.com/video/BV1sA411P7wM/)
### [DEMO VIDEO](https://www.bilibili.com/video/BV17Q4y1B7mY/) | [Wiki教程](https://github.com/babysor/MockingBird/wiki/Quick-Start-(Newbie)) [训练教程](https://vaj2fgg8yn.feishu.cn/docs/doccn7kAbr3SJz0KM0SIDJ0Xnhd)
## 特性
🌍 **中文** 支持普通话并使用多种中文数据集进行测试aidatatang_200zh, magicdata, aishell3 biaobeiMozillaCommonVoice 等
🌍 **中文** 支持普通话并使用多种中文数据集进行测试aidatatang_200zh, magicdata, aishell3, biaobei, MozillaCommonVoice, data_aishell
🤩 **PyTorch** 适用于 pytorch已在 1.9.0 版本(最新于 2021 年 8 月中测试GPU Tesla T4 和 GTX 2060
@@ -18,6 +18,7 @@
🌍 **Webserver Ready** 可伺服你的训练结果,供远程调用
## 开始
### 1. 安装要求
> 按照原始存储库测试您是否已准备好所有环境。
**Python 3.7 或更高版本** 需要运行工具箱。
@@ -31,11 +32,21 @@
### 2. 准备预训练模型
考虑训练您自己专属的模型或者下载社区他人训练好的模型:
> 近期创建了[知乎专题](https://www.zhihu.com/column/c_1425605280340504576) 将不定期更新炼丹小技巧or心得也欢迎提问
#### 2.1 使用数据集自己训练合成器模型与2.2二选一)
#### 2.1 使用数据集自己训练encoder模型 (可选)
* 进行音频和梅尔频谱图预处理:
`python encoder_preprocess.py <datasets_root>`
使用`-d {dataset}` 指定数据集,支持 librispeech_othervoxceleb1aidatatang_200zh使用逗号分割处理多数据集。
* 训练encoder: `python encoder_train.py my_run <datasets_root>/SV2TTS/encoder`
> 训练encoder使用了visdom。你可以加上`-no_visdom`禁用visdom但是有可视化会更好。在单独的命令行/进程中运行"visdom"来启动visdom服务器。
#### 2.2 使用数据集自己训练合成器模型与2.3二选一)
* 下载 数据集并解压:确保您可以访问 *train* 文件夹中的所有音频文件(如.wav
* 进行音频和梅尔频谱图预处理:
`python pre.py <datasets_root>`
传入参数 --dataset `{dataset}` 支持 aidatatang_200zh, magicdata, aishell3
`python pre.py <datasets_root> -d {dataset} -n {number}`
可传入参数
* `-d {dataset}` 指定数据集,支持 aidatatang_200zh, magicdata, aishell3, data_aishell, 不传默认为aidatatang_200zh
* `-n {number}` 指定并行数CPU 11770k + 32GB实测10没有问题
> 假如你下载的 `aidatatang_200zh`文件放在D盘`train`文件路径为 `D:\data\aidatatang_200zh\corpus\train` , 你的`datasets_root`就是 `D:\data\`
* 训练合成器:
@@ -43,16 +54,17 @@
* 当您在训练文件夹 *synthesizer/saved_models/* 中看到注意线显示和损失满足您的需要时,请转到`启动程序`一步。
#### 2.2使用社区预先训练好的合成器与2.1二选一)
#### 2.3使用社区预先训练好的合成器与2.2二选一)
> 当实在没有设备或者不想慢慢调试,可以使用社区贡献的模型(欢迎持续分享):
| 作者 | 下载链接 | 效果预览 | 信息 |
| --- | ----------- | ----- | ----- |
| 作者 | https://pan.baidu.com/s/1VHSKIbxXQejtxi2at9IrpA [百度盘链接](https://pan.baidu.com/s/1VHSKIbxXQejtxi2at9IrpA ) 提取码i183 | | 200k steps 只用aidatatang_200zh
|@FawenYo | https://drive.google.com/file/d/1H-YGOUHpmqKxJ9FRc6vAjPuqQki24UbC/view?usp=sharing [百度盘链接](https://pan.baidu.com/s/1vSYXO4wsLyjnF3Unl-Xoxg) 提取码1024 | [input](https://github.com/babysor/MockingBird/wiki/audio/self_test.mp3) [output](https://github.com/babysor/MockingBird/wiki/audio/export.wav) | 200k steps 台湾口音
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ 提取码:2021 | https://www.bilibili.com/video/BV1uh411B7AD/ | 150k steps 旧版需根据[issue](https://github.com/babysor/MockingBird/issues/37)修复
| 作者 | https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g [百度盘链接](https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g) 4j5d | | 75k steps 用3个开源数据集混合训练
| 作者 | https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw [百度盘链接](https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw) 提取码om7f | | 25k steps 用3个开源数据集混合训练, 切换到tag v0.0.1使用
|@FawenYo | https://drive.google.com/file/d/1H-YGOUHpmqKxJ9FRc6vAjPuqQki24UbC/view?usp=sharing [百度盘链接](https://pan.baidu.com/s/1vSYXO4wsLyjnF3Unl-Xoxg) 提取码:1024 | [input](https://github.com/babysor/MockingBird/wiki/audio/self_test.mp3) [output](https://github.com/babysor/MockingBird/wiki/audio/export.wav) | 200k steps 台湾口音需切换到tag v0.0.1使用
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ 提取码2021 | https://www.bilibili.com/video/BV1uh411B7AD/ | 150k steps 注意:根据[issue](https://github.com/babysor/MockingBird/issues/37)修复 并切换到tag v0.0.1使用
#### 2.3训练声码器 (可选)
#### 2.4训练声码器 (可选)
对效果影响不大已经预置3款如果希望自己训练可以参考以下命令。
* 预处理数据:
`python vocoder_preprocess.py <datasets_root> -m <synthesizer_model_path>`
@@ -73,7 +85,7 @@
### 3.1 启动Web程序
`python web.py`
运行成功后在浏览器打开地址, 默认为 `http://localhost:8080`
<img width="578" alt="bd64cd80385754afa599e3840504f45" src="https://user-images.githubusercontent.com/7423248/134275205-c95e6bd8-4f41-4eb5-9143-0390627baee1.png">
![123](https://user-images.githubusercontent.com/12797292/135494044-ae59181c-fe3a-406f-9c7d-d21d12fdb4cb.png)
> 目前界面比较buggy,
> * 第一次点击`录制`要等待几秒浏览器正常启动录音,否则会有重音
> * 录制结束不要再点`录制`而是`停止`
@@ -86,48 +98,54 @@
<img width="1042" alt="d48ea37adf3660e657cfb047c10edbc" src="https://user-images.githubusercontent.com/7423248/134275227-c1ddf154-f118-4b77-8949-8c4c7daf25f0.png">
## 文件结构(目标读者:开发者)
```
├─archived_untest_files 废弃文件
├─encoder encoder模型
│ ├─data_objects
└─saved_models 预训练好的模型
├─samples 样例语音
├─synthesizer synthesizer模型
│ ├─models
│ ├─saved_models 预训练好的模型
│ └─utils 工具类库
├─toolbox 图形化工具箱
├─utils 工具类库
├─vocoder vocoder模型目前包含hifi-gan、wavrnn
│ ├─hifigan
│ ├─saved_models 预训练好的模型
│ └─wavernn
└─web
├─api
│ └─Web端接口
├─config
│ └─ Web端配置文件
├─static 前端静态脚本
│ └─js
├─templates 前端模板
└─__init__.py Web端入口文件
```
### 4. 番外语音转换Voice Conversion(PPG based)
想像柯南拿着变声器然后发出毛利小五郎的声音吗本项目现基于PPG-VC引入额外两个模块PPG extractor + PPG2Mel, 可以实现变声功能。(文档不全,尤其是训练部分,正在努力补充中)
#### 4.0 准备环境
* 确保项目以上环境已经安装ok运行`pip install -r requirements.txt` 来安装剩余的必要包。
* 下载以下模型
* 24K采样率专用的vocoderhifigan*vocoder\saved_mode\xxx*
* 预训练的ppg特征encoder(ppg_extractor)到 *ppg_extractor\saved_mode\xxx*
* 预训练的PPG2Mel到 *ppg2mel\saved_mode\xxx*
#### 4.1 使用数据集自己训练PPG2Mel模型 (可选)
* 下载aidatatang_200zh数据集并解压确保您可以访问 *train* 文件夹中的所有音频文件(如.wav
* 进行音频和梅尔频谱图预处理:
`python pre4ppg.py <datasets_root> -d {dataset} -n {number}`
可传入参数:
* `-d {dataset}` 指定数据集,支持 aidatatang_200zh, 不传默认为aidatatang_200zh
* `-n {number}` 指定并行数CPU 11770k在8的情况下需要运行12到18小时待优化
> 假如你下载的 `aidatatang_200zh`文件放在D盘`train`文件路径为 `D:\data\aidatatang_200zh\corpus\train` , 你的`datasets_root`就是 `D:\data\`
* 训练合成器, 注意在上一步先下载好`ppg2mel.yaml`, 修改里面的地址指向预训练好的文件夹:
`python ppg2mel_train.py --config .\ppg2mel\saved_models\ppg2mel.yaml --oneshotvc `
* 如果想要继续上一次的训练,可以通过`--load .\ppg2mel\saved_models\<old_pt_file>` 参数指定一个预训练模型文件
#### 4.2 启动工具箱VC模式
您可以尝试使用以下命令:
`python demo_toolbox.py vc -d <datasets_root>`
> 请指定一个可用的数据集文件路径,如果有支持的数据集则会自动加载供调试,也同时会作为手动录制音频的存储目录。
## 引用及论文
> 该库一开始从仅支持英语的[Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) 分叉出来的,鸣谢作者。
| URL | Designation | 标题 | 实现源码 |
| --- | ----------- | ----- | --------------------- |
| [1803.09017](https://arxiv.org/abs/1803.09017) | GlobalStyleToken (synthesizer)| Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis | 本代码库 |
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | 本代码库 |
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | **SV2TTS** | **Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis** | This repo |
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | SV2TTS | Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis | 本代码库 |
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|[1703.10135](https://arxiv.org/pdf/1703.10135.pdf) | Tacotron (synthesizer) | Tacotron: Towards End-to-End Speech Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN)
|[1710.10467](https://arxiv.org/pdf/1710.10467.pdf) | GE2E (encoder)| Generalized End-To-End Loss for Speaker Verification | 本代码库 |
## 常見問題(FQ&A)
#### 1.數據集哪裡下載?
[aidatatang_200zh](http://www.openslr.org/62/)、[magicdata](http://www.openslr.org/68/)、[aishell3](http://www.openslr.org/93/)
| 数据集 | OpenSLR地址 | 其他源 (Google Drive, Baidu网盘等) |
| --- | ----------- | ---------------|
| aidatatang_200zh | [OpenSLR](http://www.openslr.org/62/) | [Google Drive](https://drive.google.com/file/d/110A11KZoVe7vy6kXlLb6zVPLb_J91I_t/view?usp=sharing) |
| magicdata | [OpenSLR](http://www.openslr.org/68/) | [Google Drive (Dev set)](https://drive.google.com/file/d/1g5bWRUSNH68ycC6eNvtwh07nX3QhOOlo/view?usp=sharing) |
| aishell3 | [OpenSLR](https://www.openslr.org/93/) | [Google Drive](https://drive.google.com/file/d/1shYp_o4Z0X0cZSKQDtFirct2luFUwKzZ/view?usp=sharing) |
| data_aishell | [OpenSLR](https://www.openslr.org/33/) | |
> 解壓 aidatatang_200zh 後,還需將 `aidatatang_200zh\corpus\train`下的檔案全選解壓縮
#### 2.`<datasets_root>`是什麼意思?

View File

@@ -6,7 +6,7 @@
> English | [中文](README-CN.md)
## Features
🌍 **Chinese** supported mandarin and tested with multiple datasets: aidatatang_200zh, magicdata, aishell3, and etc.
🌍 **Chinese** supported mandarin and tested with multiple datasets: aidatatang_200zh, magicdata, aishell3, data_aishell, and etc.
🤩 **PyTorch** worked for pytorch, tested in version of 1.9.0(latest in August 2021), with GPU Tesla T4 and GTX 2060
@@ -16,7 +16,7 @@
🌍 **Webserver Ready** to serve your result with remote calling
### [DEMO VIDEO](https://www.bilibili.com/video/BV1sA411P7wM/)
### [DEMO VIDEO](https://www.bilibili.com/video/BV17Q4y1B7mY/)
## Quick Start
@@ -32,27 +32,37 @@
> Note that we are using the pretrained encoder/vocoder but synthesizer, since the original model is incompatible with the Chinese sympols. It means the demo_cli is not working at this moment.
### 2. Prepare your models
You can either train your models or use existing ones:
#### 2.1. Train synthesizer with your dataset
#### 2.1 Train encoder with your dataset (Optional)
* Preprocess with the audios and the mel spectrograms:
`python encoder_preprocess.py <datasets_root>` Allowing parameter `--dataset {dataset}` to support the datasets you want to preprocess. Only the train set of these datasets will be used. Possible names: librispeech_other, voxceleb1, voxceleb2. Use comma to sperate multiple datasets.
* Train the encoder: `python encoder_train.py my_run <datasets_root>/SV2TTS/encoder`
> For training, the encoder uses visdom. You can disable it with `--no_visdom`, but it's nice to have. Run "visdom" in a separate CLI/process to start your visdom server.
#### 2.2 Train synthesizer with your dataset
* Download dataset and unzip: make sure you can access all .wav in folder
* Preprocess with the audios and the mel spectrograms:
`python pre.py <datasets_root>`
Allowing parameter `--dataset {dataset}` to support aidatatang_200zh, magicdata, aishell3, etc.
Allowing parameter `--dataset {dataset}` to support aidatatang_200zh, magicdata, aishell3, data_aishell, etc.If this parameter is not passed, the default dataset will be aidatatang_200zh.
* Train the synthesizer:
`python synthesizer_train.py mandarin <datasets_root>/SV2TTS/synthesizer`
* Go to next step when you see attention line show and loss meet your need in training folder *synthesizer/saved_models/*.
#### 2.2 Use pretrained model of synthesizer
#### 2.3 Use pretrained model of synthesizer
> Thanks to the community, some models will be shared:
| author | Download link | Preview Video | Info |
| --- | ----------- | ----- |----- |
| @myself | https://pan.baidu.com/s/1VHSKIbxXQejtxi2at9IrpA [Baidu](https://pan.baidu.com/s/1VHSKIbxXQejtxi2at9IrpA ) codei183 | | 200k steps only trained by aidatatang_200zh
|@FawenYo | https://drive.google.com/file/d/1H-YGOUHpmqKxJ9FRc6vAjPuqQki24UbC/view?usp=sharing [Baidu Pan](https://pan.baidu.com/s/1vSYXO4wsLyjnF3Unl-Xoxg) Code1024 | [input](https://github.com/babysor/MockingBird/wiki/audio/self_test.mp3) [output](https://github.com/babysor/MockingBird/wiki/audio/export.wav) | 200k steps with local accent of Taiwan
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ code2021 | https://www.bilibili.com/video/BV1uh411B7AD/
| @author | https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g [Baidu](https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g) 4j5d | | 75k steps trained by multiple datasets
| @author | https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw [Baidu](https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw) codeom7f | | 25k steps trained by multiple datasets, only works under version 0.0.1
|@FawenYo | https://drive.google.com/file/d/1H-YGOUHpmqKxJ9FRc6vAjPuqQki24UbC/view?usp=sharing https://u.teknik.io/AYxWf.pt | [input](https://github.com/babysor/MockingBird/wiki/audio/self_test.mp3) [output](https://github.com/babysor/MockingBird/wiki/audio/export.wav) | 200k steps with local accent of Taiwan, only works under version 0.0.1
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ code2021 | https://www.bilibili.com/video/BV1uh411B7AD/ | only works under version 0.0.1
#### 2.3 Train vocoder (Optional)
#### 2.4 Train vocoder (Optional)
> note: vocoder has little difference in effect, so you may not need to train a new one.
* Preprocess the data:
`python vocoder_preprocess.py <datasets_root> -m <synthesizer_model_path>`
@@ -77,6 +87,7 @@ You can then try the toolbox:
| URL | Designation | Title | Implementation source |
| --- | ----------- | ----- | --------------------- |
| [1803.09017](https://arxiv.org/abs/1803.09017) | GlobalStyleToken (synthesizer)| Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis | This repo |
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | This repo |
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | **SV2TTS** | **Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis** | This repo |
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
@@ -85,7 +96,12 @@ You can then try the toolbox:
## F Q&A
#### 1.Where can I download the dataset?
[aidatatang_200zh](http://www.openslr.org/62/)、[magicdata](http://www.openslr.org/68/)、[aishell3](http://www.openslr.org/93/)
| Dataset | Original Source | Alternative Sources |
| --- | ----------- | ---------------|
| aidatatang_200zh | [OpenSLR](http://www.openslr.org/62/) | [Google Drive](https://drive.google.com/file/d/110A11KZoVe7vy6kXlLb6zVPLb_J91I_t/view?usp=sharing) |
| magicdata | [OpenSLR](http://www.openslr.org/68/) | [Google Drive (Dev set)](https://drive.google.com/file/d/1g5bWRUSNH68ycC6eNvtwh07nX3QhOOlo/view?usp=sharing) |
| aishell3 | [OpenSLR](https://www.openslr.org/93/) | [Google Drive](https://drive.google.com/file/d/1shYp_o4Z0X0cZSKQDtFirct2luFUwKzZ/view?usp=sharing) |
| data_aishell | [OpenSLR](https://www.openslr.org/33/) | |
> After unzip aidatatang_200zh, you need to unzip all the files under `aidatatang_200zh\corpus\train`
#### 2.What is`<datasets_root>`?

View File

@@ -15,12 +15,18 @@ if __name__ == '__main__':
parser.add_argument("-d", "--datasets_root", type=Path, help= \
"Path to the directory containing your datasets. See toolbox/__init__.py for a list of "
"supported datasets.", default=None)
parser.add_argument("-vc", "--vc_mode", action="store_true",
help="Voice Conversion Mode(PPG based)")
parser.add_argument("-e", "--enc_models_dir", type=Path, default="encoder/saved_models",
help="Directory containing saved encoder models")
parser.add_argument("-s", "--syn_models_dir", type=Path, default="synthesizer/saved_models",
help="Directory containing saved synthesizer models")
parser.add_argument("-v", "--voc_models_dir", type=Path, default="vocoder/saved_models",
help="Directory containing saved vocoder models")
parser.add_argument("-ex", "--extractor_models_dir", type=Path, default="ppg_extractor/saved_models",
help="Directory containing saved extrator models")
parser.add_argument("-cv", "--convertor_models_dir", type=Path, default="ppg2mel/saved_models",
help="Directory containing saved convert models")
parser.add_argument("--cpu", action="store_true", help=\
"If True, processing is done on CPU, even when a GPU is available.")
parser.add_argument("--seed", type=int, default=None, help=\

View File

@@ -34,8 +34,16 @@ def load_model(weights_fpath: Path, device=None):
_model.load_state_dict(checkpoint["model_state"])
_model.eval()
print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
return _model
def set_model(model, device=None):
global _model, _device
_model = model
if device is None:
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_device = device
_model.to(device)
def is_loaded():
return _model is not None
@@ -57,7 +65,7 @@ def embed_frames_batch(frames_batch):
def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
min_pad_coverage=0.75, overlap=0.5):
min_pad_coverage=0.75, overlap=0.5, rate=None):
"""
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
@@ -85,9 +93,18 @@ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_fram
assert 0 <= overlap < 1
assert 0 < min_pad_coverage <= 1
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
if rate != None:
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
else:
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
assert 0 < frame_step, "The rate is too high"
assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
(sampling_rate / (samples_per_frame * partials_n_frames))
# Compute the slices
wav_slices, mel_slices = [], []

View File

@@ -117,6 +117,15 @@ def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir,
logger.finalize()
print("Done preprocessing %s.\n" % dataset_name)
def preprocess_aidatatang_200zh(datasets_root: Path, out_dir: Path, skip_existing=False):
dataset_name = "aidatatang_200zh"
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
if not dataset_root:
return
# Preprocess all speakers
speaker_dirs = list(dataset_root.joinpath("corpus", "train").glob("*"))
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
skip_existing, logger)
def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
for dataset_name in librispeech_datasets["train"]["other"]:

Binary file not shown.

View File

@@ -1,4 +1,4 @@
from encoder.preprocess import preprocess_librispeech, preprocess_voxceleb1, preprocess_voxceleb2
from encoder.preprocess import preprocess_librispeech, preprocess_voxceleb1, preprocess_voxceleb2, preprocess_aidatatang_200zh
from utils.argutils import print_args
from pathlib import Path
import argparse
@@ -10,17 +10,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Preprocesses audio files from datasets, encodes them as mel spectrograms and "
"writes them to the disk. This will allow you to train the encoder. The "
"datasets required are at least one of VoxCeleb1, VoxCeleb2 and LibriSpeech. "
"Ideally, you should have all three. You should extract them as they are "
"after having downloaded them and put them in a same directory, e.g.:\n"
"-[datasets_root]\n"
" -LibriSpeech\n"
" -train-other-500\n"
" -VoxCeleb1\n"
" -wav\n"
" -vox1_meta.csv\n"
" -VoxCeleb2\n"
" -dev",
"datasets required are at least one of LibriSpeech, VoxCeleb1, VoxCeleb2, aidatatang_200zh. ",
formatter_class=MyFormatter
)
parser.add_argument("datasets_root", type=Path, help=\
@@ -29,7 +19,7 @@ if __name__ == "__main__":
"Path to the output directory that will contain the mel spectrograms. If left out, "
"defaults to <datasets_root>/SV2TTS/encoder/")
parser.add_argument("-d", "--datasets", type=str,
default="librispeech_other,voxceleb1,voxceleb2", help=\
default="librispeech_other,voxceleb1,aidatatang_200zh", help=\
"Comma-separated list of the name of the datasets you want to preprocess. Only the train "
"set of these datasets will be used. Possible names: librispeech_other, voxceleb1, "
"voxceleb2.")
@@ -63,6 +53,7 @@ if __name__ == "__main__":
"librispeech_other": preprocess_librispeech,
"voxceleb1": preprocess_voxceleb1,
"voxceleb2": preprocess_voxceleb2,
"aidatatang_200zh": preprocess_aidatatang_200zh,
}
args = vars(args)
for dataset in args.pop("datasets"):

206
ppg2mel/__init__.py Normal file
View File

@@ -0,0 +1,206 @@
#!/usr/bin/env python3
# Copyright 2020 Songxiang Liu
# Apache 2.0
from typing import List
import torch
import torch.nn.functional as F
import numpy as np
from .utils.abs_model import AbsMelDecoder
from .rnn_decoder_mol import Decoder
from .utils.cnn_postnet import Postnet
from .utils.vc_utils import get_mask_from_lengths
from utils.load_yaml import HpsYaml
class MelDecoderMOLv2(AbsMelDecoder):
"""Use an encoder to preprocess ppg."""
def __init__(
self,
num_speakers: int,
spk_embed_dim: int,
bottle_neck_feature_dim: int,
encoder_dim: int = 256,
encoder_downsample_rates: List = [2, 2],
attention_rnn_dim: int = 512,
decoder_rnn_dim: int = 512,
num_decoder_rnn_layer: int = 1,
concat_context_to_last: bool = True,
prenet_dims: List = [256, 128],
num_mixtures: int = 5,
frames_per_step: int = 2,
mask_padding: bool = True,
):
super().__init__()
self.mask_padding = mask_padding
self.bottle_neck_feature_dim = bottle_neck_feature_dim
self.num_mels = 80
self.encoder_down_factor=np.cumprod(encoder_downsample_rates)[-1]
self.frames_per_step = frames_per_step
self.use_spk_dvec = True
input_dim = bottle_neck_feature_dim
# Downsampling convolution
self.bnf_prenet = torch.nn.Sequential(
torch.nn.Conv1d(input_dim, encoder_dim, kernel_size=1, bias=False),
torch.nn.LeakyReLU(0.1),
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
torch.nn.Conv1d(
encoder_dim, encoder_dim,
kernel_size=2*encoder_downsample_rates[0],
stride=encoder_downsample_rates[0],
padding=encoder_downsample_rates[0]//2,
),
torch.nn.LeakyReLU(0.1),
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
torch.nn.Conv1d(
encoder_dim, encoder_dim,
kernel_size=2*encoder_downsample_rates[1],
stride=encoder_downsample_rates[1],
padding=encoder_downsample_rates[1]//2,
),
torch.nn.LeakyReLU(0.1),
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
)
decoder_enc_dim = encoder_dim
self.pitch_convs = torch.nn.Sequential(
torch.nn.Conv1d(2, encoder_dim, kernel_size=1, bias=False),
torch.nn.LeakyReLU(0.1),
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
torch.nn.Conv1d(
encoder_dim, encoder_dim,
kernel_size=2*encoder_downsample_rates[0],
stride=encoder_downsample_rates[0],
padding=encoder_downsample_rates[0]//2,
),
torch.nn.LeakyReLU(0.1),
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
torch.nn.Conv1d(
encoder_dim, encoder_dim,
kernel_size=2*encoder_downsample_rates[1],
stride=encoder_downsample_rates[1],
padding=encoder_downsample_rates[1]//2,
),
torch.nn.LeakyReLU(0.1),
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
)
self.reduce_proj = torch.nn.Linear(encoder_dim + spk_embed_dim, encoder_dim)
# Decoder
self.decoder = Decoder(
enc_dim=decoder_enc_dim,
num_mels=self.num_mels,
frames_per_step=frames_per_step,
attention_rnn_dim=attention_rnn_dim,
decoder_rnn_dim=decoder_rnn_dim,
num_decoder_rnn_layer=num_decoder_rnn_layer,
prenet_dims=prenet_dims,
num_mixtures=num_mixtures,
use_stop_tokens=True,
concat_context_to_last=concat_context_to_last,
encoder_down_factor=self.encoder_down_factor,
)
# Mel-Spec Postnet: some residual CNN layers
self.postnet = Postnet()
def parse_output(self, outputs, output_lengths=None):
if self.mask_padding and output_lengths is not None:
mask = ~get_mask_from_lengths(output_lengths, outputs[0].size(1))
mask = mask.unsqueeze(2).expand(mask.size(0), mask.size(1), self.num_mels)
outputs[0].data.masked_fill_(mask, 0.0)
outputs[1].data.masked_fill_(mask, 0.0)
return outputs
def forward(
self,
bottle_neck_features: torch.Tensor,
feature_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
logf0_uv: torch.Tensor = None,
spembs: torch.Tensor = None,
output_att_ws: bool = False,
):
decoder_inputs = self.bnf_prenet(
bottle_neck_features.transpose(1, 2)
).transpose(1, 2)
logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
decoder_inputs = decoder_inputs + logf0_uv
assert spembs is not None
spk_embeds = F.normalize(
spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
decoder_inputs = torch.cat([decoder_inputs, spk_embeds], dim=-1)
decoder_inputs = self.reduce_proj(decoder_inputs)
# (B, num_mels, T_dec)
T_dec = torch.div(feature_lengths, int(self.encoder_down_factor), rounding_mode='floor')
mel_outputs, predicted_stop, alignments = self.decoder(
decoder_inputs, speech, T_dec)
## Post-processing
mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
if output_att_ws:
return self.parse_output(
[mel_outputs, mel_outputs_postnet, predicted_stop, alignments], speech_lengths)
else:
return self.parse_output(
[mel_outputs, mel_outputs_postnet, predicted_stop], speech_lengths)
# return mel_outputs, mel_outputs_postnet
def inference(
self,
bottle_neck_features: torch.Tensor,
logf0_uv: torch.Tensor = None,
spembs: torch.Tensor = None,
):
decoder_inputs = self.bnf_prenet(bottle_neck_features.transpose(1, 2)).transpose(1, 2)
logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
decoder_inputs = decoder_inputs + logf0_uv
assert spembs is not None
spk_embeds = F.normalize(
spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
bottle_neck_features = torch.cat([decoder_inputs, spk_embeds], dim=-1)
bottle_neck_features = self.reduce_proj(bottle_neck_features)
## Decoder
if bottle_neck_features.size(0) > 1:
mel_outputs, alignments = self.decoder.inference_batched(bottle_neck_features)
else:
mel_outputs, alignments = self.decoder.inference(bottle_neck_features,)
## Post-processing
mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
# outputs = mel_outputs_postnet[0]
return mel_outputs[0], mel_outputs_postnet[0], alignments[0]
def load_model(train_config, model_file, device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_config = HpsYaml(train_config)
ppg2mel_model = MelDecoderMOLv2(
**model_config["model"]
).to(device)
ckpt = torch.load(model_file, map_location=device)
ppg2mel_model.load_state_dict(ckpt["model"])
ppg2mel_model.eval()
return ppg2mel_model

112
ppg2mel/preprocess.py Normal file
View File

@@ -0,0 +1,112 @@
import os
import torch
import numpy as np
from tqdm import tqdm
from pathlib import Path
import soundfile
import resampy
from ppg_extractor import load_model
import encoder.inference as Encoder
from encoder.audio import preprocess_wav
from encoder import audio
from utils.f0_utils import compute_f0
from torch.multiprocessing import Pool, cpu_count
from functools import partial
SAMPLE_RATE=16000
def _compute_bnf(
wav: any,
output_fpath: str,
device: torch.device,
ppg_model_local: any,
):
"""
Compute CTC-Attention Seq2seq ASR encoder bottle-neck features (BNF).
"""
ppg_model_local.to(device)
wav_tensor = torch.from_numpy(wav).float().to(device).unsqueeze(0)
wav_length = torch.LongTensor([wav.shape[0]]).to(device)
with torch.no_grad():
bnf = ppg_model_local(wav_tensor, wav_length)
bnf_npy = bnf.squeeze(0).cpu().numpy()
np.save(output_fpath, bnf_npy, allow_pickle=False)
return bnf_npy, len(bnf_npy)
def _compute_f0_from_wav(wav, output_fpath):
"""Compute merged f0 values."""
f0 = compute_f0(wav, SAMPLE_RATE)
np.save(output_fpath, f0, allow_pickle=False)
return f0, len(f0)
def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device):
Encoder.set_model(encoder_model_local)
# Compute where to split the utterance into partials and pad if necessary
wave_slices, mel_slices = Encoder.compute_partial_slices(len(wav), rate=1.3, min_pad_coverage=0.75)
max_wave_length = wave_slices[-1].stop
if max_wave_length >= len(wav):
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
# Split the utterance into partials
frames = audio.wav_to_mel_spectrogram(wav)
frames_batch = np.array([frames[s] for s in mel_slices])
partial_embeds = Encoder.embed_frames_batch(frames_batch)
# Compute the utterance embedding from the partial embeddings
raw_embed = np.mean(partial_embeds, axis=0)
embed = raw_embed / np.linalg.norm(raw_embed, 2)
np.save(output_fpath, embed, allow_pickle=False)
return embed, len(embed)
def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local):
# wav = preprocess_wav(wav_path)
# try:
wav, sr = soundfile.read(wav_path)
if len(wav) < sr:
return None, sr, len(wav)
if sr != SAMPLE_RATE:
wav = resampy.resample(wav, sr, SAMPLE_RATE)
sr = SAMPLE_RATE
utt_id = os.path.basename(wav_path).rstrip(".wav")
_, length_bnf = _compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local)
_, length_f0 = _compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav)
_, length_embed = _compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav)
def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model):
# Glob wav files
wav_file_list = sorted(Path(f"{datasets_root}/{dataset}").glob("**/*.wav"))
print(f"Globbed {len(wav_file_list)} wav files.")
out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True)
out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True)
out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True)
ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu")
encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu")
if n_processes is None:
n_processes = cpu_count()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, encoder_model_local=encoder_model_local, device=device)
job = Pool(n_processes).imap(func, wav_file_list)
list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav"))
# finish processing and mark
t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8")
d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8")
e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8")
for file in sorted(out_dir.joinpath("f0").glob("*.npy")):
id = os.path.basename(file).split(".f0.npy")[0]
if id.endswith("01"):
d_fid_file.write(id + "\n")
elif id.endswith("09"):
e_fid_file.write(id + "\n")
else:
t_fid_file.write(id + "\n")
t_fid_file.close()
d_fid_file.close()
e_fid_file.close()

374
ppg2mel/rnn_decoder_mol.py Normal file
View File

@@ -0,0 +1,374 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .utils.mol_attention import MOLAttention
from .utils.basic_layers import Linear
from .utils.vc_utils import get_mask_from_lengths
class DecoderPrenet(nn.Module):
def __init__(self, in_dim, sizes):
super().__init__()
in_sizes = [in_dim] + sizes[:-1]
self.layers = nn.ModuleList(
[Linear(in_size, out_size, bias=False)
for (in_size, out_size) in zip(in_sizes, sizes)])
def forward(self, x):
for linear in self.layers:
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
return x
class Decoder(nn.Module):
"""Mixture of Logistic (MoL) attention-based RNN Decoder."""
def __init__(
self,
enc_dim,
num_mels,
frames_per_step,
attention_rnn_dim,
decoder_rnn_dim,
prenet_dims,
num_mixtures,
encoder_down_factor=1,
num_decoder_rnn_layer=1,
use_stop_tokens=False,
concat_context_to_last=False,
):
super().__init__()
self.enc_dim = enc_dim
self.encoder_down_factor = encoder_down_factor
self.num_mels = num_mels
self.frames_per_step = frames_per_step
self.attention_rnn_dim = attention_rnn_dim
self.decoder_rnn_dim = decoder_rnn_dim
self.prenet_dims = prenet_dims
self.use_stop_tokens = use_stop_tokens
self.num_decoder_rnn_layer = num_decoder_rnn_layer
self.concat_context_to_last = concat_context_to_last
# Mel prenet
self.prenet = DecoderPrenet(num_mels, prenet_dims)
self.prenet_pitch = DecoderPrenet(num_mels, prenet_dims)
# Attention RNN
self.attention_rnn = nn.LSTMCell(
prenet_dims[-1] + enc_dim,
attention_rnn_dim
)
# Attention
self.attention_layer = MOLAttention(
attention_rnn_dim,
r=frames_per_step/encoder_down_factor,
M=num_mixtures,
)
# Decoder RNN
self.decoder_rnn_layers = nn.ModuleList()
for i in range(num_decoder_rnn_layer):
if i == 0:
self.decoder_rnn_layers.append(
nn.LSTMCell(
enc_dim + attention_rnn_dim,
decoder_rnn_dim))
else:
self.decoder_rnn_layers.append(
nn.LSTMCell(
decoder_rnn_dim,
decoder_rnn_dim))
# self.decoder_rnn = nn.LSTMCell(
# 2 * enc_dim + attention_rnn_dim,
# decoder_rnn_dim
# )
if concat_context_to_last:
self.linear_projection = Linear(
enc_dim + decoder_rnn_dim,
num_mels * frames_per_step
)
else:
self.linear_projection = Linear(
decoder_rnn_dim,
num_mels * frames_per_step
)
# Stop-token layer
if self.use_stop_tokens:
if concat_context_to_last:
self.stop_layer = Linear(
enc_dim + decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
)
else:
self.stop_layer = Linear(
decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
)
def get_go_frame(self, memory):
B = memory.size(0)
go_frame = torch.zeros((B, self.num_mels), dtype=torch.float,
device=memory.device)
return go_frame
def initialize_decoder_states(self, memory, mask):
device = next(self.parameters()).device
B = memory.size(0)
# attention rnn states
self.attention_hidden = torch.zeros(
(B, self.attention_rnn_dim), device=device)
self.attention_cell = torch.zeros(
(B, self.attention_rnn_dim), device=device)
# decoder rnn states
self.decoder_hiddens = []
self.decoder_cells = []
for i in range(self.num_decoder_rnn_layer):
self.decoder_hiddens.append(
torch.zeros((B, self.decoder_rnn_dim),
device=device)
)
self.decoder_cells.append(
torch.zeros((B, self.decoder_rnn_dim),
device=device)
)
# self.decoder_hidden = torch.zeros(
# (B, self.decoder_rnn_dim), device=device)
# self.decoder_cell = torch.zeros(
# (B, self.decoder_rnn_dim), device=device)
self.attention_context = torch.zeros(
(B, self.enc_dim), device=device)
self.memory = memory
# self.processed_memory = self.attention_layer.memory_layer(memory)
self.mask = mask
def parse_decoder_inputs(self, decoder_inputs):
"""Prepare decoder inputs, i.e. gt mel
Args:
decoder_inputs:(B, T_out, n_mel_channels) inputs used for teacher-forced training.
"""
decoder_inputs = decoder_inputs.reshape(
decoder_inputs.size(0),
int(decoder_inputs.size(1)/self.frames_per_step), -1)
# (B, T_out//r, r*num_mels) -> (T_out//r, B, r*num_mels)
decoder_inputs = decoder_inputs.transpose(0, 1)
# (T_out//r, B, num_mels)
decoder_inputs = decoder_inputs[:,:,-self.num_mels:]
return decoder_inputs
def parse_decoder_outputs(self, mel_outputs, alignments, stop_outputs):
""" Prepares decoder outputs for output
Args:
mel_outputs:
alignments:
"""
# (T_out//r, B, T_enc) -> (B, T_out//r, T_enc)
alignments = torch.stack(alignments).transpose(0, 1)
# (T_out//r, B) -> (B, T_out//r)
if stop_outputs is not None:
if alignments.size(0) == 1:
stop_outputs = torch.stack(stop_outputs).unsqueeze(0)
else:
stop_outputs = torch.stack(stop_outputs).transpose(0, 1)
stop_outputs = stop_outputs.contiguous()
# (T_out//r, B, num_mels*r) -> (B, T_out//r, num_mels*r)
mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
# decouple frames per step
# (B, T_out, num_mels)
mel_outputs = mel_outputs.view(
mel_outputs.size(0), -1, self.num_mels)
return mel_outputs, alignments, stop_outputs
def attend(self, decoder_input):
cell_input = torch.cat((decoder_input, self.attention_context), -1)
self.attention_hidden, self.attention_cell = self.attention_rnn(
cell_input, (self.attention_hidden, self.attention_cell))
self.attention_context, attention_weights = self.attention_layer(
self.attention_hidden, self.memory, None, self.mask)
decoder_rnn_input = torch.cat(
(self.attention_hidden, self.attention_context), -1)
return decoder_rnn_input, self.attention_context, attention_weights
def decode(self, decoder_input):
for i in range(self.num_decoder_rnn_layer):
if i == 0:
self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
decoder_input, (self.decoder_hiddens[i], self.decoder_cells[i]))
else:
self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
self.decoder_hiddens[i-1], (self.decoder_hiddens[i], self.decoder_cells[i]))
return self.decoder_hiddens[-1]
def forward(self, memory, mel_inputs, memory_lengths):
""" Decoder forward pass for training
Args:
memory: (B, T_enc, enc_dim) Encoder outputs
decoder_inputs: (B, T, num_mels) Decoder inputs for teacher forcing.
memory_lengths: (B, ) Encoder output lengths for attention masking.
Returns:
mel_outputs: (B, T, num_mels) mel outputs from the decoder
alignments: (B, T//r, T_enc) attention weights.
"""
# [1, B, num_mels]
go_frame = self.get_go_frame(memory).unsqueeze(0)
# [T//r, B, num_mels]
mel_inputs = self.parse_decoder_inputs(mel_inputs)
# [T//r + 1, B, num_mels]
mel_inputs = torch.cat((go_frame, mel_inputs), dim=0)
# [T//r + 1, B, prenet_dim]
decoder_inputs = self.prenet(mel_inputs)
# decoder_inputs_pitch = self.prenet_pitch(decoder_inputs__)
self.initialize_decoder_states(
memory, mask=~get_mask_from_lengths(memory_lengths),
)
self.attention_layer.init_states(memory)
# self.attention_layer_pitch.init_states(memory_pitch)
mel_outputs, alignments = [], []
if self.use_stop_tokens:
stop_outputs = []
else:
stop_outputs = None
while len(mel_outputs) < decoder_inputs.size(0) - 1:
decoder_input = decoder_inputs[len(mel_outputs)]
# decoder_input_pitch = decoder_inputs_pitch[len(mel_outputs)]
decoder_rnn_input, context, attention_weights = self.attend(decoder_input)
decoder_rnn_output = self.decode(decoder_rnn_input)
if self.concat_context_to_last:
decoder_rnn_output = torch.cat(
(decoder_rnn_output, context), dim=1)
mel_output = self.linear_projection(decoder_rnn_output)
if self.use_stop_tokens:
stop_output = self.stop_layer(decoder_rnn_output)
stop_outputs += [stop_output.squeeze()]
mel_outputs += [mel_output.squeeze(1)] #? perhaps don't need squeeze
alignments += [attention_weights]
# alignments_pitch += [attention_weights_pitch]
mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
mel_outputs, alignments, stop_outputs)
if stop_outputs is None:
return mel_outputs, alignments
else:
return mel_outputs, stop_outputs, alignments
def inference(self, memory, stop_threshold=0.5):
""" Decoder inference
Args:
memory: (1, T_enc, D_enc) Encoder outputs
Returns:
mel_outputs: mel outputs from the decoder
alignments: sequence of attention weights from the decoder
"""
# [1, num_mels]
decoder_input = self.get_go_frame(memory)
self.initialize_decoder_states(memory, mask=None)
self.attention_layer.init_states(memory)
mel_outputs, alignments = [], []
# NOTE(sx): heuristic
max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step
min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
while True:
decoder_input = self.prenet(decoder_input)
decoder_input_final, context, alignment = self.attend(decoder_input)
#mel_output, stop_output, alignment = self.decode(decoder_input)
decoder_rnn_output = self.decode(decoder_input_final)
if self.concat_context_to_last:
decoder_rnn_output = torch.cat(
(decoder_rnn_output, context), dim=1)
mel_output = self.linear_projection(decoder_rnn_output)
stop_output = self.stop_layer(decoder_rnn_output)
mel_outputs += [mel_output.squeeze(1)]
alignments += [alignment]
if torch.sigmoid(stop_output.data) > stop_threshold and len(mel_outputs) >= min_decoder_step:
break
if len(mel_outputs) >= max_decoder_step:
# print("Warning! Decoding steps reaches max decoder steps.")
break
decoder_input = mel_output[:,-self.num_mels:]
mel_outputs, alignments, _ = self.parse_decoder_outputs(
mel_outputs, alignments, None)
return mel_outputs, alignments
def inference_batched(self, memory, stop_threshold=0.5):
""" Decoder inference
Args:
memory: (B, T_enc, D_enc) Encoder outputs
Returns:
mel_outputs: mel outputs from the decoder
alignments: sequence of attention weights from the decoder
"""
# [1, num_mels]
decoder_input = self.get_go_frame(memory)
self.initialize_decoder_states(memory, mask=None)
self.attention_layer.init_states(memory)
mel_outputs, alignments = [], []
stop_outputs = []
# NOTE(sx): heuristic
max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step
min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
while True:
decoder_input = self.prenet(decoder_input)
decoder_input_final, context, alignment = self.attend(decoder_input)
#mel_output, stop_output, alignment = self.decode(decoder_input)
decoder_rnn_output = self.decode(decoder_input_final)
if self.concat_context_to_last:
decoder_rnn_output = torch.cat(
(decoder_rnn_output, context), dim=1)
mel_output = self.linear_projection(decoder_rnn_output)
# (B, 1)
stop_output = self.stop_layer(decoder_rnn_output)
stop_outputs += [stop_output.squeeze()]
# stop_outputs.append(stop_output)
mel_outputs += [mel_output.squeeze(1)]
alignments += [alignment]
# print(stop_output.shape)
if torch.all(torch.sigmoid(stop_output.squeeze().data) > stop_threshold) \
and len(mel_outputs) >= min_decoder_step:
break
if len(mel_outputs) >= max_decoder_step:
# print("Warning! Decoding steps reaches max decoder steps.")
break
decoder_input = mel_output[:,-self.num_mels:]
mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
mel_outputs, alignments, stop_outputs)
mel_outputs_stacked = []
for mel, stop_logit in zip(mel_outputs, stop_outputs):
idx = np.argwhere(torch.sigmoid(stop_logit.cpu()) > stop_threshold)[0][0].item()
mel_outputs_stacked.append(mel[:idx,:])
mel_outputs = torch.cat(mel_outputs_stacked, dim=0).unsqueeze(0)
return mel_outputs, alignments

67
ppg2mel/train.py Normal file
View File

@@ -0,0 +1,67 @@
import sys
import torch
import argparse
import numpy as np
from utils.load_yaml import HpsYaml
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
# For reproducibility, comment these may speed up training
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main():
# Arguments
parser = argparse.ArgumentParser(description=
'Training PPG2Mel VC model.')
parser.add_argument('--config', type=str,
help='Path to experiment config, e.g., config/vc.yaml')
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
parser.add_argument('--logdir', default='log/', type=str,
help='Logging path.', required=False)
parser.add_argument('--ckpdir', default='ckpt/', type=str,
help='Checkpoint path.', required=False)
parser.add_argument('--outdir', default='result/', type=str,
help='Decode output path.', required=False)
parser.add_argument('--load', default=None, type=str,
help='Load pre-trained model (for training only)', required=False)
parser.add_argument('--warm_start', action='store_true',
help='Load model weights only, ignore specified layers.')
parser.add_argument('--seed', default=0, type=int,
help='Random seed for reproducable results.', required=False)
parser.add_argument('--njobs', default=8, type=int,
help='Number of threads for dataloader/decoding.', required=False)
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
parser.add_argument('--no-pin', action='store_true',
help='Disable pin-memory for dataloader')
parser.add_argument('--test', action='store_true', help='Test the model.')
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
parser.add_argument('--finetune', action='store_true', help='Finetune model')
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
###
paras = parser.parse_args()
setattr(paras, 'gpu', not paras.cpu)
setattr(paras, 'pin_memory', not paras.no_pin)
setattr(paras, 'verbose', not paras.no_msg)
# Make the config dict dot visitable
config = HpsYaml(paras.config)
np.random.seed(paras.seed)
torch.manual_seed(paras.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(paras.seed)
print(">>> OneShot VC training ...")
mode = "train"
solver = Solver(config, paras, mode)
solver.load_data()
solver.set_model()
solver.exec()
print(">>> Oneshot VC train finished!")
sys.exit(0)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1 @@
#

50
ppg2mel/train/loss.py Normal file
View File

@@ -0,0 +1,50 @@
from typing import Dict
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils.nets_utils import make_pad_mask
class MaskedMSELoss(nn.Module):
def __init__(self, frames_per_step):
super().__init__()
self.frames_per_step = frames_per_step
self.mel_loss_criterion = nn.MSELoss(reduction='none')
# self.loss = nn.MSELoss()
self.stop_loss_criterion = nn.BCEWithLogitsLoss(reduction='none')
def get_mask(self, lengths, max_len=None):
# lengths: [B,]
if max_len is None:
max_len = torch.max(lengths)
batch_size = lengths.size(0)
seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len).to(lengths.device)
seq_length_expand = lengths.unsqueeze(1).expand_as(seq_range_expand)
return (seq_range_expand < seq_length_expand).float()
def forward(self, mel_pred, mel_pred_postnet, mel_trg, lengths,
stop_target, stop_pred):
## process stop_target
B = stop_target.size(0)
stop_target = stop_target.reshape(B, -1, self.frames_per_step)[:, :, 0]
stop_lengths = torch.ceil(lengths.float() / self.frames_per_step).long()
stop_mask = self.get_mask(stop_lengths, int(mel_trg.size(1)/self.frames_per_step))
mel_trg.requires_grad = False
# (B, T, 1)
mel_mask = self.get_mask(lengths, mel_trg.size(1)).unsqueeze(-1)
# (B, T, D)
mel_mask = mel_mask.expand_as(mel_trg)
mel_loss_pre = (self.mel_loss_criterion(mel_pred, mel_trg) * mel_mask).sum() / mel_mask.sum()
mel_loss_post = (self.mel_loss_criterion(mel_pred_postnet, mel_trg) * mel_mask).sum() / mel_mask.sum()
mel_loss = mel_loss_pre + mel_loss_post
# stop token loss
stop_loss = torch.sum(self.stop_loss_criterion(stop_pred, stop_target) * stop_mask) / stop_mask.sum()
return mel_loss, stop_loss

45
ppg2mel/train/optim.py Normal file
View File

@@ -0,0 +1,45 @@
import torch
import numpy as np
class Optimizer():
def __init__(self, parameters, optimizer, lr, eps, lr_scheduler,
**kwargs):
# Setup torch optimizer
self.opt_type = optimizer
self.init_lr = lr
self.sch_type = lr_scheduler
opt = getattr(torch.optim, optimizer)
if lr_scheduler == 'warmup':
warmup_step = 4000.0
init_lr = lr
self.lr_scheduler = lambda step: init_lr * warmup_step ** 0.5 * \
np.minimum((step+1)*warmup_step**-1.5, (step+1)**-0.5)
self.opt = opt(parameters, lr=1.0)
else:
self.lr_scheduler = None
self.opt = opt(parameters, lr=lr, eps=eps) # ToDo: 1e-8 better?
def get_opt_state_dict(self):
return self.opt.state_dict()
def load_opt_state_dict(self, state_dict):
self.opt.load_state_dict(state_dict)
def pre_step(self, step):
if self.lr_scheduler is not None:
cur_lr = self.lr_scheduler(step)
for param_group in self.opt.param_groups:
param_group['lr'] = cur_lr
else:
cur_lr = self.init_lr
self.opt.zero_grad()
return cur_lr
def step(self):
self.opt.step()
def create_msg(self):
return ['Optim.Info.| Algo. = {}\t| Lr = {}\t (schedule = {})'
.format(self.opt_type, self.init_lr, self.sch_type)]

10
ppg2mel/train/option.py Normal file
View File

@@ -0,0 +1,10 @@
# Default parameters which will be imported by solver
default_hparas = {
'GRAD_CLIP': 5.0, # Grad. clip threshold
'PROGRESS_STEP': 100, # Std. output refresh freq.
# Decode steps for objective validation (step = ratio*input_txt_len)
'DEV_STEP_RATIO': 1.2,
# Number of examples (alignment/text) to show in tensorboard
'DEV_N_EXAMPLE': 4,
'TB_FLUSH_FREQ': 180 # Update frequency of tensorboard (secs)
}

216
ppg2mel/train/solver.py Normal file
View File

@@ -0,0 +1,216 @@
import os
import sys
import abc
import math
import yaml
import torch
from torch.utils.tensorboard import SummaryWriter
from .option import default_hparas
from utils.util import human_format, Timer
from utils.load_yaml import HpsYaml
class BaseSolver():
'''
Prototype Solver for all kinds of tasks
Arguments
config - yaml-styled config
paras - argparse outcome
mode - "train"/"test"
'''
def __init__(self, config, paras, mode="train"):
# General Settings
self.config = config # load from yaml file
self.paras = paras # command line args
self.mode = mode # 'train' or 'test'
for k, v in default_hparas.items():
setattr(self, k, v)
self.device = torch.device('cuda') if self.paras.gpu and torch.cuda.is_available() \
else torch.device('cpu')
# Name experiment
self.exp_name = paras.name
if self.exp_name is None:
if 'exp_name' in self.config:
self.exp_name = self.config.exp_name
else:
# By default, exp is named after config file
self.exp_name = paras.config.split('/')[-1].replace('.yaml', '')
if mode == 'train':
self.exp_name += '_seed{}'.format(paras.seed)
if mode == 'train':
# Filepath setup
os.makedirs(paras.ckpdir, exist_ok=True)
self.ckpdir = os.path.join(paras.ckpdir, self.exp_name)
os.makedirs(self.ckpdir, exist_ok=True)
# Logger settings
self.logdir = os.path.join(paras.logdir, self.exp_name)
self.log = SummaryWriter(
self.logdir, flush_secs=self.TB_FLUSH_FREQ)
self.timer = Timer()
# Hyper-parameters
self.step = 0
self.valid_step = config.hparas.valid_step
self.max_step = config.hparas.max_step
self.verbose('Exp. name : {}'.format(self.exp_name))
self.verbose('Loading data... large corpus may took a while.')
# elif mode == 'test':
# # Output path
# os.makedirs(paras.outdir, exist_ok=True)
# self.ckpdir = os.path.join(paras.outdir, self.exp_name)
# Load training config to get acoustic feat and build model
# self.src_config = HpsYaml(config.src.config)
# self.paras.load = config.src.ckpt
# self.verbose('Evaluating result of tr. config @ {}'.format(
# config.src.config))
def backward(self, loss):
'''
Standard backward step with self.timer and debugger
Arguments
loss - the loss to perform loss.backward()
'''
self.timer.set()
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.GRAD_CLIP)
if math.isnan(grad_norm):
self.verbose('Error : grad norm is NaN @ step '+str(self.step))
else:
self.optimizer.step()
self.timer.cnt('bw')
return grad_norm
def load_ckpt(self):
''' Load ckpt if --load option is specified '''
if self.paras.load is not None:
if self.paras.warm_start:
self.verbose(f"Warm starting model from checkpoint {self.paras.load}.")
ckpt = torch.load(
self.paras.load, map_location=self.device if self.mode == 'train'
else 'cpu')
model_dict = ckpt['model']
if len(self.config.model.ignore_layers) > 0:
model_dict = {k:v for k, v in model_dict.items()
if k not in self.config.model.ignore_layers}
dummy_dict = self.model.state_dict()
dummy_dict.update(model_dict)
model_dict = dummy_dict
self.model.load_state_dict(model_dict)
else:
# Load weights
ckpt = torch.load(
self.paras.load, map_location=self.device if self.mode == 'train'
else 'cpu')
self.model.load_state_dict(ckpt['model'])
# Load task-dependent items
if self.mode == 'train':
self.step = ckpt['global_step']
self.optimizer.load_opt_state_dict(ckpt['optimizer'])
self.verbose('Load ckpt from {}, restarting at step {}'.format(
self.paras.load, self.step))
else:
for k, v in ckpt.items():
if type(v) is float:
metric, score = k, v
self.model.eval()
self.verbose('Evaluation target = {} (recorded {} = {:.2f} %)'.format(
self.paras.load, metric, score))
def verbose(self, msg):
''' Verbose function for print information to stdout'''
if self.paras.verbose:
if type(msg) == list:
for m in msg:
print('[INFO]', m.ljust(100))
else:
print('[INFO]', msg.ljust(100))
def progress(self, msg):
''' Verbose function for updating progress on stdout (do not include newline) '''
if self.paras.verbose:
sys.stdout.write("\033[K") # Clear line
print('[{}] {}'.format(human_format(self.step), msg), end='\r')
def write_log(self, log_name, log_dict):
'''
Write log to TensorBoard
log_name - <str> Name of tensorboard variable
log_value - <dict>/<array> Value of variable (e.g. dict of losses), passed if value = None
'''
if type(log_dict) is dict:
log_dict = {key: val for key, val in log_dict.items() if (
val is not None and not math.isnan(val))}
if log_dict is None:
pass
elif len(log_dict) > 0:
if 'align' in log_name or 'spec' in log_name:
img, form = log_dict
self.log.add_image(
log_name, img, global_step=self.step, dataformats=form)
elif 'text' in log_name or 'hyp' in log_name:
self.log.add_text(log_name, log_dict, self.step)
else:
self.log.add_scalars(log_name, log_dict, self.step)
def save_checkpoint(self, f_name, metric, score, show_msg=True):
''''
Ckpt saver
f_name - <str> the name of ckpt file (w/o prefix) to store, overwrite if existed
score - <float> The value of metric used to evaluate model
'''
ckpt_path = os.path.join(self.ckpdir, f_name)
full_dict = {
"model": self.model.state_dict(),
"optimizer": self.optimizer.get_opt_state_dict(),
"global_step": self.step,
metric: score
}
torch.save(full_dict, ckpt_path)
if show_msg:
self.verbose("Saved checkpoint (step = {}, {} = {:.2f}) and status @ {}".
format(human_format(self.step), metric, score, ckpt_path))
# ----------------------------------- Abtract Methods ------------------------------------------ #
@abc.abstractmethod
def load_data(self):
'''
Called by main to load all data
After this call, data related attributes should be setup (e.g. self.tr_set, self.dev_set)
No return value
'''
raise NotImplementedError
@abc.abstractmethod
def set_model(self):
'''
Called by main to set models
After this call, model related attributes should be setup (e.g. self.l2_loss)
The followings MUST be setup
- self.model (torch.nn.Module)
- self.optimizer (src.Optimizer),
init. w/ self.optimizer = src.Optimizer(self.model.parameters(),**self.config['hparas'])
Loading pre-trained model should also be performed here
No return value
'''
raise NotImplementedError
@abc.abstractmethod
def exec(self):
'''
Called by main to execute training/inference
'''
raise NotImplementedError

View File

@@ -0,0 +1,288 @@
import os, sys
# sys.path.append('/home/shaunxliu/projects/nnsp')
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import torch
from torch.utils.data import DataLoader
import numpy as np
from .solver import BaseSolver
from utils.data_load import OneshotVcDataset, MultiSpkVcCollate
# from src.rnn_ppg2mel import BiRnnPpg2MelModel
# from src.mel_decoder_mol_encAddlf0 import MelDecoderMOL
from .loss import MaskedMSELoss
from .optim import Optimizer
from utils.util import human_format
from ppg2mel import MelDecoderMOLv2
class Solver(BaseSolver):
"""Customized Solver."""
def __init__(self, config, paras, mode):
super().__init__(config, paras, mode)
self.num_att_plots = 5
self.att_ws_dir = f"{self.logdir}/att_ws"
os.makedirs(self.att_ws_dir, exist_ok=True)
self.best_loss = np.inf
def fetch_data(self, data):
"""Move data to device"""
data = [i.to(self.device) for i in data]
return data
def load_data(self):
""" Load data for training/validation/plotting."""
train_dataset = OneshotVcDataset(
meta_file=self.config.data.train_fid_list,
vctk_ppg_dir=self.config.data.vctk_ppg_dir,
libri_ppg_dir=self.config.data.libri_ppg_dir,
vctk_f0_dir=self.config.data.vctk_f0_dir,
libri_f0_dir=self.config.data.libri_f0_dir,
vctk_wav_dir=self.config.data.vctk_wav_dir,
libri_wav_dir=self.config.data.libri_wav_dir,
vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir,
libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir,
ppg_file_ext=self.config.data.ppg_file_ext,
min_max_norm_mel=self.config.data.min_max_norm_mel,
mel_min=self.config.data.mel_min,
mel_max=self.config.data.mel_max,
)
dev_dataset = OneshotVcDataset(
meta_file=self.config.data.dev_fid_list,
vctk_ppg_dir=self.config.data.vctk_ppg_dir,
libri_ppg_dir=self.config.data.libri_ppg_dir,
vctk_f0_dir=self.config.data.vctk_f0_dir,
libri_f0_dir=self.config.data.libri_f0_dir,
vctk_wav_dir=self.config.data.vctk_wav_dir,
libri_wav_dir=self.config.data.libri_wav_dir,
vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir,
libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir,
ppg_file_ext=self.config.data.ppg_file_ext,
min_max_norm_mel=self.config.data.min_max_norm_mel,
mel_min=self.config.data.mel_min,
mel_max=self.config.data.mel_max,
)
self.train_dataloader = DataLoader(
train_dataset,
num_workers=self.paras.njobs,
shuffle=True,
batch_size=self.config.hparas.batch_size,
pin_memory=False,
drop_last=True,
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
use_spk_dvec=True),
)
self.dev_dataloader = DataLoader(
dev_dataset,
num_workers=self.paras.njobs,
shuffle=False,
batch_size=self.config.hparas.batch_size,
pin_memory=False,
drop_last=False,
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
use_spk_dvec=True),
)
self.plot_dataloader = DataLoader(
dev_dataset,
num_workers=self.paras.njobs,
shuffle=False,
batch_size=1,
pin_memory=False,
drop_last=False,
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
use_spk_dvec=True,
give_uttids=True),
)
msg = "Have prepared training set and dev set."
self.verbose(msg)
def load_pretrained_params(self):
print("Load pretrained model from: ", self.config.data.pretrain_model_file)
ignore_layer_prefixes = ["speaker_embedding_table"]
pretrain_model_file = self.config.data.pretrain_model_file
pretrain_ckpt = torch.load(
pretrain_model_file, map_location=self.device
)["model"]
model_dict = self.model.state_dict()
print(self.model)
# 1. filter out unnecessrary keys
for prefix in ignore_layer_prefixes:
pretrain_ckpt = {k : v
for k, v in pretrain_ckpt.items() if not k.startswith(prefix)
}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrain_ckpt)
# 3. load the new state dict
self.model.load_state_dict(model_dict)
def set_model(self):
"""Setup model and optimizer"""
# Model
print("[INFO] Model name: ", self.config["model_name"])
self.model = MelDecoderMOLv2(
**self.config["model"]
).to(self.device)
# self.load_pretrained_params()
# model_params = [{'params': self.model.spk_embedding.weight}]
model_params = [{'params': self.model.parameters()}]
# Loss criterion
self.loss_criterion = MaskedMSELoss(self.config.model.frames_per_step)
# Optimizer
self.optimizer = Optimizer(model_params, **self.config["hparas"])
self.verbose(self.optimizer.create_msg())
# Automatically load pre-trained model if self.paras.load is given
self.load_ckpt()
def exec(self):
self.verbose("Total training steps {}.".format(
human_format(self.max_step)))
mel_loss = None
n_epochs = 0
# Set as current time
self.timer.set()
while self.step < self.max_step:
for data in self.train_dataloader:
# Pre-step: updata lr_rate and do zero_grad
lr_rate = self.optimizer.pre_step(self.step)
total_loss = 0
# data to device
ppgs, lf0_uvs, mels, in_lengths, \
out_lengths, spk_ids, stop_tokens = self.fetch_data(data)
self.timer.cnt("rd")
mel_outputs, mel_outputs_postnet, predicted_stop = self.model(
ppgs,
in_lengths,
mels,
out_lengths,
lf0_uvs,
spk_ids
)
mel_loss, stop_loss = self.loss_criterion(
mel_outputs,
mel_outputs_postnet,
mels,
out_lengths,
stop_tokens,
predicted_stop
)
loss = mel_loss + stop_loss
self.timer.cnt("fw")
# Back-prop
grad_norm = self.backward(loss)
self.step += 1
# Logger
if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
self.progress("Tr|loss:{:.4f},mel-loss:{:.4f},stop-loss:{:.4f}|Grad.Norm-{:.2f}|{}"
.format(loss.cpu().item(), mel_loss.cpu().item(),
stop_loss.cpu().item(), grad_norm, self.timer.show()))
self.write_log('loss', {'tr/loss': loss,
'tr/mel-loss': mel_loss,
'tr/stop-loss': stop_loss})
# Validation
if (self.step == 1) or (self.step % self.valid_step == 0):
self.validate()
# End of step
# https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
torch.cuda.empty_cache()
self.timer.set()
if self.step > self.max_step:
break
n_epochs += 1
self.log.close()
def validate(self):
self.model.eval()
dev_loss, dev_mel_loss, dev_stop_loss = 0.0, 0.0, 0.0
for i, data in enumerate(self.dev_dataloader):
self.progress('Valid step - {}/{}'.format(i+1, len(self.dev_dataloader)))
# Fetch data
ppgs, lf0_uvs, mels, in_lengths, \
out_lengths, spk_ids, stop_tokens = self.fetch_data(data)
with torch.no_grad():
mel_outputs, mel_outputs_postnet, predicted_stop = self.model(
ppgs,
in_lengths,
mels,
out_lengths,
lf0_uvs,
spk_ids
)
mel_loss, stop_loss = self.loss_criterion(
mel_outputs,
mel_outputs_postnet,
mels,
out_lengths,
stop_tokens,
predicted_stop
)
loss = mel_loss + stop_loss
dev_loss += loss.cpu().item()
dev_mel_loss += mel_loss.cpu().item()
dev_stop_loss += stop_loss.cpu().item()
dev_loss = dev_loss / (i + 1)
dev_mel_loss = dev_mel_loss / (i + 1)
dev_stop_loss = dev_stop_loss / (i + 1)
self.save_checkpoint(f'step_{self.step}.pth', 'loss', dev_loss, show_msg=False)
if dev_loss < self.best_loss:
self.best_loss = dev_loss
self.save_checkpoint(f'best_loss_step_{self.step}.pth', 'loss', dev_loss)
self.write_log('loss', {'dv/loss': dev_loss,
'dv/mel-loss': dev_mel_loss,
'dv/stop-loss': dev_stop_loss})
# plot attention
for i, data in enumerate(self.plot_dataloader):
if i == self.num_att_plots:
break
# Fetch data
ppgs, lf0_uvs, mels, in_lengths, \
out_lengths, spk_ids, stop_tokens = self.fetch_data(data[:-1])
fid = data[-1][0]
with torch.no_grad():
_, _, _, att_ws = self.model(
ppgs,
in_lengths,
mels,
out_lengths,
lf0_uvs,
spk_ids,
output_att_ws=True
)
att_ws = att_ws.squeeze(0).cpu().numpy()
att_ws = att_ws[None]
w, h = plt.figaspect(1.0 / len(att_ws))
fig = plt.Figure(figsize=(w * 1.3, h * 1.3))
axes = fig.subplots(1, len(att_ws))
if len(att_ws) == 1:
axes = [axes]
for ax, aw in zip(axes, att_ws):
ax.imshow(aw.astype(np.float32), aspect="auto")
ax.set_title(f"{fid}")
ax.set_xlabel("Input")
ax.set_ylabel("Output")
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
fig_name = f"{self.att_ws_dir}/{fid}_step{self.step}.png"
fig.savefig(fig_name)
# Resume training
self.model.train()

View File

@@ -0,0 +1,23 @@
from abc import ABC
from abc import abstractmethod
import torch
class AbsMelDecoder(torch.nn.Module, ABC):
"""The abstract PPG-based voice conversion class
This "model" is one of mediator objects for "Task" class.
"""
@abstractmethod
def forward(
self,
bottle_neck_features: torch.Tensor,
feature_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
logf0_uv: torch.Tensor = None,
spembs: torch.Tensor = None,
styleembs: torch.Tensor = None,
) -> torch.Tensor:
raise NotImplementedError

View File

@@ -0,0 +1,79 @@
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
def tile(x, count, dim=0):
"""
Tiles x on dimension dim count times.
"""
perm = list(range(len(x.size())))
if dim != 0:
perm[0], perm[dim] = perm[dim], perm[0]
x = x.permute(perm).contiguous()
out_size = list(x.size())
out_size[0] *= count
batch = x.size(0)
x = x.view(batch, -1) \
.transpose(0, 1) \
.repeat(count, 1) \
.transpose(0, 1) \
.contiguous() \
.view(*out_size)
if dim != 0:
x = x.permute(perm).contiguous()
return x
class Linear(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
super(Linear, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, x):
return self.linear_layer(x)
class Conv1d(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
super(Conv1d, self).__init__()
if padding is None:
assert(kernel_size % 2 == 1)
padding = int(dilation * (kernel_size - 1)/2)
self.conv = torch.nn.Conv1d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation,
bias=bias)
torch.nn.init.xavier_uniform_(
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
def forward(self, x):
# x: BxDxT
return self.conv(x)
def tile(x, count, dim=0):
"""
Tiles x on dimension dim count times.
"""
perm = list(range(len(x.size())))
if dim != 0:
perm[0], perm[dim] = perm[dim], perm[0]
x = x.permute(perm).contiguous()
out_size = list(x.size())
out_size[0] *= count
batch = x.size(0)
x = x.view(batch, -1) \
.transpose(0, 1) \
.repeat(count, 1) \
.transpose(0, 1) \
.contiguous() \
.view(*out_size)
if dim != 0:
x = x.permute(perm).contiguous()
return x

View File

@@ -0,0 +1,52 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .basic_layers import Linear, Conv1d
class Postnet(nn.Module):
"""Postnet
- Five 1-d convolution with 512 channels and kernel size 5
"""
def __init__(self, num_mels=80,
num_layers=5,
hidden_dim=512,
kernel_size=5):
super(Postnet, self).__init__()
self.convolutions = nn.ModuleList()
self.convolutions.append(
nn.Sequential(
Conv1d(
num_mels, hidden_dim,
kernel_size=kernel_size, stride=1,
padding=int((kernel_size - 1) / 2),
dilation=1, w_init_gain='tanh'),
nn.BatchNorm1d(hidden_dim)))
for i in range(1, num_layers - 1):
self.convolutions.append(
nn.Sequential(
Conv1d(
hidden_dim,
hidden_dim,
kernel_size=kernel_size, stride=1,
padding=int((kernel_size - 1) / 2),
dilation=1, w_init_gain='tanh'),
nn.BatchNorm1d(hidden_dim)))
self.convolutions.append(
nn.Sequential(
Conv1d(
hidden_dim, num_mels,
kernel_size=kernel_size, stride=1,
padding=int((kernel_size - 1) / 2),
dilation=1, w_init_gain='linear'),
nn.BatchNorm1d(num_mels)))
def forward(self, x):
# x: (B, num_mels, T_dec)
for i in range(len(self.convolutions) - 1):
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
return x

View File

@@ -0,0 +1,123 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class MOLAttention(nn.Module):
""" Discretized Mixture of Logistic (MOL) attention.
C.f. Section 5 of "MelNet: A Generative Model for Audio in the Frequency Domain" and
GMMv2b model in "Location-relative attention mechanisms for robust long-form speech synthesis".
"""
def __init__(
self,
query_dim,
r=1,
M=5,
):
"""
Args:
query_dim: attention_rnn_dim.
M: number of mixtures.
"""
super().__init__()
if r < 1:
self.r = float(r)
else:
self.r = int(r)
self.M = M
self.score_mask_value = 0.0 # -float("inf")
self.eps = 1e-5
# Position arrary for encoder time steps
self.J = None
# Query layer: [w, sigma,]
self.query_layer = torch.nn.Sequential(
nn.Linear(query_dim, 256, bias=True),
nn.ReLU(),
nn.Linear(256, 3*M, bias=True)
)
self.mu_prev = None
self.initialize_bias()
def initialize_bias(self):
"""Initialize sigma and Delta."""
# sigma
torch.nn.init.constant_(self.query_layer[2].bias[self.M:2*self.M], 1.0)
# Delta: softplus(1.8545) = 2.0; softplus(3.9815) = 4.0; softplus(0.5413) = 1.0
# softplus(-0.432) = 0.5003
if self.r == 2:
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 1.8545)
elif self.r == 4:
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 3.9815)
elif self.r == 1:
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 0.5413)
else:
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], -0.432)
def init_states(self, memory):
"""Initialize mu_prev and J.
This function should be called by the decoder before decoding one batch.
Args:
memory: (B, T, D_enc) encoder output.
"""
B, T_enc, _ = memory.size()
device = memory.device
self.J = torch.arange(0, T_enc + 2.0).to(device) + 0.5 # NOTE: for discretize usage
# self.J = memory.new_tensor(np.arange(T_enc), dtype=torch.float)
self.mu_prev = torch.zeros(B, self.M).to(device)
def forward(self, att_rnn_h, memory, memory_pitch=None, mask=None):
"""
att_rnn_h: attetion rnn hidden state.
memory: encoder outputs (B, T_enc, D).
mask: binary mask for padded data (B, T_enc).
"""
# [B, 3M]
mixture_params = self.query_layer(att_rnn_h)
# [B, M]
w_hat = mixture_params[:, :self.M]
sigma_hat = mixture_params[:, self.M:2*self.M]
Delta_hat = mixture_params[:, 2*self.M:3*self.M]
# print("w_hat: ", w_hat)
# print("sigma_hat: ", sigma_hat)
# print("Delta_hat: ", Delta_hat)
# Dropout to de-correlate attention heads
w_hat = F.dropout(w_hat, p=0.5, training=self.training) # NOTE(sx): needed?
# Mixture parameters
w = torch.softmax(w_hat, dim=-1) + self.eps
sigma = F.softplus(sigma_hat) + self.eps
Delta = F.softplus(Delta_hat)
mu_cur = self.mu_prev + Delta
# print("w:", w)
j = self.J[:memory.size(1) + 1]
# Attention weights
# CDF of logistic distribution
phi_t = w.unsqueeze(-1) * (1 / (1 + torch.sigmoid(
(mu_cur.unsqueeze(-1) - j) / sigma.unsqueeze(-1))))
# print("phi_t:", phi_t)
# Discretize attention weights
# (B, T_enc + 1)
alpha_t = torch.sum(phi_t, dim=1)
alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1]
alpha_t[alpha_t == 0] = self.eps
# print("alpha_t: ", alpha_t.size())
# Apply masking
if mask is not None:
alpha_t.data.masked_fill_(mask, self.score_mask_value)
context = torch.bmm(alpha_t.unsqueeze(1), memory).squeeze(1)
if memory_pitch is not None:
context_pitch = torch.bmm(alpha_t.unsqueeze(1), memory_pitch).squeeze(1)
self.mu_prev = mu_cur
if memory_pitch is not None:
return context, context_pitch, alpha_t
return context, alpha_t

451
ppg2mel/utils/nets_utils.py Normal file
View File

@@ -0,0 +1,451 @@
# -*- coding: utf-8 -*-
"""Network related utility tools."""
import logging
from typing import Dict
import numpy as np
import torch
def to_device(m, x):
"""Send tensor into the device of the module.
Args:
m (torch.nn.Module): Torch module.
x (Tensor): Torch tensor.
Returns:
Tensor: Torch tensor located in the same place as torch module.
"""
assert isinstance(m, torch.nn.Module)
device = next(m.parameters()).device
return x.to(device)
def pad_list(xs, pad_value):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
n_batch = len(xs)
max_len = max(x.size(0) for x in xs)
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
for i in range(n_batch):
pad[i, :xs[i].size(0)] = xs[i]
return pad
def make_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
Returns:
Tensor: Mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
"""
if length_dim == 0:
raise ValueError('length_dim cannot be 0: {}'.format(length_dim))
if not isinstance(lengths, list):
lengths = lengths.tolist()
bs = int(len(lengths))
if xs is None:
maxlen = int(max(lengths))
else:
maxlen = xs.size(length_dim)
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
if xs is not None:
assert xs.size(0) == bs, (xs.size(0), bs)
if length_dim < 0:
length_dim = xs.dim() + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(slice(None) if i in (0, length_dim) else None
for i in range(xs.dim()))
mask = mask[ind].expand_as(xs).to(xs.device)
return mask
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of non-padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
Returns:
ByteTensor: mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 1, 0],
[1, 1, 1, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_non_pad_mask(lengths, xs, 1)
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
>>> make_non_pad_mask(lengths, xs, 2)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
return ~make_pad_mask(lengths, xs, length_dim)
def mask_by_length(xs, lengths, fill=0):
"""Mask tensor according to length.
Args:
xs (Tensor): Batch of input tensor (B, `*`).
lengths (LongTensor or List): Batch of lengths (B,).
fill (int or float): Value to fill masked part.
Returns:
Tensor: Batch of masked input tensor (B, `*`).
Examples:
>>> x = torch.arange(5).repeat(3, 1) + 1
>>> x
tensor([[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]])
>>> lengths = [5, 3, 2]
>>> mask_by_length(x, lengths)
tensor([[1, 2, 3, 4, 5],
[1, 2, 3, 0, 0],
[1, 2, 0, 0, 0]])
"""
assert xs.size(0) == len(lengths)
ret = xs.data.new(*xs.size()).fill_(fill)
for i, l in enumerate(lengths):
ret[i, :l] = xs[i, :l]
return ret
def th_accuracy(pad_outputs, pad_targets, ignore_label):
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(
pad_targets.size(0),
pad_targets.size(1),
pad_outputs.size(1)).argmax(2)
mask = pad_targets != ignore_label
numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
denominator = torch.sum(mask)
return float(numerator) / float(denominator)
def to_torch_tensor(x):
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
Args:
x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
Returns:
Tensor or ComplexTensor: Type converted inputs.
Examples:
>>> xs = np.ones(3, dtype=np.float32)
>>> xs = to_torch_tensor(xs)
tensor([1., 1., 1.])
>>> xs = torch.ones(3, 4, 5)
>>> assert to_torch_tensor(xs) is xs
>>> xs = {'real': xs, 'imag': xs}
>>> to_torch_tensor(xs)
ComplexTensor(
Real:
tensor([1., 1., 1.])
Imag;
tensor([1., 1., 1.])
)
"""
# If numpy, change to torch tensor
if isinstance(x, np.ndarray):
if x.dtype.kind == 'c':
# Dynamically importing because torch_complex requires python3
from torch_complex.tensor import ComplexTensor
return ComplexTensor(x)
else:
return torch.from_numpy(x)
# If {'real': ..., 'imag': ...}, convert to ComplexTensor
elif isinstance(x, dict):
# Dynamically importing because torch_complex requires python3
from torch_complex.tensor import ComplexTensor
if 'real' not in x or 'imag' not in x:
raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
# Relative importing because of using python3 syntax
return ComplexTensor(x['real'], x['imag'])
# If torch.Tensor, as it is
elif isinstance(x, torch.Tensor):
return x
else:
error = ("x must be numpy.ndarray, torch.Tensor or a dict like "
"{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
"but got {}".format(type(x)))
try:
from torch_complex.tensor import ComplexTensor
except Exception:
# If PY2
raise ValueError(error)
else:
# If PY3
if isinstance(x, ComplexTensor):
return x
else:
raise ValueError(error)
def get_subsample(train_args, mode, arch):
"""Parse the subsampling factors from the training args for the specified `mode` and `arch`.
Args:
train_args: argument Namespace containing options.
mode: one of ('asr', 'mt', 'st')
arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
Returns:
np.ndarray / List[np.ndarray]: subsampling factors.
"""
if arch == 'transformer':
return np.array([1])
elif mode == 'mt' and arch == 'rnn':
# +1 means input (+1) and layers outputs (train_args.elayer)
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
logging.warning('Subsampling is not performed for machine translation.')
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
return subsample
elif (mode == 'asr' and arch in ('rnn', 'rnn-t')) or \
(mode == 'mt' and arch == 'rnn') or \
(mode == 'st' and arch == 'rnn'):
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
ss = train_args.subsample.split("_")
for j in range(min(train_args.elayers + 1, len(ss))):
subsample[j] = int(ss[j])
else:
logging.warning(
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
return subsample
elif mode == 'asr' and arch == 'rnn_mix':
subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int)
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
ss = train_args.subsample.split("_")
for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))):
subsample[j] = int(ss[j])
else:
logging.warning(
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
return subsample
elif mode == 'asr' and arch == 'rnn_mulenc':
subsample_list = []
for idx in range(train_args.num_encs):
subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int)
if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"):
ss = train_args.subsample[idx].split("_")
for j in range(min(train_args.elayers[idx] + 1, len(ss))):
subsample[j] = int(ss[j])
else:
logging.warning(
'Encoder %d: Subsampling is not performed for vgg*. '
'It is performed in max pooling layers at CNN.', idx + 1)
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
subsample_list.append(subsample)
return subsample_list
else:
raise ValueError('Invalid options: mode={}, arch={}'.format(mode, arch))
def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]):
"""Replace keys of old prefix with new prefix in state dict."""
# need this list not to break the dict iterator
old_keys = [k for k in state_dict if k.startswith(old_prefix)]
if len(old_keys) > 0:
logging.warning(f'Rename: {old_prefix} -> {new_prefix}')
for k in old_keys:
v = state_dict.pop(k)
new_k = k.replace(old_prefix, new_prefix)
state_dict[new_k] = v

22
ppg2mel/utils/vc_utils.py Normal file
View File

@@ -0,0 +1,22 @@
import torch
def gcd(a, b):
"""Greatest common divisor."""
a, b = (a, b) if a >=b else (b, a)
if a%b == 0:
return b
else :
return gcd(b, a%b)
def lcm(a, b):
"""Least common multiple"""
return a * b // gcd(a, b)
def get_mask_from_lengths(lengths, max_len=None):
if max_len is None:
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
mask = (ids < lengths.unsqueeze(1)).bool()
return mask

67
ppg2mel_train.py Normal file
View File

@@ -0,0 +1,67 @@
import sys
import torch
import argparse
import numpy as np
from utils.load_yaml import HpsYaml
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
# For reproducibility, comment these may speed up training
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main():
# Arguments
parser = argparse.ArgumentParser(description=
'Training PPG2Mel VC model.')
parser.add_argument('--config', type=str,
help='Path to experiment config, e.g., config/vc.yaml')
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
parser.add_argument('--logdir', default='log/', type=str,
help='Logging path.', required=False)
parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str,
help='Checkpoint path.', required=False)
parser.add_argument('--outdir', default='result/', type=str,
help='Decode output path.', required=False)
parser.add_argument('--load', default=None, type=str,
help='Load pre-trained model (for training only)', required=False)
parser.add_argument('--warm_start', action='store_true',
help='Load model weights only, ignore specified layers.')
parser.add_argument('--seed', default=0, type=int,
help='Random seed for reproducable results.', required=False)
parser.add_argument('--njobs', default=8, type=int,
help='Number of threads for dataloader/decoding.', required=False)
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
parser.add_argument('--no-pin', action='store_true',
help='Disable pin-memory for dataloader')
parser.add_argument('--test', action='store_true', help='Test the model.')
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
parser.add_argument('--finetune', action='store_true', help='Finetune model')
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
###
paras = parser.parse_args()
setattr(paras, 'gpu', not paras.cpu)
setattr(paras, 'pin_memory', not paras.no_pin)
setattr(paras, 'verbose', not paras.no_msg)
# Make the config dict dot visitable
config = HpsYaml(paras.config)
np.random.seed(paras.seed)
torch.manual_seed(paras.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(paras.seed)
print(">>> OneShot VC training ...")
mode = "train"
solver = Solver(config, paras, mode)
solver.load_data()
solver.set_model()
solver.exec()
print(">>> Oneshot VC train finished!")
sys.exit(0)
if __name__ == "__main__":
main()

102
ppg_extractor/__init__.py Normal file
View File

@@ -0,0 +1,102 @@
import argparse
import torch
from pathlib import Path
import yaml
from .frontend import DefaultFrontend
from .utterance_mvn import UtteranceMVN
from .encoder.conformer_encoder import ConformerEncoder
_model = None # type: PPGModel
_device = None
class PPGModel(torch.nn.Module):
def __init__(
self,
frontend,
normalizer,
encoder,
):
super().__init__()
self.frontend = frontend
self.normalize = normalizer
self.encoder = encoder
def forward(self, speech, speech_lengths):
"""
Args:
speech (tensor): (B, L)
speech_lengths (tensor): (B, )
Returns:
bottle_neck_feats (tensor): (B, L//hop_size, 144)
"""
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
feats, feats_lengths = self.normalize(feats, feats_lengths)
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
return encoder_out
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
):
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
def extract_from_wav(self, src_wav):
src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(_device)
src_wav_lengths = torch.LongTensor([len(src_wav)]).to(_device)
return self(src_wav_tensor, src_wav_lengths)
def build_model(args):
normalizer = UtteranceMVN(**args.normalize_conf)
frontend = DefaultFrontend(**args.frontend_conf)
encoder = ConformerEncoder(input_size=80, **args.encoder_conf)
model = PPGModel(frontend, normalizer, encoder)
return model
def load_model(model_file, device=None):
global _model, _device
if device is None:
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
_device = device
# search a config file
model_config_fpaths = list(model_file.parent.rglob("*.yaml"))
config_file = model_config_fpaths[0]
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
args = argparse.Namespace(**args)
model = build_model(args)
model_state_dict = model.state_dict()
ckpt_state_dict = torch.load(model_file, map_location=_device)
ckpt_state_dict = {k:v for k,v in ckpt_state_dict.items() if 'encoder' in k}
model_state_dict.update(ckpt_state_dict)
model.load_state_dict(model_state_dict)
_model = model.eval().to(_device)
return _model

View File

@@ -0,0 +1,398 @@
#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Common functions for ASR."""
import argparse
import editdistance
import json
import logging
import numpy as np
import six
import sys
from itertools import groupby
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
"""End detection.
desribed in Eq. (50) of S. Watanabe et al
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
:param ended_hyps:
:param i:
:param M:
:param D_end:
:return:
"""
if len(ended_hyps) == 0:
return False
count = 0
best_hyp = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[0]
for m in six.moves.range(M):
# get ended_hyps with their length is i - m
hyp_length = i - m
hyps_same_length = [x for x in ended_hyps if len(x['yseq']) == hyp_length]
if len(hyps_same_length) > 0:
best_hyp_same_length = sorted(hyps_same_length, key=lambda x: x['score'], reverse=True)[0]
if best_hyp_same_length['score'] - best_hyp['score'] < D_end:
count += 1
if count == M:
return True
else:
return False
# TODO(takaaki-hori): add different smoothing methods
def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
"""Obtain label distribution for loss smoothing.
:param odim:
:param lsm_type:
:param blank:
:param transcript:
:return:
"""
if transcript is not None:
with open(transcript, 'rb') as f:
trans_json = json.load(f)['utts']
if lsm_type == 'unigram':
assert transcript is not None, 'transcript is required for %s label smoothing' % lsm_type
labelcount = np.zeros(odim)
for k, v in trans_json.items():
ids = np.array([int(n) for n in v['output'][0]['tokenid'].split()])
# to avoid an error when there is no text in an uttrance
if len(ids) > 0:
labelcount[ids] += 1
labelcount[odim - 1] = len(transcript) # count <eos>
labelcount[labelcount == 0] = 1 # flooring
labelcount[blank] = 0 # remove counts for blank
labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
else:
logging.error(
"Error: unexpected label smoothing type: %s" % lsm_type)
sys.exit()
return labeldist
def get_vgg2l_odim(idim, in_channel=3, out_channel=128, downsample=True):
"""Return the output size of the VGG frontend.
:param in_channel: input channel size
:param out_channel: output channel size
:return: output size
:rtype int
"""
idim = idim / in_channel
if downsample:
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling
return int(idim) * out_channel # numer of channels
class ErrorCalculator(object):
"""Calculate CER and WER for E2E_ASR and CTC models during training.
:param y_hats: numpy array with predicted text
:param y_pads: numpy array with true (target) text
:param char_list:
:param sym_space:
:param sym_blank:
:return:
"""
def __init__(self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False,
trans_type="char"):
"""Construct an ErrorCalculator object."""
super(ErrorCalculator, self).__init__()
self.report_cer = report_cer
self.report_wer = report_wer
self.trans_type = trans_type
self.char_list = char_list
self.space = sym_space
self.blank = sym_blank
self.idx_blank = self.char_list.index(self.blank)
if self.space in self.char_list:
self.idx_space = self.char_list.index(self.space)
else:
self.idx_space = None
def __call__(self, ys_hat, ys_pad, is_ctc=False):
"""Calculate sentence-level WER/CER score.
:param torch.Tensor ys_hat: prediction (batch, seqlen)
:param torch.Tensor ys_pad: reference (batch, seqlen)
:param bool is_ctc: calculate CER score for CTC
:return: sentence-level WER score
:rtype float
:return: sentence-level CER score
:rtype float
"""
cer, wer = None, None
if is_ctc:
return self.calculate_cer_ctc(ys_hat, ys_pad)
elif not self.report_cer and not self.report_wer:
return cer, wer
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
if self.report_cer:
cer = self.calculate_cer(seqs_hat, seqs_true)
if self.report_wer:
wer = self.calculate_wer(seqs_hat, seqs_true)
return cer, wer
def calculate_cer_ctc(self, ys_hat, ys_pad):
"""Calculate sentence-level CER score for CTC.
:param torch.Tensor ys_hat: prediction (batch, seqlen)
:param torch.Tensor ys_pad: reference (batch, seqlen)
:return: average sentence-level CER score
:rtype float
"""
cers, char_ref_lens = [], []
for i, y in enumerate(ys_hat):
y_hat = [x[0] for x in groupby(y)]
y_true = ys_pad[i]
seq_hat, seq_true = [], []
for idx in y_hat:
idx = int(idx)
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
seq_hat.append(self.char_list[int(idx)])
for idx in y_true:
idx = int(idx)
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
seq_true.append(self.char_list[int(idx)])
if self.trans_type == "char":
hyp_chars = "".join(seq_hat)
ref_chars = "".join(seq_true)
else:
hyp_chars = " ".join(seq_hat)
ref_chars = " ".join(seq_true)
if len(ref_chars) > 0:
cers.append(editdistance.eval(hyp_chars, ref_chars))
char_ref_lens.append(len(ref_chars))
cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
return cer_ctc
def convert_to_char(self, ys_hat, ys_pad):
"""Convert index to character.
:param torch.Tensor seqs_hat: prediction (batch, seqlen)
:param torch.Tensor seqs_true: reference (batch, seqlen)
:return: token list of prediction
:rtype list
:return: token list of reference
:rtype list
"""
seqs_hat, seqs_true = [], []
for i, y_hat in enumerate(ys_hat):
y_true = ys_pad[i]
eos_true = np.where(y_true == -1)[0]
eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true)
# To avoid wrong higher WER than the one obtained from the decoding
# eos from y_true is used to mark the eos in y_hat
# because of that y_hats has not padded outs with -1.
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]]
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
# seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
seq_hat_text = " ".join(seq_hat).replace(self.space, ' ')
seq_hat_text = seq_hat_text.replace(self.blank, '')
# seq_true_text = "".join(seq_true).replace(self.space, ' ')
seq_true_text = " ".join(seq_true).replace(self.space, ' ')
seqs_hat.append(seq_hat_text)
seqs_true.append(seq_true_text)
return seqs_hat, seqs_true
def calculate_cer(self, seqs_hat, seqs_true):
"""Calculate sentence-level CER score.
:param list seqs_hat: prediction
:param list seqs_true: reference
:return: average sentence-level CER score
:rtype float
"""
char_eds, char_ref_lens = [], []
for i, seq_hat_text in enumerate(seqs_hat):
seq_true_text = seqs_true[i]
hyp_chars = seq_hat_text.replace(' ', '')
ref_chars = seq_true_text.replace(' ', '')
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
char_ref_lens.append(len(ref_chars))
return float(sum(char_eds)) / sum(char_ref_lens)
def calculate_wer(self, seqs_hat, seqs_true):
"""Calculate sentence-level WER score.
:param list seqs_hat: prediction
:param list seqs_true: reference
:return: average sentence-level WER score
:rtype float
"""
word_eds, word_ref_lens = [], []
for i, seq_hat_text in enumerate(seqs_hat):
seq_true_text = seqs_true[i]
hyp_words = seq_hat_text.split()
ref_words = seq_true_text.split()
word_eds.append(editdistance.eval(hyp_words, ref_words))
word_ref_lens.append(len(ref_words))
return float(sum(word_eds)) / sum(word_ref_lens)
class ErrorCalculatorTrans(object):
"""Calculate CER and WER for transducer models.
Args:
decoder (nn.Module): decoder module
args (Namespace): argument Namespace containing options
report_cer (boolean): compute CER option
report_wer (boolean): compute WER option
"""
def __init__(self, decoder, args, report_cer=False, report_wer=False):
"""Construct an ErrorCalculator object for transducer model."""
super(ErrorCalculatorTrans, self).__init__()
self.dec = decoder
recog_args = {'beam_size': args.beam_size,
'nbest': args.nbest,
'space': args.sym_space,
'score_norm_transducer': args.score_norm_transducer}
self.recog_args = argparse.Namespace(**recog_args)
self.char_list = args.char_list
self.space = args.sym_space
self.blank = args.sym_blank
self.report_cer = args.report_cer
self.report_wer = args.report_wer
def __call__(self, hs_pad, ys_pad):
"""Calculate sentence-level WER/CER score for transducer models.
Args:
hs_pad (torch.Tensor): batch of padded input sequence (batch, T, D)
ys_pad (torch.Tensor): reference (batch, seqlen)
Returns:
(float): sentence-level CER score
(float): sentence-level WER score
"""
cer, wer = None, None
if not self.report_cer and not self.report_wer:
return cer, wer
batchsize = int(hs_pad.size(0))
batch_nbest = []
for b in six.moves.range(batchsize):
if self.recog_args.beam_size == 1:
nbest_hyps = self.dec.recognize(hs_pad[b], self.recog_args)
else:
nbest_hyps = self.dec.recognize_beam(hs_pad[b], self.recog_args)
batch_nbest.append(nbest_hyps)
ys_hat = [nbest_hyp[0]['yseq'][1:] for nbest_hyp in batch_nbest]
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad.cpu())
if self.report_cer:
cer = self.calculate_cer(seqs_hat, seqs_true)
if self.report_wer:
wer = self.calculate_wer(seqs_hat, seqs_true)
return cer, wer
def convert_to_char(self, ys_hat, ys_pad):
"""Convert index to character.
Args:
ys_hat (torch.Tensor): prediction (batch, seqlen)
ys_pad (torch.Tensor): reference (batch, seqlen)
Returns:
(list): token list of prediction
(list): token list of reference
"""
seqs_hat, seqs_true = [], []
for i, y_hat in enumerate(ys_hat):
y_true = ys_pad[i]
eos_true = np.where(y_true == -1)[0]
eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true)
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]]
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
seq_hat_text = seq_hat_text.replace(self.blank, '')
seq_true_text = "".join(seq_true).replace(self.space, ' ')
seqs_hat.append(seq_hat_text)
seqs_true.append(seq_true_text)
return seqs_hat, seqs_true
def calculate_cer(self, seqs_hat, seqs_true):
"""Calculate sentence-level CER score for transducer model.
Args:
seqs_hat (torch.Tensor): prediction (batch, seqlen)
seqs_true (torch.Tensor): reference (batch, seqlen)
Returns:
(float): average sentence-level CER score
"""
char_eds, char_ref_lens = [], []
for i, seq_hat_text in enumerate(seqs_hat):
seq_true_text = seqs_true[i]
hyp_chars = seq_hat_text.replace(' ', '')
ref_chars = seq_true_text.replace(' ', '')
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
char_ref_lens.append(len(ref_chars))
return float(sum(char_eds)) / sum(char_ref_lens)
def calculate_wer(self, seqs_hat, seqs_true):
"""Calculate sentence-level WER score for transducer model.
Args:
seqs_hat (torch.Tensor): prediction (batch, seqlen)
seqs_true (torch.Tensor): reference (batch, seqlen)
Returns:
(float): average sentence-level WER score
"""
word_eds, word_ref_lens = [], []
for i, seq_hat_text in enumerate(seqs_hat):
seq_true_text = seqs_true[i]
hyp_words = seq_hat_text.split()
ref_words = seq_true_text.split()
word_eds.append(editdistance.eval(hyp_words, ref_words))
word_ref_lens.append(len(ref_words))
return float(sum(word_eds)) / sum(word_ref_lens)

View File

View File

@@ -0,0 +1,183 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Multi-Head Attention layer definition."""
import math
import numpy
import torch
from torch import nn
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
:param int n_head: the number of head s
:param int n_feat: the number of features
:param float dropout_rate: dropout rate
"""
def __init__(self, n_head, n_feat, dropout_rate):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttention, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)
def forward_qkv(self, query, key, value):
"""Transform query, key and value.
:param torch.Tensor query: (batch, time1, size)
:param torch.Tensor key: (batch, time2, size)
:param torch.Tensor value: (batch, time2, size)
:return torch.Tensor transformed query, key and value
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
return q, k, v
def forward_attention(self, value, scores, mask):
"""Compute attention context vector.
:param torch.Tensor value: (batch, head, time2, size)
:param torch.Tensor scores: (batch, head, time1, time2)
:param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2)
:return torch.Tensor transformed `value` (batch, time1, d_model)
weighted by the attention score (batch, time1, time2)
"""
n_batch = value.size(0)
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
min_value = float(
numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
)
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(self, query, key, value, mask):
"""Compute 'Scaled Dot Product Attention'.
:param torch.Tensor query: (batch, time1, size)
:param torch.Tensor key: (batch, time2, size)
:param torch.Tensor value: (batch, time2, size)
:param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2)
:param torch.nn.Dropout dropout:
:return torch.Tensor: attention output (batch, time1, d_model)
"""
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
:param int n_head: the number of head s
:param int n_feat: the number of features
:param float dropout_rate: dropout rate
"""
def __init__(self, n_head, n_feat, dropout_rate):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
# linear transformation for positional ecoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x, zero_triu=False):
"""Compute relative positinal encoding.
:param torch.Tensor x: (batch, time, size)
:param bool zero_triu: return the lower triangular part of the matrix
"""
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)
if zero_triu:
ones = torch.ones((x.size(2), x.size(3)))
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x
def forward(self, query, key, value, pos_emb, mask):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
:param torch.Tensor query: (batch, time1, size)
:param torch.Tensor key: (batch, time2, size)
:param torch.Tensor value: (batch, time2, size)
:param torch.Tensor pos_emb: (batch, time1, size)
:param torch.Tensor mask: (batch, time1, time2)
:param torch.nn.Dropout dropout:
:return torch.Tensor: attention output (batch, time1, d_model)
"""
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k
) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask)

View File

@@ -0,0 +1,262 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Encoder definition."""
import logging
import torch
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from .convolution import ConvolutionModule
from .encoder_layer import EncoderLayer
from ..nets_utils import get_activation, make_pad_mask
from .vgg import VGG2L
from .attention import MultiHeadedAttention, RelPositionMultiHeadedAttention
from .embedding import PositionalEncoding, ScaledPositionalEncoding, RelPositionalEncoding
from .layer_norm import LayerNorm
from .multi_layer_conv import Conv1dLinear, MultiLayeredConv1d
from .positionwise_feed_forward import PositionwiseFeedForward
from .repeat import repeat
from .subsampling import Conv2dNoSubsampling, Conv2dSubsampling
class ConformerEncoder(torch.nn.Module):
"""Conformer encoder module.
:param int idim: input dim
:param int attention_dim: dimention of attention
:param int attention_heads: the number of heads of multi head attention
:param int linear_units: the number of units of position-wise feed forward
:param int num_blocks: the number of decoder blocks
:param float dropout_rate: dropout rate
:param float attention_dropout_rate: dropout rate in attention
:param float positional_dropout_rate: dropout rate after adding positional encoding
:param str or torch.nn.Module input_layer: input layer type
:param bool normalize_before: whether to use layer_norm before the first block
:param bool concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
:param str positionwise_layer_type: linear of conv1d
:param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
:param str encoder_pos_enc_layer_type: encoder positional encoding layer type
:param str encoder_attn_layer_type: encoder attention layer type
:param str activation_type: encoder activation function type
:param bool macaron_style: whether to use macaron style for positionwise layer
:param bool use_cnn_module: whether to use convolution module
:param int cnn_module_kernel: kernerl size of convolution module
:param int padding_idx: padding_idx for input_layer=embed
"""
def __init__(
self,
input_size,
attention_dim=256,
attention_heads=4,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
positional_dropout_rate=0.1,
attention_dropout_rate=0.0,
input_layer="conv2d",
normalize_before=True,
concat_after=False,
positionwise_layer_type="linear",
positionwise_conv_kernel_size=1,
macaron_style=False,
pos_enc_layer_type="abs_pos",
selfattention_layer_type="selfattn",
activation_type="swish",
use_cnn_module=False,
cnn_module_kernel=31,
padding_idx=-1,
no_subsample=False,
subsample_by_2=False,
):
"""Construct an Encoder object."""
super().__init__()
self._output_size = attention_dim
idim = input_size
activation = get_activation(activation_type)
if pos_enc_layer_type == "abs_pos":
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "scaled_abs_pos":
pos_enc_class = ScaledPositionalEncoding
elif pos_enc_layer_type == "rel_pos":
assert selfattention_layer_type == "rel_selfattn"
pos_enc_class = RelPositionalEncoding
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(idim, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "conv2d":
logging.info("Encoder input layer type: conv2d")
if no_subsample:
self.embed = Conv2dNoSubsampling(
idim,
attention_dim,
dropout_rate,
pos_enc_class(attention_dim, positional_dropout_rate),
)
else:
self.embed = Conv2dSubsampling(
idim,
attention_dim,
dropout_rate,
pos_enc_class(attention_dim, positional_dropout_rate),
subsample_by_2, # NOTE(Sx): added by songxiang
)
elif input_layer == "vgg2l":
self.embed = VGG2L(idim, attention_dim)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer is None:
self.embed = torch.nn.Sequential(
pos_enc_class(attention_dim, positional_dropout_rate)
)
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
attention_dim,
linear_units,
dropout_rate,
activation,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
attention_dim,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
attention_dim,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
if selfattention_layer_type == "selfattn":
logging.info("encoder self-attention layer type = self-attention")
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
attention_dim,
attention_dropout_rate,
)
elif selfattention_layer_type == "rel_selfattn":
assert pos_enc_layer_type == "rel_pos"
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
attention_dim,
attention_dropout_rate,
)
else:
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
convolution_layer = ConvolutionModule
convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
attention_dim,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
xs_pad: input tensor (B, L, D)
ilens: input lengths (B)
prev_states: Not to be used now.
Returns:
Position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if isinstance(self.embed, (Conv2dSubsampling, Conv2dNoSubsampling, VGG2L)):
# print(xs_pad.shape)
xs_pad, masks = self.embed(xs_pad, masks)
# print(xs_pad[0].size())
else:
xs_pad = self.embed(xs_pad)
xs_pad, masks = self.encoders(xs_pad, masks)
if isinstance(xs_pad, tuple):
xs_pad = xs_pad[0]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
return xs_pad, olens, None
# def forward(self, xs, masks):
# """Encode input sequence.
# :param torch.Tensor xs: input tensor
# :param torch.Tensor masks: input mask
# :return: position embedded tensor and mask
# :rtype Tuple[torch.Tensor, torch.Tensor]:
# """
# if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
# xs, masks = self.embed(xs, masks)
# else:
# xs = self.embed(xs)
# xs, masks = self.encoders(xs, masks)
# if isinstance(xs, tuple):
# xs = xs[0]
# if self.normalize_before:
# xs = self.after_norm(xs)
# return xs, masks

View File

@@ -0,0 +1,74 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Northwestern Polytechnical University (Pengcheng Guo)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""ConvolutionModule definition."""
from torch import nn
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
:param int channels: channels of cnn
:param int kernel_size: kernerl size of cnn
"""
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=channels,
bias=bias,
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = activation
def forward(self, x):
"""Compute convolution module.
:param torch.Tensor x: (batch, time, size)
:return torch.Tensor: convoluted `value` (batch, time, d_model)
"""
# exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
x = self.pointwise_conv2(x)
return x.transpose(1, 2)

View File

@@ -0,0 +1,166 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Positonal Encoding Module."""
import math
import torch
def _pre_hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""Perform pre-hook in load_state_dict for backward compatibility.
Note:
We saved self.pe until v.0.5.2 but we have omitted it later.
Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
"""
k = prefix + "pe"
if k in state_dict:
state_dict.pop(k)
class PositionalEncoding(torch.nn.Module):
"""Positional encoding.
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
:param reverse: whether to reverse the input position
"""
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
"""Construct an PositionalEncoding object."""
super(PositionalEncoding, self).__init__()
self.d_model = d_model
self.reverse = reverse
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
self._register_load_state_dict_pre_hook(_pre_hook)
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.d_model)
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor):
"""Add positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
Returns:
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1)]
return self.dropout(x)
class ScaledPositionalEncoding(PositionalEncoding):
"""Scaled positional encoding module.
See also: Sec. 3.2 https://arxiv.org/pdf/1809.08895.pdf
"""
def __init__(self, d_model, dropout_rate, max_len=5000):
"""Initialize class.
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
"""
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
def reset_parameters(self):
"""Reset parameters."""
self.alpha.data = torch.tensor(1.0)
def forward(self, x):
"""Add positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
Returns:
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
"""
self.extend_pe(x)
x = x + self.alpha * self.pe[:, : x.size(1)]
return self.dropout(x)
class RelPositionalEncoding(PositionalEncoding):
"""Relitive positional encoding module.
See : Appendix B in https://arxiv.org/abs/1901.02860
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
"""
def __init__(self, d_model, dropout_rate, max_len=5000):
"""Initialize class.
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
"""
super().__init__(d_model, dropout_rate, max_len, reverse=True)
def forward(self, x):
"""Compute positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
Returns:
torch.Tensor: x. Its shape is (batch, time, ...)
torch.Tensor: pos_emb. Its shape is (1, time, ...)
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[:, : x.size(1)]
return self.dropout(x), self.dropout(pos_emb)

View File

@@ -0,0 +1,217 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Encoder definition."""
import logging
import torch
from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule
from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer
from espnet.nets.pytorch_backend.nets_utils import get_activation
from espnet.nets.pytorch_backend.transducer.vgg import VGG2L
from espnet.nets.pytorch_backend.transformer.attention import (
MultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttention, # noqa: H301
)
from espnet.nets.pytorch_backend.transformer.embedding import (
PositionalEncoding, # noqa: H301
ScaledPositionalEncoding, # noqa: H301
RelPositionalEncoding, # noqa: H301
)
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling
class Encoder(torch.nn.Module):
"""Conformer encoder module.
:param int idim: input dim
:param int attention_dim: dimention of attention
:param int attention_heads: the number of heads of multi head attention
:param int linear_units: the number of units of position-wise feed forward
:param int num_blocks: the number of decoder blocks
:param float dropout_rate: dropout rate
:param float attention_dropout_rate: dropout rate in attention
:param float positional_dropout_rate: dropout rate after adding positional encoding
:param str or torch.nn.Module input_layer: input layer type
:param bool normalize_before: whether to use layer_norm before the first block
:param bool concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
:param str positionwise_layer_type: linear of conv1d
:param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
:param str encoder_pos_enc_layer_type: encoder positional encoding layer type
:param str encoder_attn_layer_type: encoder attention layer type
:param str activation_type: encoder activation function type
:param bool macaron_style: whether to use macaron style for positionwise layer
:param bool use_cnn_module: whether to use convolution module
:param int cnn_module_kernel: kernerl size of convolution module
:param int padding_idx: padding_idx for input_layer=embed
"""
def __init__(
self,
idim,
attention_dim=256,
attention_heads=4,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
positional_dropout_rate=0.1,
attention_dropout_rate=0.0,
input_layer="conv2d",
normalize_before=True,
concat_after=False,
positionwise_layer_type="linear",
positionwise_conv_kernel_size=1,
macaron_style=False,
pos_enc_layer_type="abs_pos",
selfattention_layer_type="selfattn",
activation_type="swish",
use_cnn_module=False,
cnn_module_kernel=31,
padding_idx=-1,
):
"""Construct an Encoder object."""
super(Encoder, self).__init__()
activation = get_activation(activation_type)
if pos_enc_layer_type == "abs_pos":
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "scaled_abs_pos":
pos_enc_class = ScaledPositionalEncoding
elif pos_enc_layer_type == "rel_pos":
assert selfattention_layer_type == "rel_selfattn"
pos_enc_class = RelPositionalEncoding
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(idim, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(
idim,
attention_dim,
dropout_rate,
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "vgg2l":
self.embed = VGG2L(idim, attention_dim)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer is None:
self.embed = torch.nn.Sequential(
pos_enc_class(attention_dim, positional_dropout_rate)
)
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
attention_dim,
linear_units,
dropout_rate,
activation,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
attention_dim,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
attention_dim,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
if selfattention_layer_type == "selfattn":
logging.info("encoder self-attention layer type = self-attention")
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
attention_dim,
attention_dropout_rate,
)
elif selfattention_layer_type == "rel_selfattn":
assert pos_enc_layer_type == "rel_pos"
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
attention_dim,
attention_dropout_rate,
)
else:
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
convolution_layer = ConvolutionModule
convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
attention_dim,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
def forward(self, xs, masks):
"""Encode input sequence.
:param torch.Tensor xs: input tensor
:param torch.Tensor masks: input mask
:return: position embedded tensor and mask
:rtype Tuple[torch.Tensor, torch.Tensor]:
"""
if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
xs, masks = self.embed(xs, masks)
else:
xs = self.embed(xs)
xs, masks = self.encoders(xs, masks)
if isinstance(xs, tuple):
xs = xs[0]
if self.normalize_before:
xs = self.after_norm(xs)
return xs, masks

View File

@@ -0,0 +1,152 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Encoder self-attention layer definition."""
import torch
from torch import nn
from .layer_norm import LayerNorm
class EncoderLayer(nn.Module):
"""Encoder layer module.
:param int size: input dim
:param espnet.nets.pytorch_backend.transformer.attention.
MultiHeadedAttention self_attn: self attention module
RelPositionMultiHeadedAttention self_attn: self attention module
:param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.
PositionwiseFeedForward feed_forward:
feed forward module
:param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward
for macaron style
PositionwiseFeedForward feed_forward:
feed forward module
:param espnet.nets.pytorch_backend.conformer.convolution.
ConvolutionModule feed_foreard:
feed forward module
:param float dropout_rate: dropout rate
:param bool normalize_before: whether to use layer_norm before the first block
:param bool concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(
self,
size,
self_attn,
feed_forward,
feed_forward_macaron,
conv_module,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.norm_ff = LayerNorm(size) # for the FNN module
self.norm_mha = LayerNorm(size) # for the MHA module
if feed_forward_macaron is not None:
self.norm_ff_macaron = LayerNorm(size)
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = LayerNorm(size) # for the CNN module
self.norm_final = LayerNorm(size) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
def forward(self, x_input, mask, cache=None):
"""Compute encoded features.
:param torch.Tensor x_input: encoded source features, w/o pos_emb
tuple((batch, max_time_in, size), (1, max_time_in, size))
or (batch, max_time_in, size)
:param torch.Tensor mask: mask for x (batch, max_time_in)
:param torch.Tensor cache: cache for x (batch, max_time_in - 1, size)
:rtype: Tuple[torch.Tensor, torch.Tensor]
"""
if isinstance(x_input, tuple):
x, pos_emb = x_input[0], x_input[1]
else:
x, pos_emb = x_input, None
# whether to use macaron style
if self.feed_forward_macaron is not None:
residual = x
if self.normalize_before:
x = self.norm_ff_macaron(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
if not self.normalize_before:
x = self.norm_ff_macaron(x)
# multi-headed self-attention module
residual = x
if self.normalize_before:
x = self.norm_mha(x)
if cache is None:
x_q = x
else:
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :]
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]
if pos_emb is not None:
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
else:
x_att = self.self_attn(x_q, x, x, mask)
if self.concat_after:
x_concat = torch.cat((x, x_att), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm_mha(x)
# convolution module
if self.conv_module is not None:
residual = x
if self.normalize_before:
x = self.norm_conv(x)
x = residual + self.dropout(self.conv_module(x))
if not self.normalize_before:
x = self.norm_conv(x)
# feed forward module
residual = x
if self.normalize_before:
x = self.norm_ff(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm_ff(x)
if self.conv_module is not None:
x = self.norm_final(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
if pos_emb is not None:
return (x, pos_emb), mask
return x, mask

View File

@@ -0,0 +1,33 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Layer normalization module."""
import torch
class LayerNorm(torch.nn.LayerNorm):
"""Layer normalization module.
:param int nout: output dim size
:param int dim: dimension to be normalized
"""
def __init__(self, nout, dim=-1):
"""Construct an LayerNorm object."""
super(LayerNorm, self).__init__(nout, eps=1e-12)
self.dim = dim
def forward(self, x):
"""Apply layer normalization.
:param torch.Tensor x: input tensor
:return: layer normalized tensor
:rtype torch.Tensor
"""
if self.dim == -1:
return super(LayerNorm, self).forward(x)
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)

View File

@@ -0,0 +1,105 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Layer modules for FFT block in FastSpeech (Feed-forward Transformer)."""
import torch
class MultiLayeredConv1d(torch.nn.Module):
"""Multi-layered conv1d for Transformer block.
This is a module of multi-leyered conv1d designed
to replace positionwise feed-forward network
in Transforner block, which is introduced in
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
https://arxiv.org/pdf/1905.09263.pdf
"""
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
"""Initialize MultiLayeredConv1d module.
Args:
in_chans (int): Number of input channels.
hidden_chans (int): Number of hidden channels.
kernel_size (int): Kernel size of conv1d.
dropout_rate (float): Dropout rate.
"""
super(MultiLayeredConv1d, self).__init__()
self.w_1 = torch.nn.Conv1d(
in_chans,
hidden_chans,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
)
self.w_2 = torch.nn.Conv1d(
hidden_chans,
in_chans,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
)
self.dropout = torch.nn.Dropout(dropout_rate)
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Batch of input tensors (B, ..., in_chans).
Returns:
Tensor: Batch of output tensors (B, ..., hidden_chans).
"""
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
class Conv1dLinear(torch.nn.Module):
"""Conv1D + Linear for Transformer block.
A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
"""
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
"""Initialize Conv1dLinear module.
Args:
in_chans (int): Number of input channels.
hidden_chans (int): Number of hidden channels.
kernel_size (int): Kernel size of conv1d.
dropout_rate (float): Dropout rate.
"""
super(Conv1dLinear, self).__init__()
self.w_1 = torch.nn.Conv1d(
in_chans,
hidden_chans,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
)
self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
self.dropout = torch.nn.Dropout(dropout_rate)
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Batch of input tensors (B, ..., in_chans).
Returns:
Tensor: Batch of output tensors (B, ..., hidden_chans).
"""
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
return self.w_2(self.dropout(x))

View File

@@ -0,0 +1,31 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Positionwise feed forward layer definition."""
import torch
class PositionwiseFeedForward(torch.nn.Module):
"""Positionwise feed forward layer.
:param int idim: input dimenstion
:param int hidden_units: number of hidden units
:param float dropout_rate: dropout rate
"""
def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
"""Construct an PositionwiseFeedForward object."""
super(PositionwiseFeedForward, self).__init__()
self.w_1 = torch.nn.Linear(idim, hidden_units)
self.w_2 = torch.nn.Linear(hidden_units, idim)
self.dropout = torch.nn.Dropout(dropout_rate)
self.activation = activation
def forward(self, x):
"""Forward funciton."""
return self.w_2(self.dropout(self.activation(self.w_1(x))))

View File

@@ -0,0 +1,30 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Repeat the same layer definition."""
import torch
class MultiSequential(torch.nn.Sequential):
"""Multi-input multi-output torch.nn.Sequential."""
def forward(self, *args):
"""Repeat."""
for m in self:
args = m(*args)
return args
def repeat(N, fn):
"""Repeat module N times.
:param int N: repeat time
:param function fn: function to generate module
:return: repeated modules
:rtype: MultiSequential
"""
return MultiSequential(*[fn(n) for n in range(N)])

View File

@@ -0,0 +1,218 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Subsampling layer definition."""
import logging
import torch
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
class Conv2dSubsampling(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/4 length or 1/2 length).
:param int idim: input dim
:param int odim: output dim
:param flaot dropout_rate: dropout rate
:param torch.nn.Module pos_enc: custom position encoding layer
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None,
subsample_by_2=False,
):
"""Construct an Conv2dSubsampling object."""
super(Conv2dSubsampling, self).__init__()
self.subsample_by_2 = subsample_by_2
if subsample_by_2:
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (idim // 2), odim),
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
)
else:
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, kernel_size=4, stride=2, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (idim // 4), odim),
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
)
def forward(self, x, x_mask):
"""Subsample x.
:param torch.Tensor x: input tensor
:param torch.Tensor x_mask: input mask
:return: subsampled x and mask
:rtype Tuple[torch.Tensor, torch.Tensor]
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
if self.subsample_by_2:
return x, x_mask[:, :, ::2]
else:
return x, x_mask[:, :, ::2][:, :, ::2]
def __getitem__(self, key):
"""Subsample x.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]
class Conv2dNoSubsampling(torch.nn.Module):
"""Convolutional 2D without subsampling.
:param int idim: input dim
:param int odim: output dim
:param flaot dropout_rate: dropout rate
:param torch.nn.Module pos_enc: custom position encoding layer
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling object."""
super().__init__()
logging.info("Encoder does not do down-sample on mel-spectrogram.")
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, kernel_size=5, stride=1, padding=2),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * idim, odim),
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
)
def forward(self, x, x_mask):
"""Subsample x.
:param torch.Tensor x: input tensor
:param torch.Tensor x_mask: input mask
:return: subsampled x and mask
:rtype Tuple[torch.Tensor, torch.Tensor]
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask
def __getitem__(self, key):
"""Subsample x.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]
class Conv2dSubsampling6(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/6 length).
:param int idim: input dim
:param int odim: output dim
:param flaot dropout_rate: dropout rate
"""
def __init__(self, idim, odim, dropout_rate):
"""Construct an Conv2dSubsampling object."""
super(Conv2dSubsampling6, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 5, 3),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim),
PositionalEncoding(odim, dropout_rate),
)
def forward(self, x, x_mask):
"""Subsample x.
:param torch.Tensor x: input tensor
:param torch.Tensor x_mask: input mask
:return: subsampled x and mask
:rtype Tuple[torch.Tensor, torch.Tensor]
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :, :-2:2][:, :, :-4:3]
class Conv2dSubsampling8(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/8 length).
:param int idim: input dim
:param int odim: output dim
:param flaot dropout_rate: dropout rate
"""
def __init__(self, idim, odim, dropout_rate):
"""Construct an Conv2dSubsampling object."""
super(Conv2dSubsampling8, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim),
PositionalEncoding(odim, dropout_rate),
)
def forward(self, x, x_mask):
"""Subsample x.
:param torch.Tensor x: input tensor
:param torch.Tensor x_mask: input mask
:return: subsampled x and mask
:rtype Tuple[torch.Tensor, torch.Tensor]
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]

View File

@@ -0,0 +1,18 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Northwestern Polytechnical University (Pengcheng Guo)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Swish() activation function for Conformer."""
import torch
class Swish(torch.nn.Module):
"""Construct an Swish object."""
def forward(self, x):
"""Return Swich activation function."""
return x * torch.sigmoid(x)

View File

@@ -0,0 +1,77 @@
"""VGG2L definition for transformer-transducer."""
import torch
class VGG2L(torch.nn.Module):
"""VGG2L module for transformer-transducer encoder."""
def __init__(self, idim, odim):
"""Construct a VGG2L object.
Args:
idim (int): dimension of inputs
odim (int): dimension of outputs
"""
super(VGG2L, self).__init__()
self.vgg2l = torch.nn.Sequential(
torch.nn.Conv2d(1, 64, 3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(64, 64, 3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d((3, 2)),
torch.nn.Conv2d(64, 128, 3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d((2, 2)),
)
self.output = torch.nn.Linear(128 * ((idim // 2) // 2), odim)
def forward(self, x, x_mask):
"""VGG2L forward for x.
Args:
x (torch.Tensor): input torch (B, T, idim)
x_mask (torch.Tensor): (B, 1, T)
Returns:
x (torch.Tensor): input torch (B, sub(T), attention_dim)
x_mask (torch.Tensor): (B, 1, sub(T))
"""
x = x.unsqueeze(1)
x = self.vgg2l(x)
b, c, t, f = x.size()
x = self.output(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
else:
x_mask = self.create_new_mask(x_mask, x)
return x, x_mask
def create_new_mask(self, x_mask, x):
"""Create a subsampled version of x_mask.
Args:
x_mask (torch.Tensor): (B, 1, T)
x (torch.Tensor): (B, sub(T), attention_dim)
Returns:
x_mask (torch.Tensor): (B, 1, sub(T))
"""
x_t1 = x_mask.size(2) - (x_mask.size(2) % 3)
x_mask = x_mask[:, :, :x_t1][:, :, ::3]
x_t2 = x_mask.size(2) - (x_mask.size(2) % 2)
x_mask = x_mask[:, :, :x_t2][:, :, ::2]
return x_mask

298
ppg_extractor/encoders.py Normal file
View File

@@ -0,0 +1,298 @@
import logging
import six
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
from .e2e_asr_common import get_vgg2l_odim
from .nets_utils import make_pad_mask, to_device
class RNNP(torch.nn.Module):
"""RNN with projection layer module
:param int idim: dimension of inputs
:param int elayers: number of encoder layers
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
:param int hdim: number of projection units
:param np.ndarray subsample: list of subsampling numbers
:param float dropout: dropout rate
:param str typ: The RNN type
"""
def __init__(self, idim, elayers, cdim, hdim, subsample, dropout, typ="blstm"):
super(RNNP, self).__init__()
bidir = typ[0] == "b"
for i in six.moves.range(elayers):
if i == 0:
inputdim = idim
else:
inputdim = hdim
rnn = torch.nn.LSTM(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir,
batch_first=True) if "lstm" in typ \
else torch.nn.GRU(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir, batch_first=True)
setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn)
# bottleneck layer to merge
if bidir:
setattr(self, "bt%d" % i, torch.nn.Linear(2 * cdim, hdim))
else:
setattr(self, "bt%d" % i, torch.nn.Linear(cdim, hdim))
self.elayers = elayers
self.cdim = cdim
self.subsample = subsample
self.typ = typ
self.bidir = bidir
def forward(self, xs_pad, ilens, prev_state=None):
"""RNNP forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor prev_state: batch of previous RNN states
:return: batch of hidden state sequences (B, Tmax, hdim)
:rtype: torch.Tensor
"""
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
elayer_states = []
for layer in six.moves.range(self.elayers):
xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True, enforce_sorted=False)
rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer))
rnn.flatten_parameters()
if prev_state is not None and rnn.bidirectional:
prev_state = reset_backward_rnn_state(prev_state)
ys, states = rnn(xs_pack, hx=None if prev_state is None else prev_state[layer])
elayer_states.append(states)
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
sub = self.subsample[layer + 1]
if sub > 1:
ys_pad = ys_pad[:, ::sub]
ilens = [int(i + 1) // sub for i in ilens]
# (sum _utt frame_utt) x dim
projected = getattr(self, 'bt' + str(layer)
)(ys_pad.contiguous().view(-1, ys_pad.size(2)))
if layer == self.elayers - 1:
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
else:
xs_pad = torch.tanh(projected.view(ys_pad.size(0), ys_pad.size(1), -1))
return xs_pad, ilens, elayer_states # x: utt list of frame x dim
class RNN(torch.nn.Module):
"""RNN module
:param int idim: dimension of inputs
:param int elayers: number of encoder layers
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
:param int hdim: number of final projection units
:param float dropout: dropout rate
:param str typ: The RNN type
"""
def __init__(self, idim, elayers, cdim, hdim, dropout, typ="blstm"):
super(RNN, self).__init__()
bidir = typ[0] == "b"
self.nbrnn = torch.nn.LSTM(idim, cdim, elayers, batch_first=True,
dropout=dropout, bidirectional=bidir) if "lstm" in typ \
else torch.nn.GRU(idim, cdim, elayers, batch_first=True, dropout=dropout,
bidirectional=bidir)
if bidir:
self.l_last = torch.nn.Linear(cdim * 2, hdim)
else:
self.l_last = torch.nn.Linear(cdim, hdim)
self.typ = typ
def forward(self, xs_pad, ilens, prev_state=None):
"""RNN forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor prev_state: batch of previous RNN states
:return: batch of hidden state sequences (B, Tmax, eprojs)
:rtype: torch.Tensor
"""
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True)
self.nbrnn.flatten_parameters()
if prev_state is not None and self.nbrnn.bidirectional:
# We assume that when previous state is passed, it means that we're streaming the input
# and therefore cannot propagate backward BRNN state (otherwise it goes in the wrong direction)
prev_state = reset_backward_rnn_state(prev_state)
ys, states = self.nbrnn(xs_pack, hx=prev_state)
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
# (sum _utt frame_utt) x dim
projected = torch.tanh(self.l_last(
ys_pad.contiguous().view(-1, ys_pad.size(2))))
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
return xs_pad, ilens, states # x: utt list of frame x dim
def reset_backward_rnn_state(states):
"""Sets backward BRNN states to zeroes - useful in processing of sliding windows over the inputs"""
if isinstance(states, (list, tuple)):
for state in states:
state[1::2] = 0.
else:
states[1::2] = 0.
return states
class VGG2L(torch.nn.Module):
"""VGG-like module
:param int in_channel: number of input channels
"""
def __init__(self, in_channel=1, downsample=True):
super(VGG2L, self).__init__()
# CNN layer (VGG motivated)
self.conv1_1 = torch.nn.Conv2d(in_channel, 64, 3, stride=1, padding=1)
self.conv1_2 = torch.nn.Conv2d(64, 64, 3, stride=1, padding=1)
self.conv2_1 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1)
self.conv2_2 = torch.nn.Conv2d(128, 128, 3, stride=1, padding=1)
self.in_channel = in_channel
self.downsample = downsample
if downsample:
self.stride = 2
else:
self.stride = 1
def forward(self, xs_pad, ilens, **kwargs):
"""VGG2L forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:return: batch of padded hidden state sequences (B, Tmax // 4, 128 * D // 4) if downsample
:rtype: torch.Tensor
"""
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
# x: utt x frame x dim
# xs_pad = F.pad_sequence(xs_pad)
# x: utt x 1 (input channel num) x frame x dim
xs_pad = xs_pad.view(xs_pad.size(0), xs_pad.size(1), self.in_channel,
xs_pad.size(2) // self.in_channel).transpose(1, 2)
# NOTE: max_pool1d ?
xs_pad = F.relu(self.conv1_1(xs_pad))
xs_pad = F.relu(self.conv1_2(xs_pad))
if self.downsample:
xs_pad = F.max_pool2d(xs_pad, 2, stride=self.stride, ceil_mode=True)
xs_pad = F.relu(self.conv2_1(xs_pad))
xs_pad = F.relu(self.conv2_2(xs_pad))
if self.downsample:
xs_pad = F.max_pool2d(xs_pad, 2, stride=self.stride, ceil_mode=True)
if torch.is_tensor(ilens):
ilens = ilens.cpu().numpy()
else:
ilens = np.array(ilens, dtype=np.float32)
if self.downsample:
ilens = np.array(np.ceil(ilens / 2), dtype=np.int64)
ilens = np.array(
np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64).tolist()
# x: utt_list of frame (remove zeropaded frames) x (input channel num x dim)
xs_pad = xs_pad.transpose(1, 2)
xs_pad = xs_pad.contiguous().view(
xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3))
return xs_pad, ilens, None # no state in this layer
class Encoder(torch.nn.Module):
"""Encoder module
:param str etype: type of encoder network
:param int idim: number of dimensions of encoder network
:param int elayers: number of layers of encoder network
:param int eunits: number of lstm units of encoder network
:param int eprojs: number of projection units of encoder network
:param np.ndarray subsample: list of subsampling numbers
:param float dropout: dropout rate
:param int in_channel: number of input channels
"""
def __init__(self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1):
super(Encoder, self).__init__()
typ = etype.lstrip("vgg").rstrip("p")
if typ not in ['lstm', 'gru', 'blstm', 'bgru']:
logging.error("Error: need to specify an appropriate encoder architecture")
if etype.startswith("vgg"):
if etype[-1] == "p":
self.enc = torch.nn.ModuleList([VGG2L(in_channel),
RNNP(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits,
eprojs,
subsample, dropout, typ=typ)])
logging.info('Use CNN-VGG + ' + typ.upper() + 'P for encoder')
else:
self.enc = torch.nn.ModuleList([VGG2L(in_channel),
RNN(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits,
eprojs,
dropout, typ=typ)])
logging.info('Use CNN-VGG + ' + typ.upper() + ' for encoder')
else:
if etype[-1] == "p":
self.enc = torch.nn.ModuleList(
[RNNP(idim, elayers, eunits, eprojs, subsample, dropout, typ=typ)])
logging.info(typ.upper() + ' with every-layer projection for encoder')
else:
self.enc = torch.nn.ModuleList([RNN(idim, elayers, eunits, eprojs, dropout, typ=typ)])
logging.info(typ.upper() + ' without projection for encoder')
def forward(self, xs_pad, ilens, prev_states=None):
"""Encoder forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...)
:return: batch of hidden state sequences (B, Tmax, eprojs)
:rtype: torch.Tensor
"""
if prev_states is None:
prev_states = [None] * len(self.enc)
assert len(prev_states) == len(self.enc)
current_states = []
for module, prev_state in zip(self.enc, prev_states):
xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
current_states.append(states)
# make mask to remove bias value in padded part
mask = to_device(self, make_pad_mask(ilens).unsqueeze(-1))
return xs_pad.masked_fill(mask, 0.0), ilens, current_states
def encoder_for(args, idim, subsample):
"""Instantiates an encoder module given the program arguments
:param Namespace args: The arguments
:param int or List of integer idim: dimension of input, e.g. 83, or
List of dimensions of inputs, e.g. [83,83]
:param List or List of List subsample: subsample factors, e.g. [1,2,2,1,1], or
List of subsample factors of each encoder. e.g. [[1,2,2,1,1], [1,2,2,1,1]]
:rtype torch.nn.Module
:return: The encoder module
"""
num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility
if num_encs == 1:
# compatible with single encoder asr mode
return Encoder(args.etype, idim, args.elayers, args.eunits, args.eprojs, subsample, args.dropout_rate)
elif num_encs >= 1:
enc_list = torch.nn.ModuleList()
for idx in range(num_encs):
enc = Encoder(args.etype[idx], idim[idx], args.elayers[idx], args.eunits[idx], args.eprojs, subsample[idx],
args.dropout_rate[idx])
enc_list.append(enc)
return enc_list
else:
raise ValueError("Number of encoders needs to be more than one. {}".format(num_encs))

115
ppg_extractor/frontend.py Normal file
View File

@@ -0,0 +1,115 @@
import copy
from typing import Tuple
import numpy as np
import torch
from torch_complex.tensor import ComplexTensor
from .log_mel import LogMel
from .stft import Stft
class DefaultFrontend(torch.nn.Module):
"""Conventional frontend structure for ASR
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
"""
def __init__(
self,
fs: 16000,
n_fft: int = 1024,
win_length: int = 800,
hop_length: int = 160,
center: bool = True,
pad_mode: str = "reflect",
normalized: bool = False,
onesided: bool = True,
n_mels: int = 80,
fmin: int = None,
fmax: int = None,
htk: bool = False,
norm=1,
frontend_conf=None, #Optional[dict] = get_default_kwargs(Frontend),
kaldi_padding_mode=False,
downsample_rate: int = 1,
):
super().__init__()
self.downsample_rate = downsample_rate
# Deepcopy (In general, dict shouldn't be used as default arg)
frontend_conf = copy.deepcopy(frontend_conf)
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
center=center,
pad_mode=pad_mode,
normalized=normalized,
onesided=onesided,
kaldi_padding_mode=kaldi_padding_mode
)
if frontend_conf is not None:
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
else:
self.frontend = None
self.logmel = LogMel(
fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm,
)
self.n_mels = n_mels
def output_size(self) -> int:
return self.n_mels
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Domain-conversion: e.g. Stft: time -> time-freq
input_stft, feats_lens = self.stft(input, input_lengths)
assert input_stft.dim() >= 4, input_stft.shape
# "2" refers to the real/imag parts of Complex
assert input_stft.shape[-1] == 2, input_stft.shape
# Change torch.Tensor to ComplexTensor
# input_stft: (..., F, 2) -> (..., F)
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
# 2. [Option] Speech enhancement
if self.frontend is not None:
assert isinstance(input_stft, ComplexTensor), type(input_stft)
# input_stft: (Batch, Length, [Channel], Freq)
input_stft, _, mask = self.frontend(input_stft, feats_lens)
# 3. [Multi channel case]: Select a channel
if input_stft.dim() == 4:
# h: (B, T, C, F) -> h: (B, T, F)
if self.training:
# Select 1ch randomly
ch = np.random.randint(input_stft.size(2))
input_stft = input_stft[:, :, ch, :]
else:
# Use the first channel
input_stft = input_stft[:, :, 0, :]
# 4. STFT -> Power spectrum
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
input_power = input_stft.real ** 2 + input_stft.imag ** 2
# 5. Feature transform e.g. Stft -> Log-Mel-Fbank
# input_power: (Batch, [Channel,] Length, Freq)
# -> input_feats: (Batch, Length, Dim)
input_feats, _ = self.logmel(input_power, feats_lens)
# NOTE(sx): pad
max_len = input_feats.size(1)
if self.downsample_rate > 1 and max_len % self.downsample_rate != 0:
padding = self.downsample_rate - max_len % self.downsample_rate
# print("Logmel: ", input_feats.size())
input_feats = torch.nn.functional.pad(input_feats, (0, 0, 0, padding),
"constant", 0)
# print("Logmel(after padding): ",input_feats.size())
feats_lens[torch.argmax(feats_lens)] = max_len + padding
return input_feats, feats_lens

74
ppg_extractor/log_mel.py Normal file
View File

@@ -0,0 +1,74 @@
import librosa
import numpy as np
import torch
from typing import Tuple
from .nets_utils import make_pad_mask
class LogMel(torch.nn.Module):
"""Convert STFT to fbank feats
The arguments is same as librosa.filters.mel
Args:
fs: number > 0 [scalar] sampling rate of the incoming signal
n_fft: int > 0 [scalar] number of FFT components
n_mels: int > 0 [scalar] number of Mel bands to generate
fmin: float >= 0 [scalar] lowest frequency (in Hz)
fmax: float >= 0 [scalar] highest frequency (in Hz).
If `None`, use `fmax = fs / 2.0`
htk: use HTK formula instead of Slaney
norm: {None, 1, np.inf} [scalar]
if 1, divide the triangular mel weights by the width of the mel band
(area normalization). Otherwise, leave all the triangles aiming for
a peak value of 1.0
"""
def __init__(
self,
fs: int = 16000,
n_fft: int = 512,
n_mels: int = 80,
fmin: float = None,
fmax: float = None,
htk: bool = False,
norm=1,
):
super().__init__()
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
_mel_options = dict(
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
)
self.mel_options = _mel_options
# Note(kamo): The mel matrix of librosa is different from kaldi.
melmat = librosa.filters.mel(**_mel_options)
# melmat: (D2, D1) -> (D1, D2)
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
inv_mel = np.linalg.pinv(melmat)
self.register_buffer("inv_melmat", torch.from_numpy(inv_mel.T).float())
def extra_repr(self):
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
def forward(
self, feat: torch.Tensor, ilens: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
mel_feat = torch.matmul(feat, self.melmat)
logmel_feat = (mel_feat + 1e-20).log()
# Zero padding
if ilens is not None:
logmel_feat = logmel_feat.masked_fill(
make_pad_mask(ilens, logmel_feat, 1), 0.0
)
else:
ilens = feat.new_full(
[feat.size(0)], fill_value=feat.size(1), dtype=torch.long
)
return logmel_feat, ilens

465
ppg_extractor/nets_utils.py Normal file
View File

@@ -0,0 +1,465 @@
# -*- coding: utf-8 -*-
"""Network related utility tools."""
import logging
from typing import Dict
import numpy as np
import torch
def to_device(m, x):
"""Send tensor into the device of the module.
Args:
m (torch.nn.Module): Torch module.
x (Tensor): Torch tensor.
Returns:
Tensor: Torch tensor located in the same place as torch module.
"""
assert isinstance(m, torch.nn.Module)
device = next(m.parameters()).device
return x.to(device)
def pad_list(xs, pad_value):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
n_batch = len(xs)
max_len = max(x.size(0) for x in xs)
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
for i in range(n_batch):
pad[i, :xs[i].size(0)] = xs[i]
return pad
def make_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
Returns:
Tensor: Mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
"""
if length_dim == 0:
raise ValueError('length_dim cannot be 0: {}'.format(length_dim))
if not isinstance(lengths, list):
lengths = lengths.tolist()
bs = int(len(lengths))
if xs is None:
maxlen = int(max(lengths))
else:
maxlen = xs.size(length_dim)
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
if xs is not None:
assert xs.size(0) == bs, (xs.size(0), bs)
if length_dim < 0:
length_dim = xs.dim() + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(slice(None) if i in (0, length_dim) else None
for i in range(xs.dim()))
mask = mask[ind].expand_as(xs).to(xs.device)
return mask
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of non-padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
Returns:
ByteTensor: mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 1, 0],
[1, 1, 1, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_non_pad_mask(lengths, xs, 1)
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
>>> make_non_pad_mask(lengths, xs, 2)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
return ~make_pad_mask(lengths, xs, length_dim)
def mask_by_length(xs, lengths, fill=0):
"""Mask tensor according to length.
Args:
xs (Tensor): Batch of input tensor (B, `*`).
lengths (LongTensor or List): Batch of lengths (B,).
fill (int or float): Value to fill masked part.
Returns:
Tensor: Batch of masked input tensor (B, `*`).
Examples:
>>> x = torch.arange(5).repeat(3, 1) + 1
>>> x
tensor([[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]])
>>> lengths = [5, 3, 2]
>>> mask_by_length(x, lengths)
tensor([[1, 2, 3, 4, 5],
[1, 2, 3, 0, 0],
[1, 2, 0, 0, 0]])
"""
assert xs.size(0) == len(lengths)
ret = xs.data.new(*xs.size()).fill_(fill)
for i, l in enumerate(lengths):
ret[i, :l] = xs[i, :l]
return ret
def th_accuracy(pad_outputs, pad_targets, ignore_label):
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(
pad_targets.size(0),
pad_targets.size(1),
pad_outputs.size(1)).argmax(2)
mask = pad_targets != ignore_label
numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
denominator = torch.sum(mask)
return float(numerator) / float(denominator)
def to_torch_tensor(x):
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
Args:
x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
Returns:
Tensor or ComplexTensor: Type converted inputs.
Examples:
>>> xs = np.ones(3, dtype=np.float32)
>>> xs = to_torch_tensor(xs)
tensor([1., 1., 1.])
>>> xs = torch.ones(3, 4, 5)
>>> assert to_torch_tensor(xs) is xs
>>> xs = {'real': xs, 'imag': xs}
>>> to_torch_tensor(xs)
ComplexTensor(
Real:
tensor([1., 1., 1.])
Imag;
tensor([1., 1., 1.])
)
"""
# If numpy, change to torch tensor
if isinstance(x, np.ndarray):
if x.dtype.kind == 'c':
# Dynamically importing because torch_complex requires python3
from torch_complex.tensor import ComplexTensor
return ComplexTensor(x)
else:
return torch.from_numpy(x)
# If {'real': ..., 'imag': ...}, convert to ComplexTensor
elif isinstance(x, dict):
# Dynamically importing because torch_complex requires python3
from torch_complex.tensor import ComplexTensor
if 'real' not in x or 'imag' not in x:
raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
# Relative importing because of using python3 syntax
return ComplexTensor(x['real'], x['imag'])
# If torch.Tensor, as it is
elif isinstance(x, torch.Tensor):
return x
else:
error = ("x must be numpy.ndarray, torch.Tensor or a dict like "
"{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
"but got {}".format(type(x)))
try:
from torch_complex.tensor import ComplexTensor
except Exception:
# If PY2
raise ValueError(error)
else:
# If PY3
if isinstance(x, ComplexTensor):
return x
else:
raise ValueError(error)
def get_subsample(train_args, mode, arch):
"""Parse the subsampling factors from the training args for the specified `mode` and `arch`.
Args:
train_args: argument Namespace containing options.
mode: one of ('asr', 'mt', 'st')
arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
Returns:
np.ndarray / List[np.ndarray]: subsampling factors.
"""
if arch == 'transformer':
return np.array([1])
elif mode == 'mt' and arch == 'rnn':
# +1 means input (+1) and layers outputs (train_args.elayer)
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
logging.warning('Subsampling is not performed for machine translation.')
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
return subsample
elif (mode == 'asr' and arch in ('rnn', 'rnn-t')) or \
(mode == 'mt' and arch == 'rnn') or \
(mode == 'st' and arch == 'rnn'):
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
ss = train_args.subsample.split("_")
for j in range(min(train_args.elayers + 1, len(ss))):
subsample[j] = int(ss[j])
else:
logging.warning(
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
return subsample
elif mode == 'asr' and arch == 'rnn_mix':
subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int)
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
ss = train_args.subsample.split("_")
for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))):
subsample[j] = int(ss[j])
else:
logging.warning(
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
return subsample
elif mode == 'asr' and arch == 'rnn_mulenc':
subsample_list = []
for idx in range(train_args.num_encs):
subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int)
if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"):
ss = train_args.subsample[idx].split("_")
for j in range(min(train_args.elayers[idx] + 1, len(ss))):
subsample[j] = int(ss[j])
else:
logging.warning(
'Encoder %d: Subsampling is not performed for vgg*. '
'It is performed in max pooling layers at CNN.', idx + 1)
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
subsample_list.append(subsample)
return subsample_list
else:
raise ValueError('Invalid options: mode={}, arch={}'.format(mode, arch))
def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]):
"""Replace keys of old prefix with new prefix in state dict."""
# need this list not to break the dict iterator
old_keys = [k for k in state_dict if k.startswith(old_prefix)]
if len(old_keys) > 0:
logging.warning(f'Rename: {old_prefix} -> {new_prefix}')
for k in old_keys:
v = state_dict.pop(k)
new_k = k.replace(old_prefix, new_prefix)
state_dict[new_k] = v
def get_activation(act):
"""Return activation function."""
# Lazy load to avoid unused import
from .encoder.swish import Swish
activation_funcs = {
"hardtanh": torch.nn.Hardtanh,
"relu": torch.nn.ReLU,
"selu": torch.nn.SELU,
"swish": Swish,
}
return activation_funcs[act]()

118
ppg_extractor/stft.py Normal file
View File

@@ -0,0 +1,118 @@
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from .nets_utils import make_pad_mask
class Stft(torch.nn.Module):
def __init__(
self,
n_fft: int = 512,
win_length: Union[int, None] = 512,
hop_length: int = 128,
center: bool = True,
pad_mode: str = "reflect",
normalized: bool = False,
onesided: bool = True,
kaldi_padding_mode=False,
):
super().__init__()
self.n_fft = n_fft
if win_length is None:
self.win_length = n_fft
else:
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.pad_mode = pad_mode
self.normalized = normalized
self.onesided = onesided
self.kaldi_padding_mode = kaldi_padding_mode
if self.kaldi_padding_mode:
self.win_length = 400
def extra_repr(self):
return (
f"n_fft={self.n_fft}, "
f"win_length={self.win_length}, "
f"hop_length={self.hop_length}, "
f"center={self.center}, "
f"pad_mode={self.pad_mode}, "
f"normalized={self.normalized}, "
f"onesided={self.onesided}"
)
def forward(
self, input: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""STFT forward function.
Args:
input: (Batch, Nsamples) or (Batch, Nsample, Channels)
ilens: (Batch)
Returns:
output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
"""
bs = input.size(0)
if input.dim() == 3:
multi_channel = True
# input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
input = input.transpose(1, 2).reshape(-1, input.size(1))
else:
multi_channel = False
# output: (Batch, Freq, Frames, 2=real_imag)
# or (Batch, Channel, Freq, Frames, 2=real_imag)
if not self.kaldi_padding_mode:
output = torch.stft(
input,
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
center=self.center,
pad_mode=self.pad_mode,
normalized=self.normalized,
onesided=self.onesided,
return_complex=False
)
else:
# NOTE(sx): Use Kaldi-fasion padding, maybe wrong
num_pads = self.n_fft - self.win_length
input = torch.nn.functional.pad(input, (num_pads, 0))
output = torch.stft(
input,
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
center=False,
pad_mode=self.pad_mode,
normalized=self.normalized,
onesided=self.onesided,
return_complex=False
)
# output: (Batch, Freq, Frames, 2=real_imag)
# -> (Batch, Frames, Freq, 2=real_imag)
output = output.transpose(1, 2)
if multi_channel:
# output: (Batch * Channel, Frames, Freq, 2=real_imag)
# -> (Batch, Frame, Channel, Freq, 2=real_imag)
output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
1, 2
)
if ilens is not None:
if self.center:
pad = self.win_length // 2
ilens = ilens + 2 * pad
olens = torch.div(ilens - self.win_length, self.hop_length, rounding_mode='floor') + 1
# olens = ilens - self.win_length // self.hop_length + 1
output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
else:
olens = None
return output, olens

View File

@@ -0,0 +1,82 @@
from typing import Tuple
import torch
from .nets_utils import make_pad_mask
class UtteranceMVN(torch.nn.Module):
def __init__(
self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20,
):
super().__init__()
self.norm_means = norm_means
self.norm_vars = norm_vars
self.eps = eps
def extra_repr(self):
return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
def forward(
self, x: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward function
Args:
x: (B, L, ...)
ilens: (B,)
"""
return utterance_mvn(
x,
ilens,
norm_means=self.norm_means,
norm_vars=self.norm_vars,
eps=self.eps,
)
def utterance_mvn(
x: torch.Tensor,
ilens: torch.Tensor = None,
norm_means: bool = True,
norm_vars: bool = False,
eps: float = 1.0e-20,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply utterance mean and variance normalization
Args:
x: (B, T, D), assumed zero padded
ilens: (B,)
norm_means:
norm_vars:
eps:
"""
if ilens is None:
ilens = x.new_full([x.size(0)], x.size(1))
ilens_ = ilens.to(x.device, x.dtype).view(-1, *[1 for _ in range(x.dim() - 1)])
# Zero padding
if x.requires_grad:
x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
else:
x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
# mean: (B, 1, D)
mean = x.sum(dim=1, keepdim=True) / ilens_
if norm_means:
x -= mean
if norm_vars:
var = x.pow(2).sum(dim=1, keepdim=True) / ilens_
std = torch.clamp(var.sqrt(), min=eps)
x = x / std.sqrt()
return x, ilens
else:
if norm_vars:
y = x - mean
y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0)
var = y.pow(2).sum(dim=1, keepdim=True) / ilens_
std = torch.clamp(var.sqrt(), min=eps)
x /= std
return x, ilens

5
pre.py
View File

@@ -12,7 +12,8 @@ import argparse
recognized_datasets = [
"aidatatang_200zh",
"magicdata",
"aishell3"
"aishell3",
"data_aishell"
]
if __name__ == "__main__":
@@ -40,7 +41,7 @@ if __name__ == "__main__":
"Use this option when dataset does not include alignments\
(these are used to split long audio files into sub-utterances.)")
parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\
"Name of the dataset to process, allowing values: magicdata, aidatatang_200zh, aishell3.")
"Name of the dataset to process, allowing values: magicdata, aidatatang_200zh, aishell3, data_aishell.")
parser.add_argument("-e", "--encoder_model_fpath", type=Path, default="encoder/saved_models/pretrained.pt", help=\
"Path your trained encoder model.")
parser.add_argument("-ne", "--n_processes_embed", type=int, default=1, help=\

49
pre4ppg.py Normal file
View File

@@ -0,0 +1,49 @@
from pathlib import Path
import argparse
from ppg2mel.preprocess import preprocess_dataset
from pathlib import Path
import argparse
recognized_datasets = [
"aidatatang_200zh",
"aidatatang_200zh_s", # sample
]
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Preprocesses audio files from datasets, to be used by the "
"ppg2mel model for training.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("datasets_root", type=Path, help=\
"Path to the directory containing your datasets.")
parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\
"Name of the dataset to process, allowing values: aidatatang_200zh.")
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
"Path to the output directory that will contain the mel spectrograms, the audios and the "
"embeds. Defaults to <datasets_root>/PPGVC/ppg2mel/")
parser.add_argument("-n", "--n_processes", type=int, default=8, help=\
"Number of processes in parallel.")
# parser.add_argument("-s", "--skip_existing", action="store_true", help=\
# "Whether to overwrite existing files with the same name. Useful if the preprocessing was "
# "interrupted. ")
# parser.add_argument("--hparams", type=str, default="", help=\
# "Hyperparameter overrides as a comma-separated list of name-value pairs")
# parser.add_argument("--no_trim", action="store_true", help=\
# "Preprocess audio without trimming silences (not recommended).")
parser.add_argument("-pf", "--ppg_encoder_model_fpath", type=Path, default="ppg_extractor/saved_models/24epoch.pt", help=\
"Path your trained ppg encoder model.")
parser.add_argument("-sf", "--speaker_encoder_model", type=Path, default="encoder/saved_models/pretrained_bak_5805000.pt", help=\
"Path your trained speaker encoder model.")
args = parser.parse_args()
assert args.dataset in recognized_datasets, 'is not supported, file a issue to propose a new one'
# Create directories
assert args.datasets_root.exists()
if not hasattr(args, "out_dir"):
args.out_dir = args.datasets_root.joinpath("PPGVC", "ppg2mel")
args.out_dir.mkdir(exist_ok=True, parents=True)
preprocess_dataset(**vars(args))

View File

@@ -1,6 +1,6 @@
umap-learn
visdom
librosa>=0.8.0
librosa==0.8.1
matplotlib>=3.3.0
numpy==1.19.3; platform_system == "Windows"
numpy==1.19.4; platform_system != "Windows"
@@ -17,6 +17,10 @@ webrtcvad; platform_system != "Windows"
pypinyin
flask
flask_wtf
flask_cors
flask_cors==3.0.10
gevent==21.8.0
flask_restx
flask_restx
tensorboard
PyYAML==5.4.1
torch_complex
espnet

View File

@@ -0,0 +1,13 @@
class GSTHyperparameters():
E = 512
# reference encoder
ref_enc_filters = [32, 32, 64, 64, 128, 128]
# style token layer
token_num = 10
# token_emb_size = 256
num_heads = 8
n_mels = 256 # Number of Mel banks to generate

View File

@@ -1,5 +1,6 @@
import ast
import pprint
import json
class HParams(object):
def __init__(self, **kwargs): self.__dict__.update(kwargs)
@@ -18,6 +19,18 @@ class HParams(object):
self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
return self
def loadJson(self, dict):
print("\Loading the json with %s\n", dict)
for k in dict.keys():
self.__dict__[k] = dict[k]
return self
def dumpJson(self, fp):
print("\Saving the json with %s\n", fp)
with fp.open("w", encoding="utf-8") as f:
json.dump(self.__dict__, f)
return self
hparams = HParams(
### Signal Processing (used in both synthesizer and vocoder)
sample_rate = 16000,
@@ -49,19 +62,24 @@ hparams = HParams(
# frame that has all values < -3.4
### Tacotron Training
tts_schedule = [(2, 1e-3, 20_000, 24), # Progressive training schedule
(2, 5e-4, 40_000, 24), # (r, lr, step, batch_size)
(2, 2e-4, 80_000, 24), #
(2, 1e-4, 160_000, 24), # r = reduction factor (# of mel frames
(2, 3e-5, 320_000, 24), # synthesized for each decoder iteration)
(2, 1e-5, 640_000, 24)], # lr = learning rate
tts_schedule = [(2, 1e-3, 10_000, 12), # Progressive training schedule
(2, 5e-4, 15_000, 12), # (r, lr, step, batch_size)
(2, 2e-4, 20_000, 12), # (r, lr, step, batch_size)
(2, 1e-4, 30_000, 12), #
(2, 5e-5, 40_000, 12), #
(2, 1e-5, 60_000, 12), #
(2, 5e-6, 160_000, 12), # r = reduction factor (# of mel frames
(2, 3e-6, 320_000, 12), # synthesized for each decoder iteration)
(2, 1e-6, 640_000, 12)], # lr = learning rate
tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
tts_eval_interval = 500, # Number of steps between model evaluation (sample generation)
# Set to -1 to generate after completing epoch, or 0 to disable
tts_eval_num_samples = 1, # Makes this number of samples
## For finetune usage, if set, only selected layers will be trained, available: encoder,encoder_proj,gst,decoder,postnet,post_proj
tts_finetune_layers = [],
### Data Preprocessing
max_mel_frames = 900,
rescale = True,
@@ -86,4 +104,6 @@ hparams = HParams(
speaker_embedding_size = 256, # Dimension for the speaker embedding
silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded
use_gst = True, # Whether to use global style token
use_ser_for_gst = True, # Whether to use speaker embedding referenced for global style token
)

View File

@@ -10,6 +10,7 @@ from typing import Union, List
import numpy as np
import librosa
from utils import logmmse
import json
from pypinyin import lazy_pinyin, Style
class Synthesizer:
@@ -44,6 +45,11 @@ class Synthesizer:
return self._model is not None
def load(self):
# Try to scan config file
model_config_fpaths = list(self.model_fpath.parent.rglob("*.json"))
if len(model_config_fpaths)>0 and model_config_fpaths[0].exists():
with model_config_fpaths[0].open("r", encoding="utf-8") as f:
hparams.loadJson(json.load(f))
"""
Instantiates and loads the model given the weights file that was passed in the constructor.
"""
@@ -62,7 +68,7 @@ class Synthesizer:
stop_threshold=hparams.tts_stop_threshold,
speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
self._model.load(self.model_fpath)
self._model.load(self.model_fpath, self.device)
self._model.eval()
if self.verbose:
@@ -70,7 +76,7 @@ class Synthesizer:
def synthesize_spectrograms(self, texts: List[str],
embeddings: Union[np.ndarray, List[np.ndarray]],
return_alignments=False):
return_alignments=False, style_idx=0, min_stop_token=5, steps=2000):
"""
Synthesizes mel spectrograms from texts and speaker embeddings.
@@ -125,7 +131,7 @@ class Synthesizer:
speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
# Inference
_, mels, alignments = self._model.generate(chars, speaker_embeddings)
_, mels, alignments = self._model.generate(chars, speaker_embeddings, style_idx=style_idx, min_stop_token=min_stop_token, steps=steps)
mels = mels.detach().cpu().numpy()
for m in mels:
# Trim silence from end of each spectrogram

View File

@@ -0,0 +1,145 @@
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as tFunctional
from synthesizer.gst_hyperparameters import GSTHyperparameters as hp
from synthesizer.hparams import hparams
class GlobalStyleToken(nn.Module):
"""
inputs: style mel spectrograms [batch_size, num_spec_frames, num_mel]
speaker_embedding: speaker mel spectrograms [batch_size, num_spec_frames, num_mel]
outputs: [batch_size, embedding_dim]
"""
def __init__(self, speaker_embedding_dim=None):
super().__init__()
self.encoder = ReferenceEncoder()
self.stl = STL(speaker_embedding_dim)
def forward(self, inputs, speaker_embedding=None):
enc_out = self.encoder(inputs)
# concat speaker_embedding according to https://github.com/mozilla/TTS/blob/master/TTS/tts/layers/gst_layers.py
if hparams.use_ser_for_gst and speaker_embedding is not None:
enc_out = torch.cat([enc_out, speaker_embedding], dim=-1)
style_embed = self.stl(enc_out)
return style_embed
class ReferenceEncoder(nn.Module):
'''
inputs --- [N, Ty/r, n_mels*r] mels
outputs --- [N, ref_enc_gru_size]
'''
def __init__(self):
super().__init__()
K = len(hp.ref_enc_filters)
filters = [1] + hp.ref_enc_filters
convs = [nn.Conv2d(in_channels=filters[i],
out_channels=filters[i + 1],
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1)) for i in range(K)]
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=hp.ref_enc_filters[i]) for i in range(K)])
out_channels = self.calculate_channels(hp.n_mels, 3, 2, 1, K)
self.gru = nn.GRU(input_size=hp.ref_enc_filters[-1] * out_channels,
hidden_size=hp.E // 2,
batch_first=True)
def forward(self, inputs):
N = inputs.size(0)
out = inputs.view(N, 1, -1, hp.n_mels) # [N, 1, Ty, n_mels]
for conv, bn in zip(self.convs, self.bns):
out = conv(out)
out = bn(out)
out = tFunctional.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
T = out.size(1)
N = out.size(0)
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
self.gru.flatten_parameters()
memory, out = self.gru(out) # out --- [1, N, E//2]
return out.squeeze(0)
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
for i in range(n_convs):
L = (L - kernel_size + 2 * pad) // stride + 1
return L
class STL(nn.Module):
'''
inputs --- [N, E//2]
'''
def __init__(self, speaker_embedding_dim=None):
super().__init__()
self.embed = nn.Parameter(torch.FloatTensor(hp.token_num, hp.E // hp.num_heads))
d_q = hp.E // 2
d_k = hp.E // hp.num_heads
# self.attention = MultiHeadAttention(hp.num_heads, d_model, d_q, d_v)
if hparams.use_ser_for_gst and speaker_embedding_dim is not None:
d_q += speaker_embedding_dim
self.attention = MultiHeadAttention(query_dim=d_q, key_dim=d_k, num_units=hp.E, num_heads=hp.num_heads)
init.normal_(self.embed, mean=0, std=0.5)
def forward(self, inputs):
N = inputs.size(0)
query = inputs.unsqueeze(1) # [N, 1, E//2]
keys = tFunctional.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads]
style_embed = self.attention(query, keys)
return style_embed
class MultiHeadAttention(nn.Module):
'''
input:
query --- [N, T_q, query_dim]
key --- [N, T_k, key_dim]
output:
out --- [N, T_q, num_units]
'''
def __init__(self, query_dim, key_dim, num_units, num_heads):
super().__init__()
self.num_units = num_units
self.num_heads = num_heads
self.key_dim = key_dim
self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
def forward(self, query, key):
querys = self.W_query(query) # [N, T_q, num_units]
keys = self.W_key(key) # [N, T_k, num_units]
values = self.W_value(key)
split_size = self.num_units // self.num_heads
querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h]
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
# score = softmax(QK^T / (d_k ** 0.5))
scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
scores = scores / (self.key_dim ** 0.5)
scores = tFunctional.softmax(scores, dim=3)
# out = score * V
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
return out

View File

@@ -3,8 +3,9 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from typing import Union
from synthesizer.models.global_style_token import GlobalStyleToken
from synthesizer.gst_hyperparameters import GSTHyperparameters as gst_hp
from synthesizer.hparams import hparams
class HighwayNetwork(nn.Module):
@@ -60,7 +61,7 @@ class Encoder(nn.Module):
idx = 1
# Start by making a copy of each speaker embedding to match the input text length
# The output of this has size (batch_size, num_chars * tts_embed_dims)
# The output of this has size (batch_size, num_chars * speaker_embedding_size)
speaker_embedding_size = speaker_embedding.size()[idx]
e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
@@ -126,7 +127,7 @@ class CBHG(nn.Module):
# Although we `_flatten_parameters()` on init, when using DataParallel
# the model gets replicated, making it no longer guaranteed that the
# weights are contiguous in GPU memory. Hence, we must call it again
self._flatten_parameters()
self.rnn.flatten_parameters()
# Save these for later
residual = x
@@ -213,7 +214,7 @@ class LSA(nn.Module):
self.attention = None
def init_attention(self, encoder_seq_proj):
device = next(self.parameters()).device # use same device as parameters
device = encoder_seq_proj.device # use same device as parameters
b, t, c = encoder_seq_proj.size()
self.cumulative = torch.zeros(b, t, device=device)
self.attention = torch.zeros(b, t, device=device)
@@ -255,16 +256,17 @@ class Decoder(nn.Module):
self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
dropout=dropout)
self.attn_net = LSA(decoder_dims)
if hparams.use_gst:
speaker_embedding_size += gst_hp.E
self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
def zoneout(self, prev, current, p=0.1):
device = next(self.parameters()).device # Use same device as parameters
mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
def zoneout(self, prev, current, device, p=0.1):
mask = torch.zeros(prev.size(),device=device).bernoulli_(p)
return prev * mask + current * (1 - mask)
def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
@@ -272,7 +274,7 @@ class Decoder(nn.Module):
# Need this for reshaping mels
batch_size = encoder_seq.size(0)
device = encoder_seq.device
# Unpack the hidden and cell states
attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
rnn1_cell, rnn2_cell = cell_states
@@ -298,7 +300,7 @@ class Decoder(nn.Module):
# Compute first Residual RNN
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
if self.training:
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next,device=device)
else:
rnn1_hidden = rnn1_hidden_next
x = x + rnn1_hidden
@@ -306,7 +308,7 @@ class Decoder(nn.Module):
# Compute second Residual RNN
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
if self.training:
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next, device=device)
else:
rnn2_hidden = rnn2_hidden_next
x = x + rnn2_hidden
@@ -337,7 +339,12 @@ class Tacotron(nn.Module):
self.speaker_embedding_size = speaker_embedding_size
self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
encoder_K, num_highways, dropout)
self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
project_dims = encoder_dims + speaker_embedding_size
if hparams.use_gst:
project_dims += gst_hp.E
self.encoder_proj = nn.Linear(project_dims, decoder_dims, bias=False)
if hparams.use_gst:
self.gst = GlobalStyleToken(speaker_embedding_size)
self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
dropout, speaker_embedding_size)
self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
@@ -357,12 +364,19 @@ class Tacotron(nn.Module):
@r.setter
def r(self, value):
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
@staticmethod
def _concat_speaker_embedding(outputs, speaker_embeddings):
speaker_embeddings_ = speaker_embeddings.expand(
outputs.size(0), outputs.size(1), -1)
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
return outputs
def forward(self, x, m, speaker_embedding):
device = next(self.parameters()).device # use same device as parameters
def forward(self, texts, mels, speaker_embedding):
device = texts.device # use same device as parameters
self.step += 1
batch_size, _, steps = m.size()
batch_size, _, steps = mels.size()
# Initialise all hidden states and pack into tuple
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
@@ -379,11 +393,20 @@ class Tacotron(nn.Module):
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
# Need an initial context vector
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
size = self.encoder_dims + self.speaker_embedding_size
if hparams.use_gst:
size += gst_hp.E
context_vec = torch.zeros(batch_size, size, device=device)
# SV2TTS: Run the encoder with the speaker embedding
# The projection avoids unnecessary matmuls in the decoder loop
encoder_seq = self.encoder(x, speaker_embedding)
encoder_seq = self.encoder(texts, speaker_embedding)
# put after encoder
if hparams.use_gst and self.gst is not None:
style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced
# style_embed = style_embed.expand_as(encoder_seq)
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
encoder_seq_proj = self.encoder_proj(encoder_seq)
# Need a couple of lists for outputs
@@ -391,10 +414,10 @@ class Tacotron(nn.Module):
# Run the decoder loop
for t in range(0, steps, self.r):
prenet_in = m[:, :, t - 1] if t > 0 else go_frame
prenet_in = mels[:, :, t - 1] if t > 0 else go_frame
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
hidden_states, cell_states, context_vec, t, x)
hidden_states, cell_states, context_vec, t, texts)
mel_outputs.append(mel_frames)
attn_scores.append(scores)
stop_outputs.extend([stop_tokens] * self.r)
@@ -414,9 +437,9 @@ class Tacotron(nn.Module):
return mel_outputs, linear, attn_scores, stop_outputs
def generate(self, x, speaker_embedding=None, steps=2000):
def generate(self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5):
self.eval()
device = next(self.parameters()).device # use same device as parameters
device = x.device # use same device as parameters
batch_size, _ = x.size()
@@ -435,11 +458,30 @@ class Tacotron(nn.Module):
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
# Need an initial context vector
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
size = self.encoder_dims + self.speaker_embedding_size
if hparams.use_gst:
size += gst_hp.E
context_vec = torch.zeros(batch_size, size, device=device)
# SV2TTS: Run the encoder with the speaker embedding
# The projection avoids unnecessary matmuls in the decoder loop
encoder_seq = self.encoder(x, speaker_embedding)
# put after encoder
if hparams.use_gst and self.gst is not None:
if style_idx >= 0 and style_idx < 10:
query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
if device.type == 'cuda':
query = query.cuda()
gst_embed = torch.tanh(self.gst.stl.embed)
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
style_embed = self.gst.stl.attention(query, key)
else:
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
# style_embed = style_embed.expand_as(encoder_seq)
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
encoder_seq_proj = self.encoder_proj(encoder_seq)
# Need a couple of lists for outputs
@@ -455,7 +497,7 @@ class Tacotron(nn.Module):
attn_scores.append(scores)
stop_outputs.extend([stop_tokens] * self.r)
# Stop the loop when all stop tokens in batch exceed threshold
if (stop_tokens > 0.5).all() and t > 10: break
if (stop_tokens * 10 > min_stop_token).all() and t > 10: break
# Concat the mel outputs into sequence
mel_outputs = torch.cat(mel_outputs, dim=2)
@@ -479,6 +521,15 @@ class Tacotron(nn.Module):
for p in self.parameters():
if p.dim() > 1: nn.init.xavier_uniform_(p)
def finetune_partial(self, whitelist_layers):
self.zero_grad()
for name, child in self.named_children():
if name in whitelist_layers:
print("Trainable Layer: %s" % name)
print("Trainable Parameters: %.3f" % sum([np.prod(p.size()) for p in child.parameters()]))
for param in child.parameters():
param.requires_grad = False
def get_step(self):
return self.step.data.item()
@@ -490,11 +541,10 @@ class Tacotron(nn.Module):
with open(path, "a") as f:
print(msg, file=f)
def load(self, path, optimizer=None):
def load(self, path, device, optimizer=None):
# Use device of model params as location for loaded state
device = next(self.parameters()).device
checkpoint = torch.load(str(path), map_location=device)
self.load_state_dict(checkpoint["model_state"])
self.load_state_dict(checkpoint["model_state"], strict=False)
if "optimizer_state" in checkpoint and optimizer is not None:
optimizer.load_state_dict(checkpoint["optimizer_state"])

View File

@@ -7,7 +7,7 @@ from tqdm import tqdm
import numpy as np
from encoder import inference as encoder
from synthesizer.preprocess_speaker import preprocess_speaker_general
from synthesizer.preprocess_transcript import preprocess_transcript_aishell3
from synthesizer.preprocess_transcript import preprocess_transcript_aishell3, preprocess_transcript_magicdata
data_info = {
"aidatatang_200zh": {
@@ -18,13 +18,19 @@ data_info = {
"magicdata": {
"subfolders": ["train"],
"trans_filepath": "train/TRANS.txt",
"speak_func": preprocess_speaker_general
"speak_func": preprocess_speaker_general,
"transcript_func": preprocess_transcript_magicdata,
},
"aishell3":{
"subfolders": ["train/wav"],
"trans_filepath": "train/content.txt",
"speak_func": preprocess_speaker_general,
"transcript_func": preprocess_transcript_aishell3,
},
"data_aishell":{
"subfolders": ["wav/train"],
"trans_filepath": "transcript/aishell_transcript_v0.8.txt",
"speak_func": preprocess_speaker_general
}
}

View File

@@ -6,4 +6,13 @@ def preprocess_transcript_aishell3(dict_info, dict_transcript):
transList = []
for i in range(2, len(v), 2):
transList.append(v[i])
dict_info[v[0]] = " ".join(transList)
dict_info[v[0]] = " ".join(transList)
def preprocess_transcript_magicdata(dict_info, dict_transcript):
for v in dict_transcript:
if not v:
continue
v = v.strip().replace("\n","").replace("\t"," ").split(" ")
dict_info[v[0]] = " ".join(v[2:])

View File

@@ -45,7 +45,7 @@ def run_synthesis(in_dir, out_dir, model_dir, hparams):
model_dir = Path(model_dir)
model_fpath = model_dir.joinpath(model_dir.stem).with_suffix(".pt")
print("\nLoading weights at %s" % model_fpath)
model.load(model_fpath)
model.load(model_fpath, device)
print("Tacotron weights loaded from step %d" % model.step)
# Synthesize using same reduction factor as the model is currently trained

View File

@@ -73,6 +73,7 @@ def collate_synthesizer(batch):
# Speaker embedding (SV2TTS)
embeds = [x[2] for x in batch]
embeds = np.stack(embeds)
# Index (for vocoder preprocessing)
indices = [x[3] for x in batch]

View File

@@ -12,6 +12,7 @@ from synthesizer.utils.symbols import symbols
from synthesizer.utils.text import sequence_to_text
from vocoder.display import *
from datetime import datetime
import json
import numpy as np
from pathlib import Path
import sys
@@ -75,6 +76,13 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
if num_chars != loaded_shape[0]:
print("WARNING: you are using compatible mode due to wrong sympols length, please modify varible _characters in `utils\symbols.py`")
num_chars != loaded_shape[0]
# Try to scan config file
model_config_fpaths = list(weights_fpath.parent.rglob("*.json"))
if len(model_config_fpaths)>0 and model_config_fpaths[0].exists():
with model_config_fpaths[0].open("r", encoding="utf-8") as f:
hparams.loadJson(json.load(f))
else: # save a config
hparams.dumpJson(weights_fpath.parent.joinpath(run_id).with_suffix(".json"))
model = Tacotron(embed_dims=hparams.tts_embed_dims,
@@ -93,7 +101,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
speaker_embedding_size=hparams.speaker_embedding_size).to(device)
# Initialize the optimizer
optimizer = optim.Adam(model.parameters())
optimizer = optim.Adam(model.parameters(), amsgrad=True)
# Load the weights
if force_restart or not weights_fpath.exists():
@@ -111,7 +119,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
else:
print("\nLoading weights at %s" % weights_fpath)
model.load(weights_fpath, optimizer)
model.load(weights_fpath, device, optimizer)
print("Tacotron weights loaded from step %d" % model.step)
# Initialize the dataset
@@ -146,7 +154,6 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
continue
model.r = r
# Begin the training
simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
("Batch Size", batch_size),
@@ -155,6 +162,8 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
for p in optimizer.param_groups:
p["lr"] = lr
if hparams.tts_finetune_layers is not None and len(hparams.tts_finetune_layers) > 0:
model.finetune_partial(hparams.tts_finetune_layers)
data_loader = DataLoader(dataset,
collate_fn=collate_synthesizer,
@@ -221,7 +230,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
# Backup or save model as appropriate
if backup_every != 0 and step % backup_every == 0 :
backup_fpath = Path("{}/{}_{}k.pt".format(str(weights_fpath.parent), run_id, k))
backup_fpath = Path("{}/{}_{}.pt".format(str(weights_fpath.parent), run_id, step))
model.save(backup_fpath, optimizer)
if save_every != 0 and step % save_every == 0 :

View File

@@ -3,16 +3,17 @@ from encoder import inference as encoder
from synthesizer.inference import Synthesizer
from vocoder.wavernn import inference as rnn_vocoder
from vocoder.hifigan import inference as gan_vocoder
import ppg_extractor as extractor
import ppg2mel as convertor
from pathlib import Path
from time import perf_counter as timer
from toolbox.utterance import Utterance
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
import numpy as np
import traceback
import sys
import torch
import librosa
import re
from audioread.exceptions import NoBackendError
# 默认使用wavernn
vocoder = rnn_vocoder
@@ -49,14 +50,20 @@ recognized_datasets = [
MAX_WAVES = 15
class Toolbox:
def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, seed, no_mp3_support):
def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, extractor_models_dir, convertor_models_dir, seed, no_mp3_support, vc_mode):
self.no_mp3_support = no_mp3_support
self.vc_mode = vc_mode
sys.excepthook = self.excepthook
self.datasets_root = datasets_root
self.utterances = set()
self.current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav
self.synthesizer = None # type: Synthesizer
# for ppg-based voice conversion
self.extractor = None
self.convertor = None # ppg2mel
self.current_wav = None
self.waves_list = []
self.waves_count = 0
@@ -70,8 +77,9 @@ class Toolbox:
self.trim_silences = False
# Initialize the events and the interface
self.ui = UI()
self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, seed)
self.ui = UI(vc_mode)
self.style_idx = 0
self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, extractor_models_dir, convertor_models_dir, seed)
self.setup_events()
self.ui.start()
@@ -95,7 +103,11 @@ class Toolbox:
self.ui.encoder_box.currentIndexChanged.connect(self.init_encoder)
def func():
self.synthesizer = None
self.ui.synthesizer_box.currentIndexChanged.connect(func)
if self.vc_mode:
self.ui.extractor_box.currentIndexChanged.connect(self.init_extractor)
else:
self.ui.synthesizer_box.currentIndexChanged.connect(func)
self.ui.vocoder_box.currentIndexChanged.connect(self.init_vocoder)
# Utterance selection
@@ -108,6 +120,11 @@ class Toolbox:
self.ui.stop_button.clicked.connect(self.ui.stop)
self.ui.record_button.clicked.connect(self.record)
# Source Utterance selection
if self.vc_mode:
func = lambda: self.load_soruce_button(self.ui.selected_utterance)
self.ui.load_soruce_button.clicked.connect(func)
#Audio
self.ui.setup_audio_devices(Synthesizer.sample_rate)
@@ -119,12 +136,17 @@ class Toolbox:
self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
# Generation
func = lambda: self.synthesize() or self.vocode()
self.ui.generate_button.clicked.connect(func)
self.ui.synthesize_button.clicked.connect(self.synthesize)
self.ui.vocode_button.clicked.connect(self.vocode)
self.ui.random_seed_checkbox.clicked.connect(self.update_seed_textbox)
if self.vc_mode:
func = lambda: self.convert() or self.vocode()
self.ui.convert_button.clicked.connect(func)
else:
func = lambda: self.synthesize() or self.vocode()
self.ui.generate_button.clicked.connect(func)
self.ui.synthesize_button.clicked.connect(self.synthesize)
# UMAP legend
self.ui.clear_button.clicked.connect(self.clear_utterances)
@@ -137,9 +159,9 @@ class Toolbox:
def replay_last_wav(self):
self.ui.play(self.current_wav, Synthesizer.sample_rate)
def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, seed):
def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, extractor_models_dir, convertor_models_dir, seed):
self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True)
self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir)
self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, extractor_models_dir, convertor_models_dir, self.vc_mode)
self.ui.populate_gen_options(seed, self.trim_silences)
def load_from_browser(self, fpath=None):
@@ -170,7 +192,10 @@ class Toolbox:
self.ui.log("Loaded %s" % name)
self.add_real_utterance(wav, name, speaker_name)
def load_soruce_button(self, utterance: Utterance):
self.selected_source_utterance = utterance
def record(self):
wav = self.ui.record_one(encoder.sampling_rate, 5)
if wav is None:
@@ -195,7 +220,7 @@ class Toolbox:
# Add the utterance
utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, False)
self.utterances.add(utterance)
self.ui.register_utterance(utterance)
self.ui.register_utterance(utterance, self.vc_mode)
# Plot it
self.ui.draw_embed(embed, name, "current")
@@ -233,7 +258,8 @@ class Toolbox:
texts = processed_texts
embed = self.ui.selected_utterance.embed
embeds = [embed] * len(texts)
specs = self.synthesizer.synthesize_spectrograms(texts, embeds)
min_token = int(self.ui.token_slider.value())
specs = self.synthesizer.synthesize_spectrograms(texts, embeds, style_idx=int(self.ui.style_slider.value()), min_stop_token=min_token, steps=int(self.ui.length_slider.value())*200)
breaks = [spec.shape[1] for spec in specs]
spec = np.concatenate(specs, axis=1)
@@ -267,7 +293,7 @@ class Toolbox:
self.ui.set_loading(i, seq_len)
if self.ui.current_vocoder_fpath is not None:
self.ui.log("")
wav = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
wav, sample_rate = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
else:
self.ui.log("Waveform generation with Griffin-Lim... ")
wav = Synthesizer.griffin_lim(spec)
@@ -278,7 +304,7 @@ class Toolbox:
b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
b_starts = np.concatenate(([0], b_ends[:-1]))
wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)]
breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks)
breaks = [np.zeros(int(0.15 * sample_rate))] * len(breaks)
wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
# Trim excessive silences
@@ -287,7 +313,7 @@ class Toolbox:
# Play it
wav = wav / np.abs(wav).max() * 0.97
self.ui.play(wav, Synthesizer.sample_rate)
self.ui.play(wav, sample_rate)
# Name it (history displayed in combobox)
# TODO better naming for the combobox items?
@@ -329,6 +355,68 @@ class Toolbox:
self.ui.draw_embed(embed, name, "generated")
self.ui.draw_umap_projections(self.utterances)
def convert(self):
self.ui.log("Extract PPG and Converting...")
self.ui.set_loading(1)
# Init
if self.convertor is None:
self.init_convertor()
if self.extractor is None:
self.init_extractor()
src_wav = self.selected_source_utterance.wav
# Compute the ppg
if not self.extractor is None:
ppg = self.extractor.extract_from_wav(src_wav)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ref_wav = self.ui.selected_utterance.wav
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
min_len = min(ppg.shape[1], len(lf0_uv))
ppg = ppg[:, :min_len]
lf0_uv = lf0_uv[:min_len]
_, mel_pred, att_ws = self.convertor.inference(
ppg,
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
spembs=torch.from_numpy(self.ui.selected_utterance.embed).unsqueeze(0).to(device),
)
mel_pred= mel_pred.transpose(0, 1)
breaks = [mel_pred.shape[1]]
mel_pred= mel_pred.detach().cpu().numpy()
self.ui.draw_spec(mel_pred, "generated")
self.current_generated = (self.ui.selected_utterance.speaker_name, mel_pred, breaks, None)
self.ui.set_loading(0)
def init_extractor(self):
if self.ui.current_extractor_fpath is None:
return
model_fpath = self.ui.current_extractor_fpath
self.ui.log("Loading the extractor %s... " % model_fpath)
self.ui.set_loading(1)
start = timer()
self.extractor = extractor.load_model(model_fpath)
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
self.ui.set_loading(0)
def init_convertor(self):
if self.ui.current_convertor_fpath is None:
return
model_fpath = self.ui.current_convertor_fpath
# search a config file
model_config_fpaths = list(model_fpath.parent.rglob("*.yaml"))
if self.ui.current_convertor_fpath is None:
return
model_config_fpath = model_config_fpaths[0]
self.ui.log("Loading the convertor %s... " % model_fpath)
self.ui.set_loading(1)
start = timer()
self.convertor = convertor.load_model(model_config_fpath, model_fpath)
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
self.ui.set_loading(0)
def init_encoder(self):
model_fpath = self.ui.current_encoder_fpath
@@ -356,12 +444,16 @@ class Toolbox:
# Case of Griffin-lim
if model_fpath is None:
return
# Sekect vocoder based on model name
model_config_fpath = None
if model_fpath.name[0] == "g":
vocoder = gan_vocoder
self.ui.log("set hifigan as vocoder")
# search a config file
model_config_fpaths = list(model_fpath.parent.rglob("*.json"))
if self.ui.current_extractor_fpath is None:
return
model_config_fpath = model_config_fpaths[0]
else:
vocoder = rnn_vocoder
self.ui.log("set wavernn as vocoder")
@@ -369,7 +461,7 @@ class Toolbox:
self.ui.log("Loading the vocoder %s... " % model_fpath)
self.ui.set_loading(1)
start = timer()
vocoder.load_model(model_fpath)
vocoder.load_model(model_fpath, model_config_fpath)
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
self.ui.set_loading(0)

BIN
toolbox/assets/mb.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.6 KiB

View File

@@ -1,8 +1,9 @@
from PyQt5.QtCore import Qt, QStringListModel
from PyQt5 import QtGui
from PyQt5.QtWidgets import *
import matplotlib.pyplot as plt
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from PyQt5.QtCore import Qt, QStringListModel
from PyQt5.QtWidgets import *
from encoder.inference import plot_embedding_as_heatmap
from toolbox.utterance import Utterance
from pathlib import Path
@@ -325,30 +326,51 @@ class UI(QDialog):
def current_vocoder_fpath(self):
return self.vocoder_box.itemData(self.vocoder_box.currentIndex())
@property
def current_extractor_fpath(self):
return self.extractor_box.itemData(self.extractor_box.currentIndex())
@property
def current_convertor_fpath(self):
return self.convertor_box.itemData(self.convertor_box.currentIndex())
def populate_models(self, encoder_models_dir: Path, synthesizer_models_dir: Path,
vocoder_models_dir: Path):
vocoder_models_dir: Path, extractor_models_dir: Path, convertor_models_dir: Path, vc_mode: bool):
# Encoder
encoder_fpaths = list(encoder_models_dir.glob("*.pt"))
if len(encoder_fpaths) == 0:
raise Exception("No encoder models found in %s" % encoder_models_dir)
self.repopulate_box(self.encoder_box, [(f.stem, f) for f in encoder_fpaths])
# Synthesizer
synthesizer_fpaths = list(synthesizer_models_dir.glob("**/*.pt"))
if len(synthesizer_fpaths) == 0:
raise Exception("No synthesizer models found in %s" % synthesizer_models_dir)
self.repopulate_box(self.synthesizer_box, [(f.stem, f) for f in synthesizer_fpaths])
if vc_mode:
# Extractor
extractor_fpaths = list(extractor_models_dir.glob("*.pt"))
if len(extractor_fpaths) == 0:
self.log("No extractor models found in %s" % extractor_fpaths)
self.repopulate_box(self.extractor_box, [(f.stem, f) for f in extractor_fpaths])
# Convertor
convertor_fpaths = list(convertor_models_dir.glob("*.pth"))
if len(convertor_fpaths) == 0:
self.log("No convertor models found in %s" % convertor_fpaths)
self.repopulate_box(self.convertor_box, [(f.stem, f) for f in convertor_fpaths])
else:
# Synthesizer
synthesizer_fpaths = list(synthesizer_models_dir.glob("**/*.pt"))
if len(synthesizer_fpaths) == 0:
raise Exception("No synthesizer models found in %s" % synthesizer_models_dir)
self.repopulate_box(self.synthesizer_box, [(f.stem, f) for f in synthesizer_fpaths])
# Vocoder
vocoder_fpaths = list(vocoder_models_dir.glob("**/*.pt"))
vocoder_items = [(f.stem, f) for f in vocoder_fpaths] + [("Griffin-Lim", None)]
self.repopulate_box(self.vocoder_box, vocoder_items)
@property
def selected_utterance(self):
return self.utterance_history.itemData(self.utterance_history.currentIndex())
def register_utterance(self, utterance: Utterance):
def register_utterance(self, utterance: Utterance, vc_mode):
self.utterance_history.blockSignals(True)
self.utterance_history.insertItem(0, utterance.name, utterance)
self.utterance_history.setCurrentIndex(0)
@@ -358,8 +380,11 @@ class UI(QDialog):
self.utterance_history.removeItem(self.max_saved_utterances)
self.play_button.setDisabled(False)
self.generate_button.setDisabled(False)
self.synthesize_button.setDisabled(False)
if vc_mode:
self.convert_button.setDisabled(False)
else:
self.generate_button.setDisabled(False)
self.synthesize_button.setDisabled(False)
def log(self, line, mode="newline"):
if mode == "newline":
@@ -401,7 +426,7 @@ class UI(QDialog):
else:
self.seed_textbox.setEnabled(False)
def reset_interface(self):
def reset_interface(self, vc_mode):
self.draw_embed(None, None, "current")
self.draw_embed(None, None, "generated")
self.draw_spec(None, "current")
@@ -409,18 +434,24 @@ class UI(QDialog):
self.draw_umap_projections(set())
self.set_loading(0)
self.play_button.setDisabled(True)
self.generate_button.setDisabled(True)
self.synthesize_button.setDisabled(True)
if vc_mode:
self.convert_button.setDisabled(True)
else:
self.generate_button.setDisabled(True)
self.synthesize_button.setDisabled(True)
self.vocode_button.setDisabled(True)
self.replay_wav_button.setDisabled(True)
self.export_wav_button.setDisabled(True)
[self.log("") for _ in range(self.max_log_lines)]
def __init__(self):
def __init__(self, vc_mode):
## Initialize the application
self.app = QApplication(sys.argv)
super().__init__(None)
self.setWindowTitle("SV2TTS toolbox")
self.setWindowTitle("MockingBird GUI")
self.setWindowIcon(QtGui.QIcon('toolbox\\assets\\mb.png'))
self.setWindowFlag(Qt.WindowMinimizeButtonHint, True)
self.setWindowFlag(Qt.WindowMaximizeButtonHint, True)
## Main layouts
@@ -430,21 +461,24 @@ class UI(QDialog):
# Browser
browser_layout = QGridLayout()
root_layout.addLayout(browser_layout, 0, 0, 1, 2)
root_layout.addLayout(browser_layout, 0, 0, 1, 8)
# Generation
gen_layout = QVBoxLayout()
root_layout.addLayout(gen_layout, 0, 2, 1, 2)
# Projections
self.projections_layout = QVBoxLayout()
root_layout.addLayout(self.projections_layout, 1, 0, 1, 1)
root_layout.addLayout(gen_layout, 0, 8)
# Visualizations
vis_layout = QVBoxLayout()
root_layout.addLayout(vis_layout, 1, 1, 1, 3)
root_layout.addLayout(vis_layout, 1, 0, 2, 8)
# Output
output_layout = QGridLayout()
vis_layout.addLayout(output_layout, 0)
# Projections
self.projections_layout = QVBoxLayout()
root_layout.addLayout(self.projections_layout, 1, 8, 2, 2)
## Projections
# UMap
fig, self.umap_ax = plt.subplots(figsize=(3, 3), facecolor="#F0F0F0")
@@ -458,84 +492,102 @@ class UI(QDialog):
## Browser
# Dataset, speaker and utterance selection
i = 0
self.dataset_box = QComboBox()
browser_layout.addWidget(QLabel("<b>Dataset</b>"), i, 0)
browser_layout.addWidget(self.dataset_box, i + 1, 0)
self.speaker_box = QComboBox()
browser_layout.addWidget(QLabel("<b>Speaker</b>"), i, 1)
browser_layout.addWidget(self.speaker_box, i + 1, 1)
self.utterance_box = QComboBox()
browser_layout.addWidget(QLabel("<b>Utterance</b>"), i, 2)
browser_layout.addWidget(self.utterance_box, i + 1, 2)
self.browser_load_button = QPushButton("Load")
browser_layout.addWidget(self.browser_load_button, i + 1, 3)
i += 2
# Random buttons
source_groupbox = QGroupBox('Source(源音频)')
source_layout = QGridLayout()
source_groupbox.setLayout(source_layout)
browser_layout.addWidget(source_groupbox, i, 0, 1, 5)
self.dataset_box = QComboBox()
source_layout.addWidget(QLabel("Dataset(数据集):"), i, 0)
source_layout.addWidget(self.dataset_box, i, 1)
self.random_dataset_button = QPushButton("Random")
browser_layout.addWidget(self.random_dataset_button, i, 0)
source_layout.addWidget(self.random_dataset_button, i, 2)
i += 1
self.speaker_box = QComboBox()
source_layout.addWidget(QLabel("Speaker(说话者)"), i, 0)
source_layout.addWidget(self.speaker_box, i, 1)
self.random_speaker_button = QPushButton("Random")
browser_layout.addWidget(self.random_speaker_button, i, 1)
source_layout.addWidget(self.random_speaker_button, i, 2)
i += 1
self.utterance_box = QComboBox()
source_layout.addWidget(QLabel("Utterance(音频):"), i, 0)
source_layout.addWidget(self.utterance_box, i, 1)
self.random_utterance_button = QPushButton("Random")
browser_layout.addWidget(self.random_utterance_button, i, 2)
source_layout.addWidget(self.random_utterance_button, i, 2)
i += 1
source_layout.addWidget(QLabel("<b>Use(使用):</b>"), i, 0)
self.browser_load_button = QPushButton("Load Above(加载上面)")
source_layout.addWidget(self.browser_load_button, i, 1, 1, 2)
self.auto_next_checkbox = QCheckBox("Auto select next")
self.auto_next_checkbox.setChecked(True)
browser_layout.addWidget(self.auto_next_checkbox, i, 3)
i += 1
source_layout.addWidget(self.auto_next_checkbox, i+1, 1)
self.browser_browse_button = QPushButton("Browse(打开本地)")
source_layout.addWidget(self.browser_browse_button, i, 3)
self.record_button = QPushButton("Record(录音)")
source_layout.addWidget(self.record_button, i+1, 3)
i += 2
# Utterance box
browser_layout.addWidget(QLabel("<b>Use embedding from:</b>"), i, 0)
browser_layout.addWidget(QLabel("<b>Current(当前):</b>"), i, 0)
self.utterance_history = QComboBox()
browser_layout.addWidget(self.utterance_history, i, 1, 1, 3)
i += 1
# Random & next utterance buttons
self.browser_browse_button = QPushButton("Browse")
browser_layout.addWidget(self.browser_browse_button, i, 0)
self.record_button = QPushButton("Record")
browser_layout.addWidget(self.record_button, i, 1)
self.play_button = QPushButton("Play")
browser_layout.addWidget(self.utterance_history, i, 1)
self.play_button = QPushButton("Play(播放)")
browser_layout.addWidget(self.play_button, i, 2)
self.stop_button = QPushButton("Stop")
self.stop_button = QPushButton("Stop(暂停)")
browser_layout.addWidget(self.stop_button, i, 3)
i += 1
if vc_mode:
self.load_soruce_button = QPushButton("Select(选择为被转换的语音输入)")
browser_layout.addWidget(self.load_soruce_button, i, 4)
i += 1
model_groupbox = QGroupBox('Models(模型选择)')
model_layout = QHBoxLayout()
model_groupbox.setLayout(model_layout)
browser_layout.addWidget(model_groupbox, i, 0, 2, 5)
# Model and audio output selection
self.encoder_box = QComboBox()
browser_layout.addWidget(QLabel("<b>Encoder</b>"), i, 0)
browser_layout.addWidget(self.encoder_box, i + 1, 0)
model_layout.addWidget(QLabel("Encoder:"))
model_layout.addWidget(self.encoder_box)
self.synthesizer_box = QComboBox()
browser_layout.addWidget(QLabel("<b>Synthesizer</b>"), i, 1)
browser_layout.addWidget(self.synthesizer_box, i + 1, 1)
if vc_mode:
self.extractor_box = QComboBox()
model_layout.addWidget(QLabel("Extractor:"))
model_layout.addWidget(self.extractor_box)
self.convertor_box = QComboBox()
model_layout.addWidget(QLabel("Convertor:"))
model_layout.addWidget(self.convertor_box)
else:
model_layout.addWidget(QLabel("Synthesizer:"))
model_layout.addWidget(self.synthesizer_box)
self.vocoder_box = QComboBox()
browser_layout.addWidget(QLabel("<b>Vocoder</b>"), i, 2)
browser_layout.addWidget(self.vocoder_box, i + 1, 2)
self.audio_out_devices_cb=QComboBox()
browser_layout.addWidget(QLabel("<b>Audio Output</b>"), i, 3)
browser_layout.addWidget(self.audio_out_devices_cb, i + 1, 3)
i += 2
model_layout.addWidget(QLabel("Vocoder:"))
model_layout.addWidget(self.vocoder_box)
#Replay & Save Audio
browser_layout.addWidget(QLabel("<b>Toolbox Output:</b>"), i, 0)
i = 0
output_layout.addWidget(QLabel("<b>Toolbox Output:</b>"), i, 0)
self.waves_cb = QComboBox()
self.waves_cb_model = QStringListModel()
self.waves_cb.setModel(self.waves_cb_model)
self.waves_cb.setToolTip("Select one of the last generated waves in this section for replaying or exporting")
browser_layout.addWidget(self.waves_cb, i, 1)
output_layout.addWidget(self.waves_cb, i, 1)
self.replay_wav_button = QPushButton("Replay")
self.replay_wav_button.setToolTip("Replay last generated vocoder")
browser_layout.addWidget(self.replay_wav_button, i, 2)
output_layout.addWidget(self.replay_wav_button, i, 2)
self.export_wav_button = QPushButton("Export")
self.export_wav_button.setToolTip("Save last generated vocoder audio in filesystem as a wav file")
browser_layout.addWidget(self.export_wav_button, i, 3)
output_layout.addWidget(self.export_wav_button, i, 3)
self.audio_out_devices_cb=QComboBox()
i += 1
output_layout.addWidget(QLabel("<b>Audio Output</b>"), i, 0)
output_layout.addWidget(self.audio_out_devices_cb, i, 1)
## Embed & spectrograms
vis_layout.addStretch()
# TODO: add spectrograms for source
gridspec_kw = {"width_ratios": [1, 4]}
fig, self.current_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0",
gridspec_kw=gridspec_kw)
@@ -552,21 +604,27 @@ class UI(QDialog):
for side in ["top", "right", "bottom", "left"]:
ax.spines[side].set_visible(False)
## Generation
self.text_prompt = QPlainTextEdit(default_text)
gen_layout.addWidget(self.text_prompt, stretch=1)
self.generate_button = QPushButton("Synthesize and vocode")
gen_layout.addWidget(self.generate_button)
layout = QHBoxLayout()
self.synthesize_button = QPushButton("Synthesize only")
layout.addWidget(self.synthesize_button)
if vc_mode:
layout = QHBoxLayout()
self.convert_button = QPushButton("Extract and Convert")
layout.addWidget(self.convert_button)
gen_layout.addLayout(layout)
else:
self.generate_button = QPushButton("Synthesize and vocode")
gen_layout.addWidget(self.generate_button)
layout = QHBoxLayout()
self.synthesize_button = QPushButton("Synthesize only")
layout.addWidget(self.synthesize_button)
self.vocode_button = QPushButton("Vocode only")
layout.addWidget(self.vocode_button)
gen_layout.addLayout(layout)
layout_seed = QGridLayout()
self.random_seed_checkbox = QCheckBox("Random seed:")
self.random_seed_checkbox.setToolTip("When checked, makes the synthesizer and vocoder deterministic.")
@@ -578,6 +636,45 @@ class UI(QDialog):
self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output."
" This feature requires `webrtcvad` to be installed.")
layout_seed.addWidget(self.trim_silences_checkbox, 0, 2, 1, 2)
self.style_slider = QSlider(Qt.Horizontal)
self.style_slider.setTickInterval(1)
self.style_slider.setFocusPolicy(Qt.NoFocus)
self.style_slider.setSingleStep(1)
self.style_slider.setRange(-1, 9)
self.style_value_label = QLabel("-1")
self.style_slider.setValue(-1)
layout_seed.addWidget(QLabel("Style:"), 1, 0)
self.style_slider.valueChanged.connect(lambda s: self.style_value_label.setNum(s))
layout_seed.addWidget(self.style_value_label, 1, 1)
layout_seed.addWidget(self.style_slider, 1, 3)
self.token_slider = QSlider(Qt.Horizontal)
self.token_slider.setTickInterval(1)
self.token_slider.setFocusPolicy(Qt.NoFocus)
self.token_slider.setSingleStep(1)
self.token_slider.setRange(3, 9)
self.token_value_label = QLabel("5")
self.token_slider.setValue(4)
layout_seed.addWidget(QLabel("Accuracy(精度):"), 2, 0)
self.token_slider.valueChanged.connect(lambda s: self.token_value_label.setNum(s))
layout_seed.addWidget(self.token_value_label, 2, 1)
layout_seed.addWidget(self.token_slider, 2, 3)
self.length_slider = QSlider(Qt.Horizontal)
self.length_slider.setTickInterval(1)
self.length_slider.setFocusPolicy(Qt.NoFocus)
self.length_slider.setSingleStep(1)
self.length_slider.setRange(1, 10)
self.length_value_label = QLabel("2")
self.length_slider.setValue(2)
layout_seed.addWidget(QLabel("MaxLength(最大句长):"), 3, 0)
self.length_slider.valueChanged.connect(lambda s: self.length_value_label.setNum(s))
layout_seed.addWidget(self.length_value_label, 3, 1)
layout_seed.addWidget(self.length_slider, 3, 3)
gen_layout.addLayout(layout_seed)
self.loading_bar = QProgressBar()
@@ -591,11 +688,11 @@ class UI(QDialog):
## Set the size of the window and of the elements
max_size = QDesktopWidget().availableGeometry(self).size() * 0.8
max_size = QDesktopWidget().availableGeometry(self).size() * 0.5
self.resize(max_size)
## Finalize the display
self.reset_interface()
self.reset_interface(vc_mode)
self.show()
def start(self):

60
utils/audio_utils.py Normal file
View File

@@ -0,0 +1,60 @@
import torch
import torch.utils.data
from scipy.io.wavfile import read
from librosa.filters import mel as librosa_mel_fn
MAX_WAV_VALUE = 32768.0
def load_wav(full_path):
sampling_rate, data = read(full_path)
return data, sampling_rate
def _dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def _spectral_normalize_torch(magnitudes):
output = _dynamic_range_compression_torch(magnitudes)
return output
mel_basis = {}
hann_window = {}
def mel_spectrogram(
y,
n_fft,
num_mels,
sampling_rate,
hop_size,
win_size,
fmin,
fmax,
center=False,
output_energy=False,
):
if torch.min(y) < -1.:
print('min value is ', torch.min(y))
if torch.max(y) > 1.:
print('max value is ', torch.max(y))
global mel_basis, hann_window
if fmax not in mel_basis:
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
mel_spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
mel_spec = _spectral_normalize_torch(mel_spec)
if output_energy:
energy = torch.norm(spec, dim=1)
return mel_spec, energy
else:
return mel_spec

214
utils/data_load.py Normal file
View File

@@ -0,0 +1,214 @@
import random
import numpy as np
import torch
from utils.f0_utils import get_cont_lf0
import resampy
from .audio_utils import MAX_WAV_VALUE, load_wav, mel_spectrogram
from librosa.util import normalize
import os
SAMPLE_RATE=16000
def read_fids(fid_list_f):
with open(fid_list_f, 'r') as f:
fids = [l.strip().split()[0] for l in f if l.strip()]
return fids
class OneshotVcDataset(torch.utils.data.Dataset):
def __init__(
self,
meta_file: str,
vctk_ppg_dir: str,
libri_ppg_dir: str,
vctk_f0_dir: str,
libri_f0_dir: str,
vctk_wav_dir: str,
libri_wav_dir: str,
vctk_spk_dvec_dir: str,
libri_spk_dvec_dir: str,
min_max_norm_mel: bool = False,
mel_min: float = None,
mel_max: float = None,
ppg_file_ext: str = "ling_feat.npy",
f0_file_ext: str = "f0.npy",
wav_file_ext: str = "wav",
):
self.fid_list = read_fids(meta_file)
self.vctk_ppg_dir = vctk_ppg_dir
self.libri_ppg_dir = libri_ppg_dir
self.vctk_f0_dir = vctk_f0_dir
self.libri_f0_dir = libri_f0_dir
self.vctk_wav_dir = vctk_wav_dir
self.libri_wav_dir = libri_wav_dir
self.vctk_spk_dvec_dir = vctk_spk_dvec_dir
self.libri_spk_dvec_dir = libri_spk_dvec_dir
self.ppg_file_ext = ppg_file_ext
self.f0_file_ext = f0_file_ext
self.wav_file_ext = wav_file_ext
self.min_max_norm_mel = min_max_norm_mel
if min_max_norm_mel:
print("[INFO] Min-Max normalize Melspec.")
assert mel_min is not None
assert mel_max is not None
self.mel_max = mel_max
self.mel_min = mel_min
random.seed(1234)
random.shuffle(self.fid_list)
print(f'[INFO] Got {len(self.fid_list)} samples.')
def __len__(self):
return len(self.fid_list)
def get_spk_dvec(self, fid):
spk_name = fid
if spk_name.startswith("p"):
spk_dvec_path = f"{self.vctk_spk_dvec_dir}{os.sep}{spk_name}.npy"
else:
spk_dvec_path = f"{self.libri_spk_dvec_dir}{os.sep}{spk_name}.npy"
return torch.from_numpy(np.load(spk_dvec_path))
def compute_mel(self, wav_path):
audio, sr = load_wav(wav_path)
if sr != SAMPLE_RATE:
audio = resampy.resample(audio, sr, SAMPLE_RATE)
audio = audio / MAX_WAV_VALUE
audio = normalize(audio) * 0.95
audio = torch.FloatTensor(audio).unsqueeze(0)
melspec = mel_spectrogram(
audio,
n_fft=1024,
num_mels=80,
sampling_rate=SAMPLE_RATE,
hop_size=160,
win_size=1024,
fmin=80,
fmax=8000,
)
return melspec.squeeze(0).numpy().T
def bin_level_min_max_norm(self, melspec):
# frequency bin level min-max normalization to [-4, 4]
mel = (melspec - self.mel_min) / (self.mel_max - self.mel_min) * 8.0 - 4.0
return np.clip(mel, -4., 4.)
def __getitem__(self, index):
fid = self.fid_list[index]
# 1. Load features
if fid.startswith("p"):
# vctk
sub = fid.split("_")[0]
ppg = np.load(f"{self.vctk_ppg_dir}{os.sep}{fid}.{self.ppg_file_ext}")
f0 = np.load(f"{self.vctk_f0_dir}{os.sep}{fid}.{self.f0_file_ext}")
mel = self.compute_mel(f"{self.vctk_wav_dir}{os.sep}{sub}{os.sep}{fid}.{self.wav_file_ext}")
else:
# aidatatang
sub = fid[5:10]
ppg = np.load(f"{self.libri_ppg_dir}{os.sep}{fid}.{self.ppg_file_ext}")
f0 = np.load(f"{self.libri_f0_dir}{os.sep}{fid}.{self.f0_file_ext}")
mel = self.compute_mel(f"{self.libri_wav_dir}{os.sep}{sub}{os.sep}{fid}.{self.wav_file_ext}")
if self.min_max_norm_mel:
mel = self.bin_level_min_max_norm(mel)
f0, ppg, mel = self._adjust_lengths(f0, ppg, mel, fid)
spk_dvec = self.get_spk_dvec(fid)
# 2. Convert f0 to continuous log-f0 and u/v flags
uv, cont_lf0 = get_cont_lf0(f0, 10.0, False)
# cont_lf0 = (cont_lf0 - np.amin(cont_lf0)) / (np.amax(cont_lf0) - np.amin(cont_lf0))
# cont_lf0 = self.utt_mvn(cont_lf0)
lf0_uv = np.concatenate([cont_lf0[:, np.newaxis], uv[:, np.newaxis]], axis=1)
# uv, cont_f0 = convert_continuous_f0(f0)
# cont_f0 = (cont_f0 - np.amin(cont_f0)) / (np.amax(cont_f0) - np.amin(cont_f0))
# lf0_uv = np.concatenate([cont_f0[:, np.newaxis], uv[:, np.newaxis]], axis=1)
# 3. Convert numpy array to torch.tensor
ppg = torch.from_numpy(ppg)
lf0_uv = torch.from_numpy(lf0_uv)
mel = torch.from_numpy(mel)
return (ppg, lf0_uv, mel, spk_dvec, fid)
def check_lengths(self, f0, ppg, mel, fid):
LEN_THRESH = 10
assert abs(len(ppg) - len(f0)) <= LEN_THRESH, \
f"{abs(len(ppg) - len(f0))}: for file {fid}"
assert abs(len(mel) - len(f0)) <= LEN_THRESH, \
f"{abs(len(mel) - len(f0))}: for file {fid}"
def _adjust_lengths(self, f0, ppg, mel, fid):
self.check_lengths(f0, ppg, mel, fid)
min_len = min(
len(f0),
len(ppg),
len(mel),
)
f0 = f0[:min_len]
ppg = ppg[:min_len]
mel = mel[:min_len]
return f0, ppg, mel
class MultiSpkVcCollate():
"""Zero-pads model inputs and targets based on number of frames per step
"""
def __init__(self, n_frames_per_step=1, give_uttids=False,
f02ppg_length_ratio=1, use_spk_dvec=False):
self.n_frames_per_step = n_frames_per_step
self.give_uttids = give_uttids
self.f02ppg_length_ratio = f02ppg_length_ratio
self.use_spk_dvec = use_spk_dvec
def __call__(self, batch):
batch_size = len(batch)
# Prepare different features
ppgs = [x[0] for x in batch]
lf0_uvs = [x[1] for x in batch]
mels = [x[2] for x in batch]
fids = [x[-1] for x in batch]
if len(batch[0]) == 5:
spk_ids = [x[3] for x in batch]
if self.use_spk_dvec:
# use d-vector
spk_ids = torch.stack(spk_ids).float()
else:
# use one-hot ids
spk_ids = torch.LongTensor(spk_ids)
# Pad features into chunk
ppg_lengths = [x.shape[0] for x in ppgs]
mel_lengths = [x.shape[0] for x in mels]
max_ppg_len = max(ppg_lengths)
max_mel_len = max(mel_lengths)
if max_mel_len % self.n_frames_per_step != 0:
max_mel_len += (self.n_frames_per_step - max_mel_len % self.n_frames_per_step)
ppg_dim = ppgs[0].shape[1]
mel_dim = mels[0].shape[1]
ppgs_padded = torch.FloatTensor(batch_size, max_ppg_len, ppg_dim).zero_()
mels_padded = torch.FloatTensor(batch_size, max_mel_len, mel_dim).zero_()
lf0_uvs_padded = torch.FloatTensor(batch_size, self.f02ppg_length_ratio * max_ppg_len, 2).zero_()
stop_tokens = torch.FloatTensor(batch_size, max_mel_len).zero_()
for i in range(batch_size):
cur_ppg_len = ppgs[i].shape[0]
cur_mel_len = mels[i].shape[0]
ppgs_padded[i, :cur_ppg_len, :] = ppgs[i]
lf0_uvs_padded[i, :self.f02ppg_length_ratio*cur_ppg_len, :] = lf0_uvs[i]
mels_padded[i, :cur_mel_len, :] = mels[i]
stop_tokens[i, cur_ppg_len-self.n_frames_per_step:] = 1
if len(batch[0]) == 5:
ret_tup = (ppgs_padded, lf0_uvs_padded, mels_padded, torch.LongTensor(ppg_lengths), \
torch.LongTensor(mel_lengths), spk_ids, stop_tokens)
if self.give_uttids:
return ret_tup + (fids, )
else:
return ret_tup
else:
ret_tup = (ppgs_padded, lf0_uvs_padded, mels_padded, torch.LongTensor(ppg_lengths), \
torch.LongTensor(mel_lengths), stop_tokens)
if self.give_uttids:
return ret_tup + (fids, )
else:
return ret_tup

124
utils/f0_utils.py Normal file
View File

@@ -0,0 +1,124 @@
import logging
import numpy as np
import pyworld
from scipy.interpolate import interp1d
from scipy.signal import firwin, get_window, lfilter
def compute_mean_std(lf0):
nonzero_indices = np.nonzero(lf0)
mean = np.mean(lf0[nonzero_indices])
std = np.std(lf0[nonzero_indices])
return mean, std
def compute_f0(wav, sr=16000, frame_period=10.0):
"""Compute f0 from wav using pyworld harvest algorithm."""
wav = wav.astype(np.float64)
f0, _ = pyworld.harvest(
wav, sr, frame_period=frame_period, f0_floor=80.0, f0_ceil=600.0)
return f0.astype(np.float32)
def f02lf0(f0):
lf0 = f0.copy()
nonzero_indices = np.nonzero(f0)
lf0[nonzero_indices] = np.log(f0[nonzero_indices])
return lf0
def get_converted_lf0uv(
wav,
lf0_mean_trg,
lf0_std_trg,
convert=True,
):
f0_src = compute_f0(wav)
if not convert:
uv, cont_lf0 = get_cont_lf0(f0_src)
lf0_uv = np.concatenate([cont_lf0[:, np.newaxis], uv[:, np.newaxis]], axis=1)
return lf0_uv
lf0_src = f02lf0(f0_src)
lf0_mean_src, lf0_std_src = compute_mean_std(lf0_src)
lf0_vc = lf0_src.copy()
lf0_vc[lf0_src > 0.0] = (lf0_src[lf0_src > 0.0] - lf0_mean_src) / lf0_std_src * lf0_std_trg + lf0_mean_trg
f0_vc = lf0_vc.copy()
f0_vc[lf0_src > 0.0] = np.exp(lf0_vc[lf0_src > 0.0])
uv, cont_lf0_vc = get_cont_lf0(f0_vc)
lf0_uv = np.concatenate([cont_lf0_vc[:, np.newaxis], uv[:, np.newaxis]], axis=1)
return lf0_uv
def low_pass_filter(x, fs, cutoff=70, padding=True):
"""FUNCTION TO APPLY LOW PASS FILTER
Args:
x (ndarray): Waveform sequence
fs (int): Sampling frequency
cutoff (float): Cutoff frequency of low pass filter
Return:
(ndarray): Low pass filtered waveform sequence
"""
nyquist = fs // 2
norm_cutoff = cutoff / nyquist
# low cut filter
numtaps = 255
fil = firwin(numtaps, norm_cutoff)
x_pad = np.pad(x, (numtaps, numtaps), 'edge')
lpf_x = lfilter(fil, 1, x_pad)
lpf_x = lpf_x[numtaps + numtaps // 2: -numtaps // 2]
return lpf_x
def convert_continuos_f0(f0):
"""CONVERT F0 TO CONTINUOUS F0
Args:
f0 (ndarray): original f0 sequence with the shape (T)
Return:
(ndarray): continuous f0 with the shape (T)
"""
# get uv information as binary
uv = np.float32(f0 != 0)
# get start and end of f0
if (f0 == 0).all():
logging.warn("all of the f0 values are 0.")
return uv, f0
start_f0 = f0[f0 != 0][0]
end_f0 = f0[f0 != 0][-1]
# padding start and end of f0 sequence
start_idx = np.where(f0 == start_f0)[0][0]
end_idx = np.where(f0 == end_f0)[0][-1]
f0[:start_idx] = start_f0
f0[end_idx:] = end_f0
# get non-zero frame index
nz_frames = np.where(f0 != 0)[0]
# perform linear interpolation
f = interp1d(nz_frames, f0[nz_frames])
cont_f0 = f(np.arange(0, f0.shape[0]))
return uv, cont_f0
def get_cont_lf0(f0, frame_period=10.0, lpf=False):
uv, cont_f0 = convert_continuos_f0(f0)
if lpf:
cont_f0_lpf = low_pass_filter(cont_f0, int(1.0 / (frame_period * 0.001)), cutoff=20)
cont_lf0_lpf = cont_f0_lpf.copy()
nonzero_indices = np.nonzero(cont_lf0_lpf)
cont_lf0_lpf[nonzero_indices] = np.log(cont_f0_lpf[nonzero_indices])
# cont_lf0_lpf = np.log(cont_f0_lpf)
return uv, cont_lf0_lpf
else:
nonzero_indices = np.nonzero(cont_f0)
cont_lf0 = cont_f0.copy()
cont_lf0[cont_f0>0] = np.log(cont_f0[cont_f0>0])
return uv, cont_lf0

58
utils/load_yaml.py Normal file
View File

@@ -0,0 +1,58 @@
import yaml
def load_hparams(filename):
stream = open(filename, 'r')
docs = yaml.safe_load_all(stream)
hparams_dict = dict()
for doc in docs:
for k, v in doc.items():
hparams_dict[k] = v
return hparams_dict
def merge_dict(user, default):
if isinstance(user, dict) and isinstance(default, dict):
for k, v in default.items():
if k not in user:
user[k] = v
else:
user[k] = merge_dict(user[k], v)
return user
class Dotdict(dict):
"""
a dictionary that supports dot notation
as well as dictionary access notation
usage: d = DotDict() or d = DotDict({'val1':'first'})
set attributes: d.val2 = 'second' or d['val2'] = 'second'
get attributes: d.val2 or d['val2']
"""
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def __init__(self, dct=None):
dct = dict() if not dct else dct
for key, value in dct.items():
if hasattr(value, 'keys'):
value = Dotdict(value)
self[key] = value
class HpsYaml(Dotdict):
def __init__(self, yaml_file):
super(Dotdict, self).__init__()
hps = load_hparams(yaml_file)
hp_dict = Dotdict(hps)
for k, v in hp_dict.items():
setattr(self, k, v)
__getattr__ = Dotdict.__getitem__
__setattr__ = Dotdict.__setitem__
__delattr__ = Dotdict.__delitem__

View File

@@ -11,7 +11,6 @@ def check_model_paths(encoder_path: Path, synthesizer_path: Path, vocoder_path:
# If none of the paths exist, remind the user to download models if needed
print("********************************************************************************")
print("Error: Model files not found. Follow these instructions to get and install the models:")
print("https://github.com/CorentinJ/Real-Time-Voice-Cloning/wiki/Pretrained-models")
print("Error: Model files not found. Please download the models")
print("********************************************************************************\n")
quit(-1)

44
utils/util.py Normal file
View File

@@ -0,0 +1,44 @@
import matplotlib
matplotlib.use('Agg')
import time
class Timer():
''' Timer for recording training time distribution. '''
def __init__(self):
self.prev_t = time.time()
self.clear()
def set(self):
self.prev_t = time.time()
def cnt(self, mode):
self.time_table[mode] += time.time()-self.prev_t
self.set()
if mode == 'bw':
self.click += 1
def show(self):
total_time = sum(self.time_table.values())
self.time_table['avg'] = total_time/self.click
self.time_table['rd'] = 100*self.time_table['rd']/total_time
self.time_table['fw'] = 100*self.time_table['fw']/total_time
self.time_table['bw'] = 100*self.time_table['bw']/total_time
msg = '{avg:.3f} sec/step (rd {rd:.1f}% | fw {fw:.1f}% | bw {bw:.1f}%)'.format(
**self.time_table)
self.clear()
return msg
def clear(self):
self.time_table = {'rd': 0, 'fw': 0, 'bw': 0}
self.click = 0
# Reference : https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/e2e_asr.py#L168
def human_format(num):
magnitude = 0
while num >= 1000:
magnitude += 1
num /= 1000.0
# add more suffixes if you need them
return '{:3.1f}{}'.format(num, [' ', 'K', 'M', 'G', 'T', 'P'][magnitude])

View File

@@ -3,14 +3,11 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import json
import torch
from scipy.io.wavfile import write
from vocoder.hifigan.env import AttrDict
from vocoder.hifigan.meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav
from vocoder.hifigan.models import Generator
import soundfile as sf
generator = None # type: Generator
output_sample_rate = None
_device = None
@@ -22,16 +19,17 @@ def load_checkpoint(filepath, device):
return checkpoint_dict
def load_model(weights_fpath, verbose=True):
global generator, _device
def load_model(weights_fpath, config_fpath="./vocoder/saved_models/24k/config.json", verbose=True):
global generator, _device, output_sample_rate
if verbose:
print("Building hifigan")
with open("./vocoder/hifigan/config_16k_.json") as f:
with open(config_fpath) as f:
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
output_sample_rate = h.sampling_rate
torch.manual_seed(h.seed)
if torch.cuda.is_available():
@@ -66,5 +64,5 @@ def infer_waveform(mel, progress_callback=None):
audio = y_g_hat.squeeze()
audio = audio.cpu().numpy()
return audio
return audio, output_sample_rate

View File

@@ -71,6 +71,24 @@ class ResBlock2(torch.nn.Module):
for l in self.convs:
remove_weight_norm(l)
class InterpolationBlock(torch.nn.Module):
def __init__(self, scale_factor, mode='nearest', align_corners=None, downsample=False):
super(InterpolationBlock, self).__init__()
self.downsample = downsample
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
outputs = torch.nn.functional.interpolate(
x,
size=x.shape[-1] * self.scale_factor \
if not self.downsample else x.shape[-1] // self.scale_factor,
mode=self.mode,
align_corners=self.align_corners,
recompute_scale_factor=False
)
return outputs
class Generator(torch.nn.Module):
def __init__(self, h):
@@ -82,14 +100,27 @@ class Generator(torch.nn.Module):
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
# self.ups.append(weight_norm(
# ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
# k, u, padding=(k-u)//2)))
self.ups.append(weight_norm(ConvTranspose1d(h.upsample_initial_channel//(2**i),
h.upsample_initial_channel//(2**(i+1)),
k, u, padding=(u//2 + u%2), output_padding=u%2)))
# for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
# # self.ups.append(weight_norm(
# # ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
# # k, u, padding=(k-u)//2)))
if h.sampling_rate == 24000:
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(
torch.nn.Sequential(
InterpolationBlock(u),
weight_norm(torch.nn.Conv1d(
h.upsample_initial_channel//(2**i),
h.upsample_initial_channel//(2**(i+1)),
k, padding=(k-1)//2,
))
)
)
else:
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(weight_norm(ConvTranspose1d(h.upsample_initial_channel//(2**i),
h.upsample_initial_channel//(2**(i+1)),
k, u, padding=(u//2 + u%2), output_padding=u%2)))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel//(2**(i+1))
@@ -121,7 +152,10 @@ class Generator(torch.nn.Module):
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
if self.h.sampling_rate == 24000:
remove_weight_norm(l[-1])
else:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)

View File

@@ -23,11 +23,11 @@ torch.backends.cudnn.benchmark = True
def train(rank, a, h):
a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_hifigan')
a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_hifigan')
a.checkpoint_path.mkdir(exist_ok=True)
a.training_epochs = 3100
a.stdout_interval = 5
a.checkpoint_interval = 25000
a.checkpoint_interval = a.backup_every
a.summary_interval = 5000
a.validation_interval = 1000
a.fine_tuning = True
@@ -186,11 +186,9 @@ def train(rank, a, h):
save_checkpoint(checkpoint_path,
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
save_checkpoint(checkpoint_path,
{'mpd': (mpd.module if h.num_gpus > 1
else mpd).state_dict(),
'msd': (msd.module if h.num_gpus > 1
else msd).state_dict(),
save_checkpoint(checkpoint_path,
{'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(),
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
'epoch': epoch})
@@ -198,6 +196,19 @@ def train(rank, a, h):
if steps % a.summary_interval == 0:
sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
sw.add_scalar("training/mel_spec_error", mel_error, steps)
# save temperate hifigan model
if steps % a.save_every == 0:
checkpoint_path = "{}/g_hifigan.pt".format(a.checkpoint_path)
save_checkpoint(checkpoint_path,
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
checkpoint_path = "{}/do_hifigan".format(a.checkpoint_path)
save_checkpoint(checkpoint_path,
{'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(),
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
'epoch': epoch})
# Validation
if steps % a.validation_interval == 0: # and steps != 0:

View File

@@ -61,4 +61,4 @@ def infer_waveform(mel, normalize=True, batched=True, target=8000, overlap=800,
mel = mel / hp.mel_max_abs_value
mel = torch.from_numpy(mel[None, ...])
wav = _model.generate(mel, batched, target, overlap, hp.mu_law, progress_callback)
return wav
return wav, hp.sample_rate

View File

@@ -9,10 +9,12 @@ from vocoder.wavernn import inference as rnn_vocoder
import numpy as np
import re
from scipy.io.wavfile import write
import librosa
import io
import base64
from flask_cors import CORS
from flask_wtf import CSRFProtect
import webbrowser
def webApp():
# Init and load config
@@ -29,6 +31,7 @@ def webApp():
synthesizers = list(Path(syn_models_dirt).glob("**/*.pt"))
synthesizers_cache = {}
encoder.load_model(Path("encoder/saved_models/pretrained.pt"))
rnn_vocoder.load_model(Path("vocoder/saved_models/pretrained/pretrained.pt"))
gan_vocoder.load_model(Path("vocoder/saved_models/pretrained/g_hifigan.pt"))
def pcm2float(sig, dtype='float32'):
@@ -65,7 +68,6 @@ def webApp():
@app.route("/api/synthesize", methods=["POST"])
def synthesize():
# TODO Implementation with json to support more platform
# Load synthesizer
if "synt_path" in request.form:
synt_path = request.form["synt_path"]
@@ -79,10 +81,16 @@ def webApp():
current_synt = synthesizers_cache[synt_path]
print("using synthesizer model: " + str(synt_path))
# Load input wav
wav_base64 = request.form["upfile_b64"]
wav = base64.b64decode(bytes(wav_base64, 'utf-8'))
wav = pcm2float(np.frombuffer(wav, dtype=np.int16), dtype=np.float32)
encoder_wav = encoder.preprocess_wav(wav, 16000)
if "upfile_b64" in request.form:
wav_base64 = request.form["upfile_b64"]
wav = base64.b64decode(bytes(wav_base64, 'utf-8'))
wav = pcm2float(np.frombuffer(wav, dtype=np.int16), dtype=np.float32)
sample_rate = Synthesizer.sample_rate
else:
wav, sample_rate, = librosa.load(request.files['file'])
write("temp.wav", sample_rate, wav) #Make sure we get the correct wav
encoder_wav = encoder.preprocess_wav(wav, sample_rate)
embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
# Load input text
@@ -99,11 +107,15 @@ def webApp():
embeds = [embed] * len(texts)
specs = current_synt.synthesize_spectrograms(texts, embeds)
spec = np.concatenate(specs, axis=1)
wav = gan_vocoder.infer_waveform(spec)
sample_rate = Synthesizer.sample_rate
if "vocoder" in request.form and request.form["vocoder"] == "WaveRNN":
wav = rnn_vocoder.infer_waveform(spec)
else:
wav, sample_rate = gan_vocoder.infer_waveform(spec)
# Return cooked wav
out = io.BytesIO()
write(out, Synthesizer.sample_rate, wav)
write(out, sample_rate, wav.astype(np.float32))
return Response(out, mimetype="audio/wav")
@app.route('/', methods=['GET'])
@@ -112,10 +124,11 @@ def webApp():
host = app.config.get("HOST")
port = app.config.get("PORT")
print(f"Web server: http://{host}:{port}")
web_address = 'http://{}:{}'.format(host, port)
print(f"Web server:" + web_address)
webbrowser.open(web_address)
server = wsgi.WSGIServer((host, port), app)
server.serve_forever()
return app
if __name__ == "__main__":

View File

@@ -5,3 +5,4 @@ PORT = 8080
MAX_CONTENT_PATH =1024 * 1024 * 4 # mp3文件大小限定不能超过4M
SECRET_KEY = "mockingbird_key"
WTF_CSRF_SECRET_KEY = "mockingbird_key"
TEMPLATES_AUTO_RELOAD = True

BIN
web/static/img/bird-sm.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

BIN
web/static/img/bird.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 89 KiB

View File

@@ -4,8 +4,7 @@
<head>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=0">
<link rel="shortcut icon" type="image/png"
href="https://cdn.jsdelivr.net/gh/xiangyuecn/Recorder@latest/assets/icon.png">
<link rel="shortcut icon" type="image/png" href="../static/img/bird-sm.png">
<title>MockingBird Web Server</title>
@@ -24,50 +23,114 @@
<div class="main">
<div class="mainBox">
<div class="pd btns">
<div class="title" >
<div style="width: 15%;float: left;margin-left: 5%;">
<img src="../static/img/bird.png" style="width: 100%;border-radius:50%;"></img>
</div>
<div style="width: 80% ;height: 15%;; margin-left: 15%;overflow: hidden;">
<div style="margin-left: 5%;margin-top: 15px;font-size: xx-large;font-weight: bolder;">
拟声鸟工具箱
</div>
<div style="margin-left: 5%;margin-top: 3px;font-size: large;">
<a href="https://github.com/babysor/MockingBird" target="_blank">https://github.com/babysor/MockingBird</a>
</div>
</div>
</div>
<div style="margin-left: 5%;margin-top: 50px;width: 90%;">
<div style="font-size: larger;font-weight: bolder;">1. 请输入中文</div>
<textarea id="user_input_text"
style="border:1px solid #ccc; width: 100%; height: 100px; font-size: 15px; margin-top: 10px;"></textarea>
</div>
<div class="pd btns" style="margin-left: 5%;margin-top: 20px;width: 90%; ">
<!-- <div>
<button onclick="recOpen()" style="margin-right:10px">打开录音,请求权限</button>
<button onclick="recClose()" style="margin-right:0">关闭录音,释放资源</button>
</div> -->
<button onclick="recStart()" style="margin-left:100px">录制</button>
<button onclick="recStop()" style="margin-left:100px">停止</button>
<button onclick="recPlay()" style="margin-left:100px">播放</button>
<div style="font-size: larger;font-weight: bolder;">2. 请直接录音,点击停止结束</div>
<button onclick="recStart()" >录制</button>
<button onclick="recStop()">停止</button>
<button onclick="recPlay()" >播放</button>
</div>
<div class="pd btns" style="margin-left: 5%;margin-top: 20px;width: 90%; ">
<div style="font-size: larger;font-weight: bolder;">或上传音频</div>
<input type="file" id="fileInput" accept=".wav" />
<label for="fileInput">选择音频</label>
<div id="audio1"></div>
</div>
<div class="pd btns" style="margin-left: 5%;margin-top: 20px;width: 90%; ">
<div style="font-size: larger;font-weight: bolder;">3. 选择Synthesizer模型</div>
<span class="box">
<select id="selectSynt">
</select>
</span>
</div>
<div class="pd btns" style="margin-left: 5%;margin-top: 20px;width: 90%; ">
<div style="font-size: larger;font-weight: bolder;">4. 选择Vocoder模型</div>
<span class="box">
<select id="selectVocoder">
<option>WaveRNN</option>
<option>HifiGAN</option>
</select>
</span>
</div>
<div class="pd btns" style="margin-left: 5%;margin-top: 20px;width: 90%; text-align:right;">
<button id="upload" onclick="recUpload()">上传合成</button>
</div>
<!-- 波形绘制区域 -->
<div class="pd recpower">
<!-- <div class="pd recpower">
<div style="height:40px;width:100%;background:#fff;position:relative;">
<div class="recpowerx" style="height:40px;background:#ff3295;position:absolute;"></div>
<div class="recpowert" style="padding-left:50px; line-height:40px; position: relative;"></div>
</div>
</div>
<div class="pd waveBox" style="height:100px;">
</div> -->
<!-- <div class="pd waveBox" style="height:100px;">
<div style="border:1px solid #ccc;display:inline-block; width: 100%; height: 100px;">
<div style="height:100px; width: 100%; background-color: #FE76B8; position: relative;left: 0px;top: 0px;z-index: 10;"
<div style="height:100px; width: 100%; background-color: #5da1f5; position: relative;left: 0px;top: 0px;z-index: 10;"
class="recwave"></div>
<div
style="background-color: transparent;position: relative;top: -80px;left: 30%;z-index: 20;font-size: 48px;color: #fff;">
音频预览</div>
</div>
</div>
<div>
<div>请输入文本:</div>
<input type="text" id="user_input_text"
style="border:1px solid #ccc; width: 100%; height: 20px; font-size: 18px;" />
</div>
<div class="pd btns">
<button onclick="recUpload()" style="margin-left: 300px; margin-top: 15px;">上传</button>
</div>
</div>
<!-- 日志输出区域 -->
<div class="mainBox">
<div class="reclog"></div>
</div> -->
<div class="reclog" style="margin-left: 5%;margin-top: 20px;width: 90%;"></div>
</div>
</div>
<script>
$("#fileInput").change(function(){
var file = $("#fileInput").get(0).files;
if (file.length > 0) {
var path = URL.createObjectURL(file[0]);
var audio = document.createElement('audio');
audio.src = path;
audio.controls = true;
$('#audio1').empty().append(audio);
}
});
fetch("/api/synthesizers", {
method: 'get',
headers: {
"X-CSRFToken": "{{ csrf_token() }}"
}
}).then(function (res) {
if (!res.ok) throw Error(res.statusText);
return res.json();
}).then(function (data) {
for (var synt of data) {
var option = document.createElement('option');
option.text = synt.name
option.value = synt.path
$("#selectSynt").append(option);
}
}).catch(function (err) {
console.log('Error: ' + err.message);
})
var rec, wave, recBlob;
/**调用open打开录音请求好录音权限**/
var recOpen = function () {//一般在显示出录音按钮或相关的录音界面时进行此方法调用,后面用户点击开始录音时就能畅通无阻了
@@ -78,11 +141,11 @@
type: "wav", bitRate: 16, sampleRate: 16000
, onProcess: function (buffers, powerLevel, bufferDuration, bufferSampleRate, newBufferIdx, asyncEnd) {
//录音实时回调大约1秒调用12次本回调
document.querySelector(".recpowerx").style.width = powerLevel + "%";
document.querySelector(".recpowert").innerText = bufferDuration + " / " + powerLevel;
// document.querySelector(".recpowerx").style.width = powerLevel + "%";
// document.querySelector(".recpowert").innerText = bufferDuration + " / " + powerLevel;
//可视化图形绘制
wave.input(buffers[buffers.length - 1], powerLevel, bufferSampleRate);
// wave.input(buffers[buffers.length - 1], powerLevel, bufferSampleRate);
}
});
@@ -93,7 +156,7 @@
rec = newRec;
//此处创建这些音频可视化图形绘制浏览器支持妥妥的
wave = Recorder.FrequencyHistogramView({ elem: ".recwave" });
// wave = Recorder.FrequencyHistogramView({ elem: ".recwave" });
reclog("已打开录音,可以点击录制开始录音了", 2);
}, function (msg, isUserNotAllow) {//用户拒绝未授权或不支持
@@ -186,15 +249,21 @@
/**上传**/
function recUpload() {
var blob = recBlob;
var blob
var loadedAudios = $("#fileInput").get(0).files
if (loadedAudios.length > 0) {
blob = loadedAudios[0];
} else {
blob = recBlob;
}
if (!blob) {
reclog("请先录音,然后停止后再上传", 1);
reclog("请先录音或选择音频,然后停止后再上传", 1);
return;
};
//本例子假设使用原始XMLHttpRequest请求方式实际使用中自行调整为自己的请求方式
//录音结束时拿到了blob文件对象可以用FileReader读取出内容或者用FormData上传
var api = "http://127.0.0.1:8080/api/synthesize";
var api = "/api/synthesize";
reclog("开始上传到" + api + ",请求稍后...");
@@ -203,15 +272,23 @@
var csrftoken = "{{ csrf_token() }}";
var user_input_text = document.getElementById("user_input_text");
var input_text = user_input_text.value;
var postData = "";
postData += "mime=" + encodeURIComponent(blob.type);//告诉后端这个录音是什么格式的可能前后端都固定的mp3可以不用写
postData += "&upfile_b64=" + encodeURIComponent((/.+;\s*base64\s*,\s*(.+)$/i.exec(reader.result) || [])[1]) //录音文件内容后端进行base64解码成二进制
postData += "&text=" + encodeURIComponent(input_text);
var postData = new FormData();
postData.append("text", input_text)
postData.append("file", blob)
var syntSelect = document.getElementById("selectSynt");
var path = syntSelect.options[syntSelect.selectedIndex].value;
if (!!path) {
postData.append("synt_path", path);
}
var vocoderSelect = document.getElementById("selectVocoder");
var vocoder = vocoderSelect.options[vocoderSelect.selectedIndex].value;
if (!!vocoder) {
postData.append("vocoder", vocoder);
}
fetch(api, {
method: 'post',
headers: {
"Content-type": "application/x-www-form-urlencoded; charset=UTF-8",
"X-CSRFToken": csrftoken
},
body: postData
@@ -277,7 +354,7 @@
var div = document.createElement("div");
var elem = document.querySelector(".reclog");
elem.insertBefore(div, elem.firstChild);
div.innerHTML = '<div style="color:' + (!color ? "" : color == 1 ? "red" : color == 2 ? "#FE76B8" : color) + '">[' + t + ']' + s + '</div>';
div.innerHTML = '<div style="color:' + (!color ? "" : color == 1 ? "#327de8" : color == 2 ? "#5da1f5" : color) + '">[' + t + ']' + s + '</div>';
};
window.onerror = function (message, url, lineNo, columnNo, error) {
reclog('<span style="color:red">【Uncaught Error】' + message + '<pre>' + "at:" + lineNo + ":" + columnNo + " url:" + url + "\n" + (error && error.stack || "不能获得错误堆栈") + '</pre></span>');
@@ -312,11 +389,11 @@
a {
text-decoration: none;
color: #FE76B8;
color: #327de8;
}
a:hover {
color: #f00;
color: #5da1f5;
}
.main {
@@ -330,7 +407,6 @@
padding: 12px;
border-radius: 6px;
background: #fff;
--border: 1px solid #f60;
box-shadow: 2px 2px 3px #aaa;
}
@@ -340,20 +416,31 @@
cursor: pointer;
border: none;
border-radius: 3px;
background: #FE76B8;
background: #5698c3;
color: #fff;
padding: 0 15px;
margin: 3px 20px 3px 0;
margin: 3px 10px 3px 0;
width: 70px;
line-height: 36px;
height: 36px;
overflow: hidden;
vertical-align: middle;
}
.btns button:active {
background: #fd54a6
.btns #upload {
background: #5698c3;
color: #fff;
width: 100px;
height: 42px;
}
.btns button:active {
background: #5da1f5
}
.btns button:hover {
background: #5da1f5
}
.pd {
padding: 0 0 6px 0;
}
@@ -361,12 +448,74 @@
.lb {
display: inline-block;
vertical-align: middle;
background: #ff3d9b;
background: #327de8;
color: #fff;
font-size: 14px;
padding: 2px 8px;
border-radius: 99px;
}
#fileInput {
width: 0.1px;
height: 0.1px;
opacity: 0;
overflow: hidden;
position: absolute;
z-index: -1;
}
#fileInput + label {
padding: 0 15px;
border-radius: 4px;
color: white;
background-color: #5698c3;
display: inline-block;
width: 70px;
line-height: 36px;
height: 36px;
}
#fileInput + label {
cursor: pointer; /* "hand" cursor */
}
#fileInput:focus + label,
#fileInput + label:hover {
background-color: #5da1f5;
}
.box select {
background-color: #5698c3;
color: white;
padding: 8px;
width: 120px;
border: none;
border-radius: 4px;
font-size: 0.5em;
outline: none;
margin: 3px 10px 3px 0;
}
.box::before {
content: "\f13a";
position: absolute;
top: 0;
right: 0;
width: 20%;
height: 100%;
text-align: center;
font-size: 28px;
line-height: 45px;
color: rgba(255, 255, 255, 0.5);
background-color: rgba(255, 255, 255, 0.1);
pointer-events: none;
}
.box:hover::before {
color: rgba(255, 255, 255, 0.6);
background-color: rgba(255, 255, 255, 0.2);
}
.box select option {
padding: 30px;
}
</style>
</body>