mirror of
https://github.com/babysor/Realtime-Voice-Clone-Chinese.git
synced 2026-02-04 02:54:07 +08:00
Compare commits
1 Commits
babysor-pa
...
fregan
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
86ea11affd |
@@ -1,4 +0,0 @@
|
||||
*/saved_models
|
||||
!vocoder/saved_models/pretrained/**
|
||||
!encoder/saved_models/pretrained.pt
|
||||
/datasets
|
||||
1
.github/FUNDING.yml
vendored
1
.github/FUNDING.yml
vendored
@@ -1 +0,0 @@
|
||||
github: babysor
|
||||
17
.github/ISSUE_TEMPLATE/issue.md
vendored
17
.github/ISSUE_TEMPLATE/issue.md
vendored
@@ -1,17 +0,0 @@
|
||||
---
|
||||
name: Issue
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Summary[问题简述(一句话)]**
|
||||
A clear and concise description of what the issue is.
|
||||
|
||||
**Env & To Reproduce[复现与环境]**
|
||||
描述你用的环境、代码版本、模型
|
||||
|
||||
**Screenshots[截图(如有)]**
|
||||
If applicable, add screenshots to help
|
||||
17
.gitignore
vendored
17
.gitignore
vendored
@@ -13,14 +13,11 @@
|
||||
*.bbl
|
||||
*.bcf
|
||||
*.toc
|
||||
*.wav
|
||||
*.sh
|
||||
data/ckpt/*/*
|
||||
!data/ckpt/encoder/pretrained.pt
|
||||
!data/ckpt/vocoder/pretrained/
|
||||
wavs
|
||||
log
|
||||
!/docker-entrypoint.sh
|
||||
!/datasets_download/*.sh
|
||||
/datasets
|
||||
monotonic_align/build
|
||||
monotonic_align/monotonic_align
|
||||
synthesizer/saved_models/*
|
||||
vocoder/saved_models/*
|
||||
encoder/saved_models/*
|
||||
cp_hifigan/*
|
||||
!vocoder/saved_models/pretrained/*
|
||||
!encoder/saved_models/pretrained.pt
|
||||
38
.vscode/launch.json
vendored
38
.vscode/launch.json
vendored
@@ -15,8 +15,7 @@
|
||||
"name": "Python: Vocoder Preprocess",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "control\\cli\\vocoder_preprocess.py",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"program": "vocoder_preprocess.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["..\\audiodata"]
|
||||
},
|
||||
@@ -24,8 +23,7 @@
|
||||
"name": "Python: Vocoder Train",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "control\\cli\\vocoder_train.py",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"program": "vocoder_train.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["dev", "..\\audiodata"]
|
||||
},
|
||||
@@ -34,44 +32,16 @@
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "demo_toolbox.py",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["-d","..\\audiodata"]
|
||||
},
|
||||
{
|
||||
"name": "Python: Demo Box VC",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "demo_toolbox.py",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["-d","..\\audiodata","-vc"]
|
||||
},
|
||||
{
|
||||
"name": "Python: Synth Train",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "train.py",
|
||||
"program": "synthesizer_train.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--type", "vits"]
|
||||
},
|
||||
{
|
||||
"name": "Python: PPG Convert",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "run.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["-c", ".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml",
|
||||
"-m", ".\\ppg2mel\\saved_models\\best_loss_step_304000.pth", "--wav_dir", ".\\wavs\\input", "--ref_wav_path", ".\\wavs\\pkq.mp3", "-o", ".\\wavs\\output\\"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Python: Vits Train",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "train.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--type", "vits"]
|
||||
"args": ["my_run", "..\\"]
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
17
Dockerfile
17
Dockerfile
@@ -1,17 +0,0 @@
|
||||
FROM pytorch/pytorch:latest
|
||||
|
||||
RUN apt-get update && apt-get install -y build-essential ffmpeg parallel aria2 && apt-get clean
|
||||
|
||||
COPY ./requirements.txt /workspace/requirements.txt
|
||||
|
||||
RUN pip install -r requirements.txt && pip install webrtcvad-wheels
|
||||
|
||||
COPY . /workspace
|
||||
|
||||
VOLUME [ "/datasets", "/workspace/synthesizer/saved_models/" ]
|
||||
|
||||
ENV DATASET_MIRROR=default FORCE_RETRAIN=false TRAIN_DATASETS=aidatatang_200zh\ magicdata\ aishell3\ data_aishell TRAIN_SKIP_EXISTING=true
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
ENTRYPOINT [ "/workspace/docker-entrypoint.sh" ]
|
||||
183
README-CN.md
183
README-CN.md
@@ -18,12 +18,10 @@
|
||||
|
||||
🌍 **Webserver Ready** 可伺服你的训练结果,供远程调用
|
||||
|
||||
|
||||
## 开始
|
||||
### 1. 安装要求
|
||||
#### 1.1 通用配置
|
||||
> 按照原始存储库测试您是否已准备好所有环境。
|
||||
运行工具箱(demo_toolbox.py)需要 **Python 3.7 或更高版本** 。
|
||||
**Python 3.7 或更高版本** 需要运行工具箱。
|
||||
|
||||
* 安装 [PyTorch](https://pytorch.org/get-started/locally/)。
|
||||
> 如果在用 pip 方式安装的时候出现 `ERROR: Could not find a version that satisfies the requirement torch==1.9.0+cu102 (from versions: 0.1.2, 0.1.2.post1, 0.1.2.post2)` 这个错误可能是 python 版本过低,3.9 可以安装成功
|
||||
@@ -31,67 +29,6 @@
|
||||
* 运行`pip install -r requirements.txt` 来安装剩余的必要包。
|
||||
* 安装 webrtcvad `pip install webrtcvad-wheels`。
|
||||
|
||||
或者
|
||||
- 用`conda` 或者 `mamba` 安装依赖
|
||||
|
||||
```conda env create -n env_name -f env.yml```
|
||||
|
||||
```mamba env create -n env_name -f env.yml```
|
||||
|
||||
会创建新环境安装必须的依赖. 之后用 `conda activate env_name` 切换环境就完成了.
|
||||
> env.yml只包含了运行时必要的依赖,暂时不包括monotonic-align,如果想要装GPU版本的pytorch可以查看官网教程。
|
||||
|
||||
#### 1.2 M1芯片Mac环境配置(Inference Time)
|
||||
> 以下环境按x86-64搭建,使用原生的`demo_toolbox.py`,可作为在不改代码情况下快速使用的workaround。
|
||||
>
|
||||
> 如需使用M1芯片训练,因`demo_toolbox.py`依赖的`PyQt5`不支持M1,则应按需修改代码,或者尝试使用`web.py`。
|
||||
|
||||
* 安装`PyQt5`,参考[这个链接](https://stackoverflow.com/a/68038451/20455983)
|
||||
* 用Rosetta打开Terminal,参考[这个链接](https://dev.to/courier/tips-and-tricks-to-setup-your-apple-m1-for-development-547g)
|
||||
* 用系统Python创建项目虚拟环境
|
||||
```
|
||||
/usr/bin/python3 -m venv /PathToMockingBird/venv
|
||||
source /PathToMockingBird/venv/bin/activate
|
||||
```
|
||||
* 升级pip并安装`PyQt5`
|
||||
```
|
||||
pip install --upgrade pip
|
||||
pip install pyqt5
|
||||
```
|
||||
* 安装`pyworld`和`ctc-segmentation`
|
||||
> 这里两个文件直接`pip install`的时候找不到wheel,尝试从c里build时找不到`Python.h`报错
|
||||
* 安装`pyworld`
|
||||
* `brew install python` 通过brew安装python时会自动安装`Python.h`
|
||||
* `export CPLUS_INCLUDE_PATH=/opt/homebrew/Frameworks/Python.framework/Headers` 对于M1,brew安装`Python.h`到上述路径。把路径添加到环境变量里
|
||||
* `pip install pyworld`
|
||||
|
||||
* 安装`ctc-segmentation`
|
||||
> 因上述方法没有成功,选择从[github](https://github.com/lumaku/ctc-segmentation) clone源码手动编译
|
||||
* `git clone https://github.com/lumaku/ctc-segmentation.git` 克隆到任意位置
|
||||
* `cd ctc-segmentation`
|
||||
* `source /PathToMockingBird/venv/bin/activate` 假设一开始未开启,打开MockingBird项目的虚拟环境
|
||||
* `cythonize -3 ctc_segmentation/ctc_segmentation_dyn.pyx`
|
||||
* `/usr/bin/arch -x86_64 python setup.py build` 要注意明确用x86-64架构编译
|
||||
* `/usr/bin/arch -x86_64 python setup.py install --optimize=1 --skip-build`用x86-64架构安装
|
||||
|
||||
* 安装其他依赖
|
||||
* `/usr/bin/arch -x86_64 pip install torch torchvision torchaudio` 这里用pip安装`PyTorch`,明确架构是x86
|
||||
* `pip install ffmpeg` 安装ffmpeg
|
||||
* `pip install -r requirements.txt`
|
||||
|
||||
* 运行
|
||||
> 参考[这个链接](https://youtrack.jetbrains.com/issue/PY-46290/Allow-running-Python-under-Rosetta-2-in-PyCharm-for-Apple-Silicon)
|
||||
,让项目跑在x86架构环境上
|
||||
* `vim /PathToMockingBird/venv/bin/pythonM1`
|
||||
* 写入以下代码
|
||||
```
|
||||
#!/usr/bin/env zsh
|
||||
mydir=${0:a:h}
|
||||
/usr/bin/arch -x86_64 $mydir/python "$@"
|
||||
```
|
||||
* `chmod +x pythonM1` 设为可执行文件
|
||||
* 如果使用PyCharm,则把Interpreter指向`pythonM1`,否则也可命令行运行`/PathToMockingBird/venv/bin/pythonM1 demo_toolbox.py`
|
||||
|
||||
### 2. 准备预训练模型
|
||||
考虑训练您自己专属的模型或者下载社区他人训练好的模型:
|
||||
> 近期创建了[知乎专题](https://www.zhihu.com/column/c_1425605280340504576) 将不定期更新炼丹小技巧or心得,也欢迎提问
|
||||
@@ -113,7 +50,7 @@
|
||||
> 假如你下载的 `aidatatang_200zh`文件放在D盘,`train`文件路径为 `D:\data\aidatatang_200zh\corpus\train` , 你的`datasets_root`就是 `D:\data\`
|
||||
|
||||
* 训练合成器:
|
||||
`python ./control/cli/synthesizer_train.py mandarin <datasets_root>/SV2TTS/synthesizer`
|
||||
`python synthesizer_train.py mandarin <datasets_root>/SV2TTS/synthesizer`
|
||||
|
||||
* 当您在训练文件夹 *synthesizer/saved_models/* 中看到注意线显示和损失满足您的需要时,请转到`启动程序`一步。
|
||||
|
||||
@@ -131,27 +68,33 @@
|
||||
对效果影响不大,已经预置3款,如果希望自己训练可以参考以下命令。
|
||||
* 预处理数据:
|
||||
`python vocoder_preprocess.py <datasets_root> -m <synthesizer_model_path>`
|
||||
> `<datasets_root>`替换为你的数据集目录,`<synthesizer_model_path>`替换为一个你最好的synthesizer模型目录,例如 *sythensizer\saved_models\xxx*
|
||||
> `<datasets_root>`替换为你的数据集目录,`<synthesizer_model_path>`替换为一个你最好的synthesizer模型目录,例如 *sythensizer\saved_mode\xxx*
|
||||
|
||||
|
||||
* 训练wavernn声码器:
|
||||
`python ./control/cli/vocoder_train.py <trainid> <datasets_root>`
|
||||
`python vocoder_train.py <trainid> <datasets_root>`
|
||||
> `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
|
||||
|
||||
* 训练hifigan声码器:
|
||||
`python ./control/cli/vocoder_train.py <trainid> <datasets_root> hifigan`
|
||||
`python vocoder_train.py <trainid> <datasets_root> hifigan`
|
||||
> `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
|
||||
* 训练fregan声码器:
|
||||
`python ./control/cli/vocoder_train.py <trainid> <datasets_root> --config config.json fregan`
|
||||
|
||||
* 训练Fre-GAN声码器:
|
||||
`python vocoder_train.py <trainid> <datasets_root> --config config.json fregan`
|
||||
> `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
|
||||
* 将GAN声码器的训练切换为多GPU模式:修改GAN文件夹下.json文件中的"num_gpus"参数
|
||||
|
||||
### 3. 启动程序或工具箱
|
||||
您可以尝试使用以下命令:
|
||||
|
||||
### 3.1 启动Web程序(v2):
|
||||
### 3.1 启动Web程序:
|
||||
`python web.py`
|
||||
运行成功后在浏览器打开地址, 默认为 `http://localhost:8080`
|
||||

|
||||
> 注:目前界面比较buggy,
|
||||
> * 第一次点击`录制`要等待几秒浏览器正常启动录音,否则会有重音
|
||||
> * 录制结束不要再点`录制`而是`停止`
|
||||
> * 仅支持手动新录音(16khz), 不支持超过4MB的录音,最佳长度在5~15秒
|
||||
> * 默认使用第一个找到的模型,有动手能力的可以看代码修改 `web\__init__.py`。
|
||||
|
||||
### 3.2 启动工具箱:
|
||||
`python demo_toolbox.py -d <datasets_root>`
|
||||
@@ -159,35 +102,33 @@
|
||||
|
||||
<img width="1042" alt="d48ea37adf3660e657cfb047c10edbc" src="https://user-images.githubusercontent.com/7423248/134275227-c1ddf154-f118-4b77-8949-8c4c7daf25f0.png">
|
||||
|
||||
### 4. 番外:语音转换Voice Conversion(PPG based)
|
||||
想像柯南拿着变声器然后发出毛利小五郎的声音吗?本项目现基于PPG-VC,引入额外两个模块(PPG extractor + PPG2Mel), 可以实现变声功能。(文档不全,尤其是训练部分,正在努力补充中)
|
||||
#### 4.0 准备环境
|
||||
* 确保项目以上环境已经安装ok,运行`pip install espnet` 来安装剩余的必要包。
|
||||
* 下载以下模型 链接:https://pan.baidu.com/s/1bl_x_DHJSAUyN2fma-Q_Wg
|
||||
提取码:gh41
|
||||
* 24K采样率专用的vocoder(hifigan)到 *vocoder\saved_models\xxx*
|
||||
* 预训练的ppg特征encoder(ppg_extractor)到 *ppg_extractor\saved_models\xxx*
|
||||
* 预训练的PPG2Mel到 *ppg2mel\saved_models\xxx*
|
||||
|
||||
#### 4.1 使用数据集自己训练PPG2Mel模型 (可选)
|
||||
|
||||
* 下载aidatatang_200zh数据集并解压:确保您可以访问 *train* 文件夹中的所有音频文件(如.wav)
|
||||
* 进行音频和梅尔频谱图预处理:
|
||||
`python ./control/cli/pre4ppg.py <datasets_root> -d {dataset} -n {number}`
|
||||
可传入参数:
|
||||
* `-d {dataset}` 指定数据集,支持 aidatatang_200zh, 不传默认为aidatatang_200zh
|
||||
* `-n {number}` 指定并行数,CPU 11700k在8的情况下,需要运行12到18小时!待优化
|
||||
> 假如你下载的 `aidatatang_200zh`文件放在D盘,`train`文件路径为 `D:\data\aidatatang_200zh\corpus\train` , 你的`datasets_root`就是 `D:\data\`
|
||||
|
||||
* 训练合成器, 注意在上一步先下载好`ppg2mel.yaml`, 修改里面的地址指向预训练好的文件夹:
|
||||
`python ./control/cli/ppg2mel_train.py --config .\ppg2mel\saved_models\ppg2mel.yaml --oneshotvc `
|
||||
* 如果想要继续上一次的训练,可以通过`--load .\ppg2mel\saved_models\<old_pt_file>` 参数指定一个预训练模型文件。
|
||||
|
||||
#### 4.2 启动工具箱VC模式
|
||||
您可以尝试使用以下命令:
|
||||
`python demo_toolbox.py -vc -d <datasets_root>`
|
||||
> 请指定一个可用的数据集文件路径,如果有支持的数据集则会自动加载供调试,也同时会作为手动录制音频的存储目录。
|
||||
<img width="971" alt="微信图片_20220305005351" src="https://user-images.githubusercontent.com/7423248/156805733-2b093dbc-d989-4e68-8609-db11f365886a.png">
|
||||
## 文件结构(目标读者:开发者)
|
||||
```
|
||||
├─archived_untest_files 废弃文件
|
||||
├─encoder encoder模型
|
||||
│ ├─data_objects
|
||||
│ └─saved_models 预训练好的模型
|
||||
├─samples 样例语音
|
||||
├─synthesizer synthesizer模型
|
||||
│ ├─models
|
||||
│ ├─saved_models 预训练好的模型
|
||||
│ └─utils 工具类库
|
||||
├─toolbox 图形化工具箱
|
||||
├─utils 工具类库
|
||||
├─vocoder vocoder模型(目前包含hifi-gan、wavrnn)
|
||||
│ ├─hifigan
|
||||
│ ├─saved_models 预训练好的模型
|
||||
│ └─wavernn
|
||||
└─web
|
||||
├─api
|
||||
│ └─Web端接口
|
||||
├─config
|
||||
│ └─ Web端配置文件
|
||||
├─static 前端静态脚本
|
||||
│ └─js
|
||||
├─templates 前端模板
|
||||
└─__init__.py Web端入口文件
|
||||
```
|
||||
|
||||
## 引用及论文
|
||||
> 该库一开始从仅支持英语的[Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) 分叉出来的,鸣谢作者。
|
||||
@@ -196,36 +137,35 @@
|
||||
| --- | ----------- | ----- | --------------------- |
|
||||
| [1803.09017](https://arxiv.org/abs/1803.09017) | GlobalStyleToken (synthesizer)| Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis | 本代码库 |
|
||||
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | 本代码库 |
|
||||
| [2106.02297](https://arxiv.org/abs/2106.02297) | Fre-GAN (vocoder)| Fre-GAN: Adversarial Frequency-consistent Audio Synthesis | 本代码库 |
|
||||
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | SV2TTS | Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis | 本代码库 |
|
||||
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|
||||
|[1703.10135](https://arxiv.org/pdf/1703.10135.pdf) | Tacotron (synthesizer) | Tacotron: Towards End-to-End Speech Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN)
|
||||
|[1710.10467](https://arxiv.org/pdf/1710.10467.pdf) | GE2E (encoder)| Generalized End-To-End Loss for Speaker Verification | 本代码库 |
|
||||
|
||||
## 常见问题(FQ&A)
|
||||
#### 1.数据集在哪里下载?
|
||||
## 常見問題(FQ&A)
|
||||
#### 1.數據集哪裡下載?
|
||||
| 数据集 | OpenSLR地址 | 其他源 (Google Drive, Baidu网盘等) |
|
||||
| --- | ----------- | ---------------|
|
||||
| aidatatang_200zh | [OpenSLR](http://www.openslr.org/62/) | [Google Drive](https://drive.google.com/file/d/110A11KZoVe7vy6kXlLb6zVPLb_J91I_t/view?usp=sharing) |
|
||||
| magicdata | [OpenSLR](http://www.openslr.org/68/) | [Google Drive (Dev set)](https://drive.google.com/file/d/1g5bWRUSNH68ycC6eNvtwh07nX3QhOOlo/view?usp=sharing) |
|
||||
| aishell3 | [OpenSLR](https://www.openslr.org/93/) | [Google Drive](https://drive.google.com/file/d/1shYp_o4Z0X0cZSKQDtFirct2luFUwKzZ/view?usp=sharing) |
|
||||
| data_aishell | [OpenSLR](https://www.openslr.org/33/) | |
|
||||
> 解压 aidatatang_200zh 后,还需将 `aidatatang_200zh\corpus\train`下的文件全选解压缩
|
||||
> 解壓 aidatatang_200zh 後,還需將 `aidatatang_200zh\corpus\train`下的檔案全選解壓縮
|
||||
|
||||
#### 2.`<datasets_root>`是什麼意思?
|
||||
假如数据集路径为 `D:\data\aidatatang_200zh`,那么 `<datasets_root>`就是 `D:\data`
|
||||
假如數據集路徑為 `D:\data\aidatatang_200zh`,那麼 `<datasets_root>`就是 `D:\data`
|
||||
|
||||
#### 3.训练模型显存不足
|
||||
训练合成器时:将 `synthesizer/hparams.py`中的batch_size参数调小
|
||||
#### 3.訓練模型顯存不足
|
||||
訓練合成器時:將 `synthesizer/hparams.py`中的batch_size參數調小
|
||||
```
|
||||
//调整前
|
||||
//調整前
|
||||
tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule
|
||||
(2, 5e-4, 40_000, 12), # (r, lr, step, batch_size)
|
||||
(2, 2e-4, 80_000, 12), #
|
||||
(2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames
|
||||
(2, 3e-5, 320_000, 12), # synthesized for each decoder iteration)
|
||||
(2, 1e-5, 640_000, 12)], # lr = learning rate
|
||||
//调整后
|
||||
//調整後
|
||||
tts_schedule = [(2, 1e-3, 20_000, 8), # Progressive training schedule
|
||||
(2, 5e-4, 40_000, 8), # (r, lr, step, batch_size)
|
||||
(2, 2e-4, 80_000, 8), #
|
||||
@@ -234,15 +174,15 @@ tts_schedule = [(2, 1e-3, 20_000, 8), # Progressive training schedule
|
||||
(2, 1e-5, 640_000, 8)], # lr = learning rate
|
||||
```
|
||||
|
||||
声码器-预处理数据集时:将 `synthesizer/hparams.py`中的batch_size参数调小
|
||||
聲碼器-預處理數據集時:將 `synthesizer/hparams.py`中的batch_size參數調小
|
||||
```
|
||||
//调整前
|
||||
//調整前
|
||||
### Data Preprocessing
|
||||
max_mel_frames = 900,
|
||||
rescale = True,
|
||||
rescaling_max = 0.9,
|
||||
synthesis_batch_size = 16, # For vocoder preprocessing and inference.
|
||||
//调整后
|
||||
//調整後
|
||||
### Data Preprocessing
|
||||
max_mel_frames = 900,
|
||||
rescale = True,
|
||||
@@ -250,16 +190,16 @@ tts_schedule = [(2, 1e-3, 20_000, 8), # Progressive training schedule
|
||||
synthesis_batch_size = 8, # For vocoder preprocessing and inference.
|
||||
```
|
||||
|
||||
声码器-训练声码器时:将 `vocoder/wavernn/hparams.py`中的batch_size参数调小
|
||||
聲碼器-訓練聲碼器時:將 `vocoder/wavernn/hparams.py`中的batch_size參數調小
|
||||
```
|
||||
//调整前
|
||||
//調整前
|
||||
# Training
|
||||
voc_batch_size = 100
|
||||
voc_lr = 1e-4
|
||||
voc_gen_at_checkpoint = 5
|
||||
voc_pad = 2
|
||||
|
||||
//调整后
|
||||
//調整後
|
||||
# Training
|
||||
voc_batch_size = 6
|
||||
voc_lr = 1e-4
|
||||
@@ -268,16 +208,17 @@ voc_pad =2
|
||||
```
|
||||
|
||||
#### 4.碰到`RuntimeError: Error(s) in loading state_dict for Tacotron: size mismatch for encoder.embedding.weight: copying a param with shape torch.Size([70, 512]) from checkpoint, the shape in current model is torch.Size([75, 512]).`
|
||||
请参照 issue [#37](https://github.com/babysor/MockingBird/issues/37)
|
||||
請參照 issue [#37](https://github.com/babysor/MockingBird/issues/37)
|
||||
|
||||
#### 5.如何改善CPU、GPU占用率?
|
||||
视情况调整batch_size参数来改善
|
||||
#### 5.如何改善CPU、GPU佔用率?
|
||||
適情況調整batch_size參數來改善
|
||||
|
||||
#### 6.发生 `页面文件太小,无法完成操作`
|
||||
请参考这篇[文章](https://blog.csdn.net/qq_17755303/article/details/112564030),将虚拟内存更改为100G(102400),例如:文件放置D盘就更改D盘的虚拟内存
|
||||
#### 6.發生 `頁面文件太小,無法完成操作`
|
||||
請參考這篇[文章](https://blog.csdn.net/qq_17755303/article/details/112564030),將虛擬內存更改為100G(102400),例如:档案放置D槽就更改D槽的虚拟内存
|
||||
|
||||
#### 7.什么时候算训练完成?
|
||||
首先一定要出现注意力模型,其次是loss足够低,取决于硬件设备和数据集。拿本人的供参考,我的注意力是在 18k 步之后出现的,并且在 50k 步之后损失变得低于 0.4
|
||||

|
||||
|
||||

|
||||
|
||||
|
||||
79
README.md
79
README.md
@@ -21,7 +21,6 @@
|
||||
## Quick Start
|
||||
|
||||
### 1. Install Requirements
|
||||
#### 1.1 General Setup
|
||||
> Follow the original repo to test if you got all environment ready.
|
||||
**Python 3.7 or higher ** is needed to run the toolbox.
|
||||
|
||||
@@ -30,74 +29,8 @@
|
||||
* Install [ffmpeg](https://ffmpeg.org/download.html#get-packages).
|
||||
* Run `pip install -r requirements.txt` to install the remaining necessary packages.
|
||||
* Install webrtcvad `pip install webrtcvad-wheels`(If you need)
|
||||
|
||||
or
|
||||
- install dependencies with `conda` or `mamba`
|
||||
|
||||
```conda env create -n env_name -f env.yml```
|
||||
|
||||
```mamba env create -n env_name -f env.yml```
|
||||
|
||||
will create a virtual environment where necessary dependencies are installed. Switch to the new environment by `conda activate env_name` and enjoy it.
|
||||
> env.yml only includes the necessary dependencies to run the project,temporarily without monotonic-align. You can check the official website to install the GPU version of pytorch.
|
||||
|
||||
#### 1.2 Setup with a M1 Mac
|
||||
> The following steps are a workaround to directly use the original `demo_toolbox.py`without the changing of codes.
|
||||
>
|
||||
> Since the major issue comes with the PyQt5 packages used in `demo_toolbox.py` not compatible with M1 chips, were one to attempt on training models with the M1 chip, either that person can forgo `demo_toolbox.py`, or one can try the `web.py` in the project.
|
||||
|
||||
##### 1.2.1 Install `PyQt5`, with [ref](https://stackoverflow.com/a/68038451/20455983) here.
|
||||
* Create and open a Rosetta Terminal, with [ref](https://dev.to/courier/tips-and-tricks-to-setup-your-apple-m1-for-development-547g) here.
|
||||
* Use system Python to create a virtual environment for the project
|
||||
```
|
||||
/usr/bin/python3 -m venv /PathToMockingBird/venv
|
||||
source /PathToMockingBird/venv/bin/activate
|
||||
```
|
||||
* Upgrade pip and install `PyQt5`
|
||||
```
|
||||
pip install --upgrade pip
|
||||
pip install pyqt5
|
||||
```
|
||||
##### 1.2.2 Install `pyworld` and `ctc-segmentation`
|
||||
|
||||
> Both packages seem to be unique to this project and are not seen in the original [Real-Time Voice Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) project. When installing with `pip install`, both packages lack wheels so the program tries to directly compile from c code and could not find `Python.h`.
|
||||
|
||||
* Install `pyworld`
|
||||
* `brew install python` `Python.h` can come with Python installed by brew
|
||||
* `export CPLUS_INCLUDE_PATH=/opt/homebrew/Frameworks/Python.framework/Headers` The filepath of brew-installed `Python.h` is unique to M1 MacOS and listed above. One needs to manually add the path to the environment variables.
|
||||
* `pip install pyworld` that should do.
|
||||
|
||||
|
||||
* Install`ctc-segmentation`
|
||||
> Same method does not apply to `ctc-segmentation`, and one needs to compile it from the source code on [github](https://github.com/lumaku/ctc-segmentation).
|
||||
* `git clone https://github.com/lumaku/ctc-segmentation.git`
|
||||
* `cd ctc-segmentation`
|
||||
* `source /PathToMockingBird/venv/bin/activate` If the virtual environment hasn't been deployed, activate it.
|
||||
* `cythonize -3 ctc_segmentation/ctc_segmentation_dyn.pyx`
|
||||
* `/usr/bin/arch -x86_64 python setup.py build` Build with x86 architecture.
|
||||
* `/usr/bin/arch -x86_64 python setup.py install --optimize=1 --skip-build`Install with x86 architecture.
|
||||
|
||||
##### 1.2.3 Other dependencies
|
||||
* `/usr/bin/arch -x86_64 pip install torch torchvision torchaudio` Pip installing `PyTorch` as an example, articulate that it's installed with x86 architecture
|
||||
* `pip install ffmpeg` Install ffmpeg
|
||||
* `pip install -r requirements.txt` Install other requirements.
|
||||
|
||||
##### 1.2.4 Run the Inference Time (with Toolbox)
|
||||
> To run the project on x86 architecture. [ref](https://youtrack.jetbrains.com/issue/PY-46290/Allow-running-Python-under-Rosetta-2-in-PyCharm-for-Apple-Silicon).
|
||||
* `vim /PathToMockingBird/venv/bin/pythonM1` Create an executable file `pythonM1` to condition python interpreter at `/PathToMockingBird/venv/bin`.
|
||||
* Write in the following content:
|
||||
```
|
||||
#!/usr/bin/env zsh
|
||||
mydir=${0:a:h}
|
||||
/usr/bin/arch -x86_64 $mydir/python "$@"
|
||||
```
|
||||
* `chmod +x pythonM1` Set the file as executable.
|
||||
* If using PyCharm IDE, configure project interpreter to `pythonM1`([steps here](https://www.jetbrains.com/help/pycharm/configuring-python-interpreter.html#add-existing-interpreter)), if using command line python, run `/PathToMockingBird/venv/bin/pythonM1 demo_toolbox.py`
|
||||
|
||||
|
||||
> Note that we are using the pretrained encoder/vocoder but synthesizer, since the original model is incompatible with the Chinese sympols. It means the demo_cli is not working at this moment.
|
||||
### 2. Prepare your models
|
||||
> Note that we are using the pretrained encoder/vocoder but not synthesizer, since the original model is incompatible with the Chinese symbols. It means the demo_cli is not working at this moment, so additional synthesizer models are required.
|
||||
|
||||
You can either train your models or use existing ones:
|
||||
|
||||
#### 2.1 Train encoder with your dataset (Optional)
|
||||
@@ -126,8 +59,8 @@ Allowing parameter `--dataset {dataset}` to support aidatatang_200zh, magicdata,
|
||||
| --- | ----------- | ----- |----- |
|
||||
| @author | https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g [Baidu](https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g) 4j5d | | 75k steps trained by multiple datasets
|
||||
| @author | https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw [Baidu](https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw) code:om7f | | 25k steps trained by multiple datasets, only works under version 0.0.1
|
||||
|@FawenYo | https://yisiou-my.sharepoint.com/:u:/g/personal/lawrence_cheng_yisiou_onmicrosoft_com/EWFWDHzee-NNg9TWdKckCc4BC7bK2j9cCbOWn0-_tK0nOg?e=Cc4EFA https://u.teknik.io/AYxWf.pt | [input](https://github.com/babysor/MockingBird/wiki/audio/self_test.mp3) [output](https://github.com/babysor/MockingBird/wiki/audio/export.wav) | 200k steps with local accent of Taiwan, only works under version 0.0.1
|
||||
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ code: 2021 https://www.aliyundrive.com/s/AwPsbo8mcSP code: z2m0 | https://www.bilibili.com/video/BV1uh411B7AD/ | only works under version 0.0.1
|
||||
|@FawenYo | https://drive.google.com/file/d/1H-YGOUHpmqKxJ9FRc6vAjPuqQki24UbC/view?usp=sharing https://u.teknik.io/AYxWf.pt | [input](https://github.com/babysor/MockingBird/wiki/audio/self_test.mp3) [output](https://github.com/babysor/MockingBird/wiki/audio/export.wav) | 200k steps with local accent of Taiwan, only works under version 0.0.1
|
||||
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ code:2021 | https://www.bilibili.com/video/BV1uh411B7AD/ | only works under version 0.0.1
|
||||
|
||||
#### 2.4 Train vocoder (Optional)
|
||||
> note: vocoder has little difference in effect, so you may not need to train a new one.
|
||||
@@ -149,11 +82,6 @@ You can then try to run:`python web.py` and open it in browser, default as `http
|
||||
You can then try the toolbox:
|
||||
`python demo_toolbox.py -d <datasets_root>`
|
||||
|
||||
#### 3.3 Using the command line
|
||||
You can then try the command:
|
||||
`python gen_voice.py <text_file.txt> your_wav_file.wav`
|
||||
you may need to install cn2an by "pip install cn2an" for better digital number result.
|
||||
|
||||
## Reference
|
||||
> This repository is forked from [Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) which only support English.
|
||||
|
||||
@@ -161,7 +89,6 @@ you may need to install cn2an by "pip install cn2an" for better digital number r
|
||||
| --- | ----------- | ----- | --------------------- |
|
||||
| [1803.09017](https://arxiv.org/abs/1803.09017) | GlobalStyleToken (synthesizer)| Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis | This repo |
|
||||
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | This repo |
|
||||
| [2106.02297](https://arxiv.org/abs/2106.02297) | Fre-GAN (vocoder)| Fre-GAN: Adversarial Frequency-consistent Audio Synthesis | This repo |
|
||||
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | **SV2TTS** | **Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis** | This repo |
|
||||
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|
||||
|[1703.10135](https://arxiv.org/pdf/1703.10135.pdf) | Tacotron (synthesizer) | Tacotron: Towards End-to-End Speech Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN)
|
||||
|
||||
43
analysis.py
Normal file
43
analysis.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from scipy.io import wavfile # scipy library to read wav files
|
||||
import numpy as np
|
||||
|
||||
AudioName = "target.wav" # Audio File
|
||||
fs, Audiodata = wavfile.read(AudioName)
|
||||
|
||||
# Plot the audio signal in time
|
||||
import matplotlib.pyplot as plt
|
||||
plt.plot(Audiodata)
|
||||
plt.title('Audio signal in time',size=16)
|
||||
|
||||
# spectrum
|
||||
from scipy.fftpack import fft # fourier transform
|
||||
n = len(Audiodata)
|
||||
AudioFreq = fft(Audiodata)
|
||||
AudioFreq = AudioFreq[0:int(np.ceil((n+1)/2.0))] #Half of the spectrum
|
||||
MagFreq = np.abs(AudioFreq) # Magnitude
|
||||
MagFreq = MagFreq / float(n)
|
||||
# power spectrum
|
||||
MagFreq = MagFreq**2
|
||||
if n % 2 > 0: # ffte odd
|
||||
MagFreq[1:len(MagFreq)] = MagFreq[1:len(MagFreq)] * 2
|
||||
else:# fft even
|
||||
MagFreq[1:len(MagFreq) -1] = MagFreq[1:len(MagFreq) - 1] * 2
|
||||
|
||||
plt.figure()
|
||||
freqAxis = np.arange(0,int(np.ceil((n+1)/2.0)), 1.0) * (fs / n);
|
||||
plt.plot(freqAxis/1000.0, 10*np.log10(MagFreq)) #Power spectrum
|
||||
plt.xlabel('Frequency (kHz)'); plt.ylabel('Power spectrum (dB)');
|
||||
|
||||
|
||||
#Spectrogram
|
||||
from scipy import signal
|
||||
N = 512 #Number of point in the fft
|
||||
f, t, Sxx = signal.spectrogram(Audiodata, fs,window = signal.blackman(N),nfft=N)
|
||||
plt.figure()
|
||||
plt.pcolormesh(t, f,10*np.log10(Sxx)) # dB spectrogram
|
||||
#plt.pcolormesh(t, f,Sxx) # Lineal spectrogram
|
||||
plt.ylabel('Frequency [Hz]')
|
||||
plt.xlabel('Time [seg]')
|
||||
plt.title('Spectrogram with scipy.signal',size=16);
|
||||
|
||||
plt.show()
|
||||
@@ -1,9 +1,9 @@
|
||||
from models.encoder.params_model import model_embedding_size as speaker_embedding_size
|
||||
from encoder.params_model import model_embedding_size as speaker_embedding_size
|
||||
from utils.argutils import print_args
|
||||
from utils.modelutils import check_model_paths
|
||||
from models.synthesizer.inference import Synthesizer
|
||||
from models.encoder import inference as encoder
|
||||
from models.vocoder import inference as vocoder
|
||||
from synthesizer.inference import Synthesizer
|
||||
from encoder import inference as encoder
|
||||
from vocoder import inference as vocoder
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
import sys
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from utils.hparams import HpsYaml
|
||||
from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||
|
||||
# For reproducibility, comment these may speed up training
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def main():
|
||||
# Arguments
|
||||
parser = argparse.ArgumentParser(description=
|
||||
'Training PPG2Mel VC model.')
|
||||
parser.add_argument('--config', type=str,
|
||||
help='Path to experiment config, e.g., config/vc.yaml')
|
||||
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
||||
parser.add_argument('--logdir', default='log/', type=str,
|
||||
help='Logging path.', required=False)
|
||||
parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str,
|
||||
help='Checkpoint path.', required=False)
|
||||
parser.add_argument('--outdir', default='result/', type=str,
|
||||
help='Decode output path.', required=False)
|
||||
parser.add_argument('--load', default=None, type=str,
|
||||
help='Load pre-trained model (for training only)', required=False)
|
||||
parser.add_argument('--warm_start', action='store_true',
|
||||
help='Load model weights only, ignore specified layers.')
|
||||
parser.add_argument('--seed', default=0, type=int,
|
||||
help='Random seed for reproducable results.', required=False)
|
||||
parser.add_argument('--njobs', default=8, type=int,
|
||||
help='Number of threads for dataloader/decoding.', required=False)
|
||||
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
||||
parser.add_argument('--no-pin', action='store_true',
|
||||
help='Disable pin-memory for dataloader')
|
||||
parser.add_argument('--test', action='store_true', help='Test the model.')
|
||||
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
||||
parser.add_argument('--finetune', action='store_true', help='Finetune model')
|
||||
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
|
||||
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
|
||||
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
|
||||
|
||||
###
|
||||
|
||||
paras = parser.parse_args()
|
||||
setattr(paras, 'gpu', not paras.cpu)
|
||||
setattr(paras, 'pin_memory', not paras.no_pin)
|
||||
setattr(paras, 'verbose', not paras.no_msg)
|
||||
# Make the config dict dot visitable
|
||||
config = HpsYaml(paras.config)
|
||||
|
||||
np.random.seed(paras.seed)
|
||||
torch.manual_seed(paras.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(paras.seed)
|
||||
|
||||
print(">>> OneShot VC training ...")
|
||||
mode = "train"
|
||||
solver = Solver(config, paras, mode)
|
||||
solver.load_data()
|
||||
solver.set_model()
|
||||
solver.exec()
|
||||
print(">>> Oneshot VC train finished!")
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,49 +0,0 @@
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
from models.ppg2mel.preprocess import preprocess_dataset
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
recognized_datasets = [
|
||||
"aidatatang_200zh",
|
||||
"aidatatang_200zh_s", # sample
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocesses audio files from datasets, to be used by the "
|
||||
"ppg2mel model for training.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument("datasets_root", type=Path, help=\
|
||||
"Path to the directory containing your datasets.")
|
||||
parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\
|
||||
"Name of the dataset to process, allowing values: aidatatang_200zh.")
|
||||
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
|
||||
"Path to the output directory that will contain the mel spectrograms, the audios and the "
|
||||
"embeds. Defaults to <datasets_root>/PPGVC/ppg2mel/")
|
||||
parser.add_argument("-n", "--n_processes", type=int, default=8, help=\
|
||||
"Number of processes in parallel.")
|
||||
# parser.add_argument("-s", "--skip_existing", action="store_true", help=\
|
||||
# "Whether to overwrite existing files with the same name. Useful if the preprocessing was "
|
||||
# "interrupted. ")
|
||||
# parser.add_argument("--hparams", type=str, default="", help=\
|
||||
# "Hyperparameter overrides as a comma-separated list of name-value pairs")
|
||||
# parser.add_argument("--no_trim", action="store_true", help=\
|
||||
# "Preprocess audio without trimming silences (not recommended).")
|
||||
parser.add_argument("-pf", "--ppg_encoder_model_fpath", type=Path, default="ppg_extractor/saved_models/24epoch.pt", help=\
|
||||
"Path your trained ppg encoder model.")
|
||||
parser.add_argument("-sf", "--speaker_encoder_model", type=Path, default="encoder/saved_models/pretrained_bak_5805000.pt", help=\
|
||||
"Path your trained speaker encoder model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.dataset in recognized_datasets, 'is not supported, file a issue to propose a new one'
|
||||
|
||||
# Create directories
|
||||
assert args.datasets_root.exists()
|
||||
if not hasattr(args, "out_dir"):
|
||||
args.out_dir = args.datasets_root.joinpath("PPGVC", "ppg2mel")
|
||||
args.out_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
preprocess_dataset(**vars(args))
|
||||
@@ -1,66 +0,0 @@
|
||||
import sys
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from utils.hparams import HpsYaml
|
||||
from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||
|
||||
# For reproducibility, comment these may speed up training
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def main():
|
||||
# Arguments
|
||||
parser = argparse.ArgumentParser(description=
|
||||
'Training PPG2Mel VC model.')
|
||||
parser.add_argument('--config', type=str,
|
||||
help='Path to experiment config, e.g., config/vc.yaml')
|
||||
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
||||
parser.add_argument('--logdir', default='log/', type=str,
|
||||
help='Logging path.', required=False)
|
||||
parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str,
|
||||
help='Checkpoint path.', required=False)
|
||||
parser.add_argument('--outdir', default='result/', type=str,
|
||||
help='Decode output path.', required=False)
|
||||
parser.add_argument('--load', default=None, type=str,
|
||||
help='Load pre-trained model (for training only)', required=False)
|
||||
parser.add_argument('--warm_start', action='store_true',
|
||||
help='Load model weights only, ignore specified layers.')
|
||||
parser.add_argument('--seed', default=0, type=int,
|
||||
help='Random seed for reproducable results.', required=False)
|
||||
parser.add_argument('--njobs', default=8, type=int,
|
||||
help='Number of threads for dataloader/decoding.', required=False)
|
||||
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
||||
parser.add_argument('--no-pin', action='store_true',
|
||||
help='Disable pin-memory for dataloader')
|
||||
parser.add_argument('--test', action='store_true', help='Test the model.')
|
||||
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
||||
parser.add_argument('--finetune', action='store_true', help='Finetune model')
|
||||
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
|
||||
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
|
||||
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
|
||||
|
||||
###
|
||||
paras = parser.parse_args()
|
||||
setattr(paras, 'gpu', not paras.cpu)
|
||||
setattr(paras, 'pin_memory', not paras.no_pin)
|
||||
setattr(paras, 'verbose', not paras.no_msg)
|
||||
# Make the config dict dot visitable
|
||||
config = HpsYaml(paras.config)
|
||||
|
||||
np.random.seed(paras.seed)
|
||||
torch.manual_seed(paras.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(paras.seed)
|
||||
|
||||
print(">>> OneShot VC training ...")
|
||||
mode = "train"
|
||||
solver = Solver(config, paras, mode)
|
||||
solver.load_data()
|
||||
solver.set_model()
|
||||
solver.exec()
|
||||
print(">>> Oneshot VC train finished!")
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,151 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from models.encoder import inference as encoder
|
||||
import librosa
|
||||
from scipy.io.wavfile import write
|
||||
import re
|
||||
import numpy as np
|
||||
from control.mkgui.base.components.types import FileContent
|
||||
from models.vocoder.hifigan import inference as gan_vocoder
|
||||
from models.synthesizer.inference import Synthesizer
|
||||
from typing import Any, Tuple
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Constants
|
||||
AUDIO_SAMPLES_DIR = f"data{os.sep}samples{os.sep}"
|
||||
SYN_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}synthesizer"
|
||||
ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
|
||||
VOC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}vocoder"
|
||||
TEMP_SOURCE_AUDIO = f"wavs{os.sep}temp_source.wav"
|
||||
TEMP_RESULT_AUDIO = f"wavs{os.sep}temp_result.wav"
|
||||
if not os.path.isdir("wavs"):
|
||||
os.makedirs("wavs")
|
||||
|
||||
# Load local sample audio as options TODO: load dataset
|
||||
if os.path.isdir(AUDIO_SAMPLES_DIR):
|
||||
audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
|
||||
# Pre-Load models
|
||||
if os.path.isdir(SYN_MODELS_DIRT):
|
||||
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
||||
else:
|
||||
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist. 请将模型文件位置移动到上述位置中进行重试!")
|
||||
|
||||
if os.path.isdir(ENC_MODELS_DIRT):
|
||||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded encoders models: " + str(len(encoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(VOC_MODELS_DIRT):
|
||||
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
|
||||
print("Loaded vocoders models: " + str(len(synthesizers)))
|
||||
else:
|
||||
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
|
||||
class Input(BaseModel):
|
||||
message: str = Field(
|
||||
..., example="欢迎使用工具箱, 现已支持中文输入!", alias="文本内容"
|
||||
)
|
||||
local_audio_file: audio_input_selection = Field(
|
||||
..., alias="选择语音(本地wav)",
|
||||
description="选择本地语音文件."
|
||||
)
|
||||
record_audio_file: FileContent = Field(default=None, alias="录制语音",
|
||||
description="录音.", is_recorder=True, mime_type="audio/wav")
|
||||
upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
|
||||
description="拖拽或点击上传.", mime_type="audio/wav")
|
||||
encoder: encoders = Field(
|
||||
..., alias="编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
synthesizer: synthesizers = Field(
|
||||
..., alias="合成模型",
|
||||
description="选择语音合成模型文件."
|
||||
)
|
||||
vocoder: vocoders = Field(
|
||||
..., alias="语音解码模型",
|
||||
description="选择语音解码模型文件(目前只支持HifiGan类型)."
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: Tuple[AudioEntity, AudioEntity]
|
||||
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
src, result = self.__root__
|
||||
|
||||
streamlit_app.subheader("Synthesized Audio")
|
||||
streamlit_app.audio(result.content, format="audio/wav")
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(src.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Source Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(result.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Result Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
|
||||
|
||||
def synthesize(input: Input) -> Output:
|
||||
"""synthesize(合成)"""
|
||||
# load models
|
||||
encoder.load_model(Path(input.encoder.value))
|
||||
current_synt = Synthesizer(Path(input.synthesizer.value))
|
||||
gan_vocoder.load_model(Path(input.vocoder.value))
|
||||
|
||||
# load file
|
||||
if input.record_audio_file != None:
|
||||
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
||||
f.write(input.record_audio_file.as_bytes())
|
||||
f.seek(0)
|
||||
wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
|
||||
elif input.upload_audio_file != None:
|
||||
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
||||
f.write(input.upload_audio_file.as_bytes())
|
||||
f.seek(0)
|
||||
wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
|
||||
else:
|
||||
wav, sample_rate = librosa.load(input.local_audio_file.value)
|
||||
write(TEMP_SOURCE_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
||||
|
||||
source_spec = Synthesizer.make_spectrogram(wav)
|
||||
|
||||
# preprocess
|
||||
encoder_wav = encoder.preprocess_wav(wav, sample_rate)
|
||||
embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
||||
|
||||
# Load input text
|
||||
texts = filter(None, input.message.split("\n"))
|
||||
punctuation = '!,。、,' # punctuate and split/clean text
|
||||
processed_texts = []
|
||||
for text in texts:
|
||||
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
|
||||
if processed_text:
|
||||
processed_texts.append(processed_text.strip())
|
||||
texts = processed_texts
|
||||
|
||||
# synthesize and vocode
|
||||
embeds = [embed] * len(texts)
|
||||
specs = current_synt.synthesize_spectrograms(texts, embeds)
|
||||
spec = np.concatenate(specs, axis=1)
|
||||
sample_rate = Synthesizer.sample_rate
|
||||
wav, sample_rate = gan_vocoder.infer_waveform(spec)
|
||||
|
||||
# write and output
|
||||
write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
||||
with open(TEMP_SOURCE_AUDIO, "rb") as f:
|
||||
source_file = f.read()
|
||||
with open(TEMP_RESULT_AUDIO, "rb") as f:
|
||||
result_file = f.read()
|
||||
return Output(__root__=(AudioEntity(content=source_file, mel=source_spec), AudioEntity(content=result_file, mel=spec)))
|
||||
@@ -1,166 +0,0 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Tuple
|
||||
|
||||
import librosa
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from scipy.io.wavfile import write
|
||||
|
||||
import models.ppg2mel as Convertor
|
||||
import models.ppg_extractor as Extractor
|
||||
from control.mkgui.base.components.types import FileContent
|
||||
from models.encoder import inference as speacker_encoder
|
||||
from models.synthesizer.inference import Synthesizer
|
||||
from models.vocoder.hifigan import inference as gan_vocoder
|
||||
|
||||
# Constants
|
||||
AUDIO_SAMPLES_DIR = f'data{os.sep}samples{os.sep}'
|
||||
EXT_MODELS_DIRT = f'data{os.sep}ckpt{os.sep}ppg_extractor'
|
||||
CONV_MODELS_DIRT = f'data{os.sep}ckpt{os.sep}ppg2mel'
|
||||
VOC_MODELS_DIRT = f'data{os.sep}ckpt{os.sep}vocoder'
|
||||
TEMP_SOURCE_AUDIO = f'wavs{os.sep}temp_source.wav'
|
||||
TEMP_TARGET_AUDIO = f'wavs{os.sep}temp_target.wav'
|
||||
TEMP_RESULT_AUDIO = f'wavs{os.sep}temp_result.wav'
|
||||
|
||||
# Load local sample audio as options TODO: load dataset
|
||||
if os.path.isdir(AUDIO_SAMPLES_DIR):
|
||||
audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
|
||||
# Pre-Load models
|
||||
if os.path.isdir(EXT_MODELS_DIRT):
|
||||
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded extractor models: " + str(len(extractors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(CONV_MODELS_DIRT):
|
||||
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
|
||||
print("Loaded convertor models: " + str(len(convertors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(VOC_MODELS_DIRT):
|
||||
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
|
||||
print("Loaded vocoders models: " + str(len(vocoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
class Input(BaseModel):
|
||||
local_audio_file: audio_input_selection = Field(
|
||||
..., alias="输入语音(本地wav)",
|
||||
description="选择本地语音文件."
|
||||
)
|
||||
upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
|
||||
description="拖拽或点击上传.", mime_type="audio/wav")
|
||||
local_audio_file_target: audio_input_selection = Field(
|
||||
..., alias="目标语音(本地wav)",
|
||||
description="选择本地语音文件."
|
||||
)
|
||||
upload_audio_file_target: FileContent = Field(default=None, alias="或上传目标语音",
|
||||
description="拖拽或点击上传.", mime_type="audio/wav")
|
||||
extractor: extractors = Field(
|
||||
..., alias="编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
convertor: convertors = Field(
|
||||
..., alias="转换模型",
|
||||
description="选择语音转换模型文件."
|
||||
)
|
||||
vocoder: vocoders = Field(
|
||||
..., alias="语音解码模型",
|
||||
description="选择语音解码模型文件(目前只支持HifiGan类型)."
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: Tuple[AudioEntity, AudioEntity, AudioEntity]
|
||||
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
src, target, result = self.__root__
|
||||
|
||||
streamlit_app.subheader("Synthesized Audio")
|
||||
streamlit_app.audio(result.content, format="audio/wav")
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(src.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Source Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(target.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Target Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(result.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Result Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
|
||||
def convert(input: Input) -> Output:
|
||||
"""convert(转换)"""
|
||||
# load models
|
||||
extractor = Extractor.load_model(Path(input.extractor.value))
|
||||
convertor = Convertor.load_model(Path(input.convertor.value))
|
||||
# current_synt = Synthesizer(Path(input.synthesizer.value))
|
||||
gan_vocoder.load_model(Path(input.vocoder.value))
|
||||
|
||||
# load file
|
||||
if input.upload_audio_file != None:
|
||||
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
||||
f.write(input.upload_audio_file.as_bytes())
|
||||
f.seek(0)
|
||||
src_wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
|
||||
else:
|
||||
src_wav, sample_rate = librosa.load(input.local_audio_file.value)
|
||||
write(TEMP_SOURCE_AUDIO, sample_rate, src_wav) #Make sure we get the correct wav
|
||||
|
||||
if input.upload_audio_file_target != None:
|
||||
with open(TEMP_TARGET_AUDIO, "w+b") as f:
|
||||
f.write(input.upload_audio_file_target.as_bytes())
|
||||
f.seek(0)
|
||||
ref_wav, _ = librosa.load(TEMP_TARGET_AUDIO)
|
||||
else:
|
||||
ref_wav, _ = librosa.load(input.local_audio_file_target.value)
|
||||
write(TEMP_TARGET_AUDIO, sample_rate, ref_wav) #Make sure we get the correct wav
|
||||
|
||||
ppg = extractor.extract_from_wav(src_wav)
|
||||
# Import necessary dependency of Voice Conversion
|
||||
from utils.f0_utils import (compute_f0, compute_mean_std, f02lf0,
|
||||
get_converted_lf0uv)
|
||||
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
|
||||
speacker_encoder.load_model(Path(f"data{os.sep}ckpt{os.sep}encoder{os.sep}pretrained_bak_5805000.pt"))
|
||||
embed = speacker_encoder.embed_utterance(ref_wav)
|
||||
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
|
||||
min_len = min(ppg.shape[1], len(lf0_uv))
|
||||
ppg = ppg[:, :min_len]
|
||||
lf0_uv = lf0_uv[:min_len]
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
_, mel_pred, att_ws = convertor.inference(
|
||||
ppg,
|
||||
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
|
||||
spembs=torch.from_numpy(embed).unsqueeze(0).to(device),
|
||||
)
|
||||
mel_pred= mel_pred.transpose(0, 1)
|
||||
breaks = [mel_pred.shape[1]]
|
||||
mel_pred= mel_pred.detach().cpu().numpy()
|
||||
|
||||
# synthesize and vocode
|
||||
wav, sample_rate = gan_vocoder.infer_waveform(mel_pred)
|
||||
|
||||
# write and output
|
||||
write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
||||
with open(TEMP_SOURCE_AUDIO, "rb") as f:
|
||||
source_file = f.read()
|
||||
with open(TEMP_TARGET_AUDIO, "rb") as f:
|
||||
target_file = f.read()
|
||||
with open(TEMP_RESULT_AUDIO, "rb") as f:
|
||||
result_file = f.read()
|
||||
|
||||
|
||||
return Output(__root__=(AudioEntity(content=source_file, mel=Synthesizer.make_spectrogram(src_wav)), AudioEntity(content=target_file, mel=Synthesizer.make_spectrogram(ref_wav)), AudioEntity(content=result_file, mel=Synthesizer.make_spectrogram(wav))))
|
||||
@@ -1,2 +0,0 @@
|
||||
|
||||
from .core import Opyrator
|
||||
@@ -1 +0,0 @@
|
||||
from .fastapi_app import create_api
|
||||
@@ -1,102 +0,0 @@
|
||||
"""Collection of utilities for FastAPI apps."""
|
||||
|
||||
import inspect
|
||||
from typing import Any, Type
|
||||
|
||||
from fastapi import FastAPI, Form
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def as_form(cls: Type[BaseModel]) -> Any:
|
||||
"""Adds an as_form class method to decorated models.
|
||||
|
||||
The as_form class method can be used with FastAPI endpoints
|
||||
"""
|
||||
new_params = [
|
||||
inspect.Parameter(
|
||||
field.alias,
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
default=(Form(field.default) if not field.required else Form(...)),
|
||||
)
|
||||
for field in cls.__fields__.values()
|
||||
]
|
||||
|
||||
async def _as_form(**data): # type: ignore
|
||||
return cls(**data)
|
||||
|
||||
sig = inspect.signature(_as_form)
|
||||
sig = sig.replace(parameters=new_params)
|
||||
_as_form.__signature__ = sig # type: ignore
|
||||
setattr(cls, "as_form", _as_form)
|
||||
return cls
|
||||
|
||||
|
||||
def patch_fastapi(app: FastAPI) -> None:
|
||||
"""Patch function to allow relative url resolution.
|
||||
|
||||
This patch is required to make fastapi fully functional with a relative url path.
|
||||
This code snippet can be copy-pasted to any Fastapi application.
|
||||
"""
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse
|
||||
|
||||
async def redoc_ui_html(req: Request) -> HTMLResponse:
|
||||
assert app.openapi_url is not None
|
||||
redoc_ui = get_redoc_html(
|
||||
openapi_url="./" + app.openapi_url.lstrip("/"),
|
||||
title=app.title + " - Redoc UI",
|
||||
)
|
||||
|
||||
return HTMLResponse(redoc_ui.body.decode("utf-8"))
|
||||
|
||||
async def swagger_ui_html(req: Request) -> HTMLResponse:
|
||||
assert app.openapi_url is not None
|
||||
swagger_ui = get_swagger_ui_html(
|
||||
openapi_url="./" + app.openapi_url.lstrip("/"),
|
||||
title=app.title + " - Swagger UI",
|
||||
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
|
||||
)
|
||||
|
||||
# insert request interceptor to have all request run on relativ path
|
||||
request_interceptor = (
|
||||
"requestInterceptor: (e) => {"
|
||||
"\n\t\t\tvar url = window.location.origin + window.location.pathname"
|
||||
'\n\t\t\turl = url.substring( 0, url.lastIndexOf( "/" ) + 1);'
|
||||
"\n\t\t\turl = e.url.replace(/http(s)?:\/\/[^/]*\//i, url);" # noqa: W605
|
||||
"\n\t\t\te.contextUrl = url"
|
||||
"\n\t\t\te.url = url"
|
||||
"\n\t\t\treturn e;}"
|
||||
)
|
||||
|
||||
return HTMLResponse(
|
||||
swagger_ui.body.decode("utf-8").replace(
|
||||
"dom_id: '#swagger-ui',",
|
||||
"dom_id: '#swagger-ui',\n\t\t" + request_interceptor + ",",
|
||||
)
|
||||
)
|
||||
|
||||
# remove old docs route and add our patched route
|
||||
routes_new = []
|
||||
for app_route in app.routes:
|
||||
if app_route.path == "/docs": # type: ignore
|
||||
continue
|
||||
|
||||
if app_route.path == "/redoc": # type: ignore
|
||||
continue
|
||||
|
||||
routes_new.append(app_route)
|
||||
|
||||
app.router.routes = routes_new
|
||||
|
||||
assert app.docs_url is not None
|
||||
app.add_route(app.docs_url, swagger_ui_html, include_in_schema=False)
|
||||
assert app.redoc_url is not None
|
||||
app.add_route(app.redoc_url, redoc_ui_html, include_in_schema=False)
|
||||
|
||||
# Make graphql realtive
|
||||
from starlette import graphql
|
||||
|
||||
graphql.GRAPHIQL = graphql.GRAPHIQL.replace(
|
||||
"({{REQUEST_PATH}}", '("." + {{REQUEST_PATH}}'
|
||||
)
|
||||
@@ -1,43 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ScoredLabel(BaseModel):
|
||||
label: str
|
||||
score: float
|
||||
|
||||
|
||||
class ClassificationOutput(BaseModel):
|
||||
__root__: List[ScoredLabel]
|
||||
|
||||
def __iter__(self): # type: ignore
|
||||
return iter(self.__root__)
|
||||
|
||||
def __getitem__(self, item): # type: ignore
|
||||
return self.__root__[item]
|
||||
|
||||
def render_output_ui(self, streamlit) -> None: # type: ignore
|
||||
import plotly.express as px
|
||||
|
||||
sorted_predictions = sorted(
|
||||
[prediction.dict() for prediction in self.__root__],
|
||||
key=lambda k: k["score"],
|
||||
)
|
||||
|
||||
num_labels = len(sorted_predictions)
|
||||
if len(sorted_predictions) > 10:
|
||||
num_labels = streamlit.slider(
|
||||
"Maximum labels to show: ",
|
||||
min_value=1,
|
||||
max_value=len(sorted_predictions),
|
||||
value=len(sorted_predictions),
|
||||
)
|
||||
fig = px.bar(
|
||||
sorted_predictions[len(sorted_predictions) - num_labels :],
|
||||
x="score",
|
||||
y="label",
|
||||
orientation="h",
|
||||
)
|
||||
streamlit.plotly_chart(fig, use_container_width=True)
|
||||
# fig.show()
|
||||
@@ -1,46 +0,0 @@
|
||||
import base64
|
||||
from typing import Any, Dict, overload
|
||||
|
||||
|
||||
class FileContent(str):
|
||||
def as_bytes(self) -> bytes:
|
||||
return base64.b64decode(self, validate=True)
|
||||
|
||||
def as_str(self) -> str:
|
||||
return self.as_bytes().decode()
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(format="byte")
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Any: # type: ignore
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: Any) -> "FileContent":
|
||||
if isinstance(value, FileContent):
|
||||
return value
|
||||
elif isinstance(value, str):
|
||||
return FileContent(value)
|
||||
elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||
return FileContent(base64.b64encode(value).decode())
|
||||
else:
|
||||
raise Exception("Wrong type")
|
||||
|
||||
# # 暂时无法使用,因为浏览器中没有考虑选择文件夹
|
||||
# class DirectoryContent(FileContent):
|
||||
# @classmethod
|
||||
# def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
# field_schema.update(format="path")
|
||||
|
||||
# @classmethod
|
||||
# def validate(cls, value: Any) -> "DirectoryContent":
|
||||
# if isinstance(value, DirectoryContent):
|
||||
# return value
|
||||
# elif isinstance(value, str):
|
||||
# return DirectoryContent(value)
|
||||
# elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||
# return DirectoryContent(base64.b64encode(value).decode())
|
||||
# else:
|
||||
# raise Exception("Wrong type")
|
||||
@@ -1,203 +0,0 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import re
|
||||
from typing import Any, Callable, Type, Union, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
from pydantic.tools import parse_obj_as
|
||||
|
||||
|
||||
def name_to_title(name: str) -> str:
|
||||
"""Converts a camelCase or snake_case name to title case."""
|
||||
# If camelCase -> convert to snake case
|
||||
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
||||
name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
|
||||
# Convert to title case
|
||||
return name.replace("_", " ").strip().title()
|
||||
|
||||
|
||||
def is_compatible_type(type: Type) -> bool:
|
||||
"""Returns `True` if the type is opyrator-compatible."""
|
||||
try:
|
||||
if issubclass(type, BaseModel):
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# valid list type
|
||||
if type.__origin__ is list and issubclass(type.__args__[0], BaseModel):
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_input_type(func: Callable) -> Type:
|
||||
"""Returns the input type of a given function (callable).
|
||||
|
||||
Args:
|
||||
func: The function for which to get the input type.
|
||||
|
||||
Raises:
|
||||
ValueError: If the function does not have a valid input type annotation.
|
||||
"""
|
||||
type_hints = get_type_hints(func)
|
||||
|
||||
if "input" not in type_hints:
|
||||
raise ValueError(
|
||||
"The callable MUST have a parameter with the name `input` with typing annotation. "
|
||||
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
|
||||
)
|
||||
|
||||
input_type = type_hints["input"]
|
||||
|
||||
if not is_compatible_type(input_type):
|
||||
raise ValueError(
|
||||
"The `input` parameter MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
|
||||
)
|
||||
|
||||
# TODO: return warning if more than one input parameters
|
||||
|
||||
return input_type
|
||||
|
||||
|
||||
def get_output_type(func: Callable) -> Type:
|
||||
"""Returns the output type of a given function (callable).
|
||||
|
||||
Args:
|
||||
func: The function for which to get the output type.
|
||||
|
||||
Raises:
|
||||
ValueError: If the function does not have a valid output type annotation.
|
||||
"""
|
||||
type_hints = get_type_hints(func)
|
||||
if "return" not in type_hints:
|
||||
raise ValueError(
|
||||
"The return type of the callable MUST be annotated with type hints."
|
||||
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
|
||||
)
|
||||
|
||||
output_type = type_hints["return"]
|
||||
|
||||
if not is_compatible_type(output_type):
|
||||
raise ValueError(
|
||||
"The return value MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
|
||||
)
|
||||
|
||||
return output_type
|
||||
|
||||
|
||||
def get_callable(import_string: str) -> Callable:
|
||||
"""Import a callable from an string."""
|
||||
callable_seperator = ":"
|
||||
if callable_seperator not in import_string:
|
||||
# Use dot as seperator
|
||||
callable_seperator = "."
|
||||
|
||||
if callable_seperator not in import_string:
|
||||
raise ValueError("The callable path MUST specify the function. ")
|
||||
|
||||
mod_name, callable_name = import_string.rsplit(callable_seperator, 1)
|
||||
mod = importlib.import_module(mod_name)
|
||||
return getattr(mod, callable_name)
|
||||
|
||||
|
||||
class Opyrator:
|
||||
def __init__(self, func: Union[Callable, str]) -> None:
|
||||
if isinstance(func, str):
|
||||
# Try to load the function from a string notion
|
||||
self.function = get_callable(func)
|
||||
else:
|
||||
self.function = func
|
||||
|
||||
self._action = "Execute"
|
||||
self._input_type = None
|
||||
self._output_type = None
|
||||
|
||||
if not callable(self.function):
|
||||
raise ValueError("The provided function parameters is not a callable.")
|
||||
|
||||
if inspect.isclass(self.function):
|
||||
raise ValueError(
|
||||
"The provided callable is an uninitialized Class. This is not allowed."
|
||||
)
|
||||
|
||||
if inspect.isfunction(self.function):
|
||||
# The provided callable is a function
|
||||
self._input_type = get_input_type(self.function)
|
||||
self._output_type = get_output_type(self.function)
|
||||
|
||||
try:
|
||||
# Get name
|
||||
self._name = name_to_title(self.function.__name__)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Get description from function
|
||||
doc_string = inspect.getdoc(self.function)
|
||||
if doc_string:
|
||||
self._action = doc_string
|
||||
except Exception:
|
||||
pass
|
||||
elif hasattr(self.function, "__call__"):
|
||||
# The provided callable is a function
|
||||
self._input_type = get_input_type(self.function.__call__) # type: ignore
|
||||
self._output_type = get_output_type(self.function.__call__) # type: ignore
|
||||
|
||||
try:
|
||||
# Get name
|
||||
self._name = name_to_title(type(self.function).__name__)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Get action from
|
||||
doc_string = inspect.getdoc(self.function.__call__) # type: ignore
|
||||
if doc_string:
|
||||
self._action = doc_string
|
||||
|
||||
if (
|
||||
not self._action
|
||||
or self._action == "Call"
|
||||
):
|
||||
# Get docstring from class instead of __call__ function
|
||||
doc_string = inspect.getdoc(self.function)
|
||||
if doc_string:
|
||||
self._action = doc_string
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Unknown callable type.")
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def action(self) -> str:
|
||||
return self._action
|
||||
|
||||
@property
|
||||
def input_type(self) -> Any:
|
||||
return self._input_type
|
||||
|
||||
@property
|
||||
def output_type(self) -> Any:
|
||||
return self._output_type
|
||||
|
||||
def __call__(self, input: Any, **kwargs: Any) -> Any:
|
||||
|
||||
input_obj = input
|
||||
|
||||
if isinstance(input, str):
|
||||
# Allow json input
|
||||
input_obj = parse_raw_as(self.input_type, input)
|
||||
|
||||
if isinstance(input, dict):
|
||||
# Allow dict input
|
||||
input_obj = parse_obj_as(self.input_type, input)
|
||||
|
||||
return self.function(input_obj, **kwargs)
|
||||
@@ -1 +0,0 @@
|
||||
from .streamlit_ui import render_streamlit_ui
|
||||
@@ -1,135 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def resolve_reference(reference: str, references: Dict) -> Dict:
|
||||
return references[reference.split("/")[-1]]
|
||||
|
||||
|
||||
def get_single_reference_item(property: Dict, references: Dict) -> Dict:
|
||||
# Ref can either be directly in the properties or the first element of allOf
|
||||
reference = property.get("$ref")
|
||||
if reference is None:
|
||||
reference = property["allOf"][0]["$ref"]
|
||||
return resolve_reference(reference, references)
|
||||
|
||||
|
||||
def is_single_string_property(property: Dict) -> bool:
|
||||
return property.get("type") == "string"
|
||||
|
||||
|
||||
def is_single_datetime_property(property: Dict) -> bool:
|
||||
if property.get("type") != "string":
|
||||
return False
|
||||
return property.get("format") in ["date-time", "time", "date"]
|
||||
|
||||
|
||||
def is_single_boolean_property(property: Dict) -> bool:
|
||||
return property.get("type") == "boolean"
|
||||
|
||||
|
||||
def is_single_number_property(property: Dict) -> bool:
|
||||
return property.get("type") in ["integer", "number"]
|
||||
|
||||
|
||||
def is_single_file_property(property: Dict) -> bool:
|
||||
if property.get("type") != "string":
|
||||
return False
|
||||
# TODO: binary?
|
||||
return property.get("format") == "byte"
|
||||
|
||||
def is_single_autio_property(property: Dict) -> bool:
|
||||
if property.get("type") != "string":
|
||||
return False
|
||||
# TODO: binary?
|
||||
return property.get("format") == "bytes"
|
||||
|
||||
|
||||
def is_single_directory_property(property: Dict) -> bool:
|
||||
if property.get("type") != "string":
|
||||
return False
|
||||
return property.get("format") == "path"
|
||||
|
||||
def is_multi_enum_property(property: Dict, references: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
if property.get("uniqueItems") is not True:
|
||||
# Only relevant if it is a set or other datastructures with unique items
|
||||
return False
|
||||
|
||||
try:
|
||||
_ = resolve_reference(property["items"]["$ref"], references)["enum"]
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_single_enum_property(property: Dict, references: Dict) -> bool:
|
||||
try:
|
||||
_ = get_single_reference_item(property, references)["enum"]
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_single_dict_property(property: Dict) -> bool:
|
||||
if property.get("type") != "object":
|
||||
return False
|
||||
return "additionalProperties" in property
|
||||
|
||||
|
||||
def is_single_reference(property: Dict) -> bool:
|
||||
if property.get("type") is not None:
|
||||
return False
|
||||
|
||||
return bool(property.get("$ref"))
|
||||
|
||||
|
||||
def is_multi_file_property(property: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
if property.get("items") is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
# TODO: binary
|
||||
return property["items"]["format"] == "byte"
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_single_object(property: Dict, references: Dict) -> bool:
|
||||
try:
|
||||
object_reference = get_single_reference_item(property, references)
|
||||
if object_reference["type"] != "object":
|
||||
return False
|
||||
return "properties" in object_reference
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_property_list(property: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
if property.get("items") is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
return property["items"]["type"] in ["string", "number", "integer"]
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_object_list_property(property: Dict, references: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
try:
|
||||
object_reference = resolve_reference(property["items"]["$ref"], references)
|
||||
if object_reference["type"] != "object":
|
||||
return False
|
||||
return "properties" in object_reference
|
||||
except Exception:
|
||||
return False
|
||||
@@ -1,933 +0,0 @@
|
||||
import datetime
|
||||
import inspect
|
||||
import mimetypes
|
||||
import sys
|
||||
from os import getcwd, unlink, path
|
||||
from platform import system
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any, Callable, Dict, List, Type
|
||||
from PIL import Image
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, ValidationError, parse_obj_as
|
||||
|
||||
from control.mkgui.base import Opyrator
|
||||
from control.mkgui.base.core import name_to_title
|
||||
from . import schema_utils
|
||||
from .streamlit_utils import CUSTOM_STREAMLIT_CSS
|
||||
|
||||
STREAMLIT_RUNNER_SNIPPET = """
|
||||
from control.mkgui.base.ui import render_streamlit_ui
|
||||
|
||||
import streamlit as st
|
||||
|
||||
# TODO: Make it configurable
|
||||
# Page config can only be setup once
|
||||
st.set_page_config(
|
||||
page_title="MockingBird",
|
||||
page_icon="🧊",
|
||||
layout="wide")
|
||||
|
||||
render_streamlit_ui()
|
||||
"""
|
||||
|
||||
# with st.spinner("Loading MockingBird GUI. Please wait..."):
|
||||
# opyrator = Opyrator("{opyrator_path}")
|
||||
|
||||
|
||||
def launch_ui(port: int = 8501) -> None:
|
||||
with NamedTemporaryFile(
|
||||
suffix=".py", mode="w", encoding="utf-8", delete=False
|
||||
) as f:
|
||||
f.write(STREAMLIT_RUNNER_SNIPPET)
|
||||
f.seek(0)
|
||||
|
||||
import subprocess
|
||||
|
||||
python_path = f'PYTHONPATH="$PYTHONPATH:{getcwd()}"'
|
||||
if system() == "Windows":
|
||||
python_path = f"set PYTHONPATH=%PYTHONPATH%;{getcwd()} &&"
|
||||
subprocess.run(
|
||||
f"""set STREAMLIT_GLOBAL_SHOW_WARNING_ON_DIRECT_EXECUTION=false""",
|
||||
shell=True,
|
||||
)
|
||||
|
||||
subprocess.run(
|
||||
f"""{python_path} "{sys.executable}" -m streamlit run --server.port={port} --server.headless=True --runner.magicEnabled=False --server.maxUploadSize=50 --browser.gatherUsageStats=False {f.name}""",
|
||||
shell=True,
|
||||
)
|
||||
|
||||
f.close()
|
||||
unlink(f.name)
|
||||
|
||||
|
||||
def function_has_named_arg(func: Callable, parameter: str) -> bool:
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
for param in sig.parameters.values():
|
||||
if param.name == "input":
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def has_output_ui_renderer(data_item: BaseModel) -> bool:
|
||||
return hasattr(data_item, "render_output_ui")
|
||||
|
||||
|
||||
def has_input_ui_renderer(input_class: Type[BaseModel]) -> bool:
|
||||
return hasattr(input_class, "render_input_ui")
|
||||
|
||||
|
||||
def is_compatible_audio(mime_type: str) -> bool:
|
||||
return mime_type in ["audio/mpeg", "audio/ogg", "audio/wav"]
|
||||
|
||||
|
||||
def is_compatible_image(mime_type: str) -> bool:
|
||||
return mime_type in ["image/png", "image/jpeg"]
|
||||
|
||||
|
||||
def is_compatible_video(mime_type: str) -> bool:
|
||||
return mime_type in ["video/mp4"]
|
||||
|
||||
|
||||
class InputUI:
|
||||
def __init__(self, session_state, input_class: Type[BaseModel]):
|
||||
self._session_state = session_state
|
||||
self._input_class = input_class
|
||||
|
||||
self._schema_properties = input_class.schema(by_alias=True).get(
|
||||
"properties", {}
|
||||
)
|
||||
self._schema_references = input_class.schema(by_alias=True).get(
|
||||
"definitions", {}
|
||||
)
|
||||
|
||||
def render_ui(self, streamlit_app_root) -> None:
|
||||
if has_input_ui_renderer(self._input_class):
|
||||
# The input model has a rendering function
|
||||
# The rendering also returns the current state of input data
|
||||
self._session_state.input_data = self._input_class.render_input_ui( # type: ignore
|
||||
st, self._session_state.input_data
|
||||
)
|
||||
return
|
||||
|
||||
# print(self._schema_properties)
|
||||
for property_key in self._schema_properties.keys():
|
||||
property = self._schema_properties[property_key]
|
||||
|
||||
if not property.get("title"):
|
||||
# Set property key as fallback title
|
||||
property["title"] = name_to_title(property_key)
|
||||
|
||||
try:
|
||||
if "input_data" in self._session_state:
|
||||
self._store_value(
|
||||
property_key,
|
||||
self._render_property(streamlit_app_root, property_key, property),
|
||||
)
|
||||
except Exception as e:
|
||||
print("Exception!", e)
|
||||
pass
|
||||
|
||||
def _get_default_streamlit_input_kwargs(self, key: str, property: Dict) -> Dict:
|
||||
streamlit_kwargs = {
|
||||
"label": property.get("title"),
|
||||
"key": key,
|
||||
}
|
||||
|
||||
if property.get("description"):
|
||||
streamlit_kwargs["help"] = property.get("description")
|
||||
return streamlit_kwargs
|
||||
|
||||
def _store_value(self, key: str, value: Any) -> None:
|
||||
data_element = self._session_state.input_data
|
||||
key_elements = key.split(".")
|
||||
for i, key_element in enumerate(key_elements):
|
||||
if i == len(key_elements) - 1:
|
||||
# add value to this element
|
||||
data_element[key_element] = value
|
||||
return
|
||||
if key_element not in data_element:
|
||||
data_element[key_element] = {}
|
||||
data_element = data_element[key_element]
|
||||
|
||||
def _get_value(self, key: str) -> Any:
|
||||
data_element = self._session_state.input_data
|
||||
key_elements = key.split(".")
|
||||
for i, key_element in enumerate(key_elements):
|
||||
if i == len(key_elements) - 1:
|
||||
# add value to this element
|
||||
if key_element not in data_element:
|
||||
return None
|
||||
return data_element[key_element]
|
||||
if key_element not in data_element:
|
||||
data_element[key_element] = {}
|
||||
data_element = data_element[key_element]
|
||||
return None
|
||||
|
||||
def _render_single_datetime_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
if property.get("format") == "time":
|
||||
if property.get("default"):
|
||||
try:
|
||||
streamlit_kwargs["value"] = datetime.time.fromisoformat( # type: ignore
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return streamlit_app.time_input(**streamlit_kwargs)
|
||||
elif property.get("format") == "date":
|
||||
if property.get("default"):
|
||||
try:
|
||||
streamlit_kwargs["value"] = datetime.date.fromisoformat( # type: ignore
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return streamlit_app.date_input(**streamlit_kwargs)
|
||||
elif property.get("format") == "date-time":
|
||||
if property.get("default"):
|
||||
try:
|
||||
streamlit_kwargs["value"] = datetime.datetime.fromisoformat( # type: ignore
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
with streamlit_app.container():
|
||||
streamlit_app.subheader(streamlit_kwargs.get("label"))
|
||||
if streamlit_kwargs.get("description"):
|
||||
streamlit_app.text(streamlit_kwargs.get("description"))
|
||||
selected_date = None
|
||||
selected_time = None
|
||||
date_col, time_col = streamlit_app.columns(2)
|
||||
with date_col:
|
||||
date_kwargs = {"label": "Date", "key": key + "-date-input"}
|
||||
if streamlit_kwargs.get("value"):
|
||||
try:
|
||||
date_kwargs["value"] = streamlit_kwargs.get( # type: ignore
|
||||
"value"
|
||||
).date()
|
||||
except Exception:
|
||||
pass
|
||||
selected_date = streamlit_app.date_input(**date_kwargs)
|
||||
|
||||
with time_col:
|
||||
time_kwargs = {"label": "Time", "key": key + "-time-input"}
|
||||
if streamlit_kwargs.get("value"):
|
||||
try:
|
||||
time_kwargs["value"] = streamlit_kwargs.get( # type: ignore
|
||||
"value"
|
||||
).time()
|
||||
except Exception:
|
||||
pass
|
||||
selected_time = streamlit_app.time_input(**time_kwargs)
|
||||
return datetime.datetime.combine(selected_date, selected_time)
|
||||
else:
|
||||
streamlit_app.warning(
|
||||
"Date format is not supported: " + str(property.get("format"))
|
||||
)
|
||||
|
||||
def _render_single_file_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
file_extension = None
|
||||
if "mime_type" in property:
|
||||
file_extension = mimetypes.guess_extension(property["mime_type"])
|
||||
|
||||
if "is_recorder" in property:
|
||||
from audio_recorder_streamlit import audio_recorder
|
||||
audio_bytes = audio_recorder()
|
||||
if audio_bytes:
|
||||
streamlit_app.audio(audio_bytes, format="audio/wav")
|
||||
return audio_bytes
|
||||
|
||||
uploaded_file = streamlit_app.file_uploader(
|
||||
**streamlit_kwargs, accept_multiple_files=False, type=file_extension
|
||||
)
|
||||
if uploaded_file is None:
|
||||
return None
|
||||
|
||||
bytes = uploaded_file.getvalue()
|
||||
if property.get("mime_type"):
|
||||
if is_compatible_audio(property["mime_type"]):
|
||||
# Show audio
|
||||
streamlit_app.audio(bytes, format=property.get("mime_type"))
|
||||
if is_compatible_image(property["mime_type"]):
|
||||
# Show image
|
||||
streamlit_app.image(bytes)
|
||||
if is_compatible_video(property["mime_type"]):
|
||||
# Show video
|
||||
streamlit_app.video(bytes, format=property.get("mime_type"))
|
||||
return bytes
|
||||
|
||||
def _render_single_audio_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
# streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
from audio_recorder_streamlit import audio_recorder
|
||||
audio_bytes = audio_recorder()
|
||||
if audio_bytes:
|
||||
streamlit_app.audio(audio_bytes, format="audio/wav")
|
||||
return audio_bytes
|
||||
|
||||
# file_extension = None
|
||||
# if "mime_type" in property:
|
||||
# file_extension = mimetypes.guess_extension(property["mime_type"])
|
||||
|
||||
# uploaded_file = streamlit_app.file_uploader(
|
||||
# **streamlit_kwargs, accept_multiple_files=False, type=file_extension
|
||||
# )
|
||||
# if uploaded_file is None:
|
||||
# return None
|
||||
|
||||
# bytes = uploaded_file.getvalue()
|
||||
# if property.get("mime_type"):
|
||||
# if is_compatible_audio(property["mime_type"]):
|
||||
# # Show audio
|
||||
# streamlit_app.audio(bytes, format=property.get("mime_type"))
|
||||
# if is_compatible_image(property["mime_type"]):
|
||||
# # Show image
|
||||
# streamlit_app.image(bytes)
|
||||
# if is_compatible_video(property["mime_type"]):
|
||||
# # Show video
|
||||
# streamlit_app.video(bytes, format=property.get("mime_type"))
|
||||
# return bytes
|
||||
|
||||
def _render_single_string_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
if property.get("default"):
|
||||
streamlit_kwargs["value"] = property.get("default")
|
||||
elif property.get("example"):
|
||||
# TODO: also use example for other property types
|
||||
# Use example as value if it is provided
|
||||
streamlit_kwargs["value"] = property.get("example")
|
||||
|
||||
if property.get("maxLength") is not None:
|
||||
streamlit_kwargs["max_chars"] = property.get("maxLength")
|
||||
|
||||
if (
|
||||
property.get("format")
|
||||
or (
|
||||
property.get("maxLength") is not None
|
||||
and int(property.get("maxLength")) < 140 # type: ignore
|
||||
)
|
||||
or property.get("writeOnly")
|
||||
):
|
||||
# If any format is set, use single text input
|
||||
# If max chars is set to less than 140, use single text input
|
||||
# If write only -> password field
|
||||
if property.get("writeOnly"):
|
||||
streamlit_kwargs["type"] = "password"
|
||||
return streamlit_app.text_input(**streamlit_kwargs)
|
||||
else:
|
||||
# Otherwise use multiline text area
|
||||
return streamlit_app.text_area(**streamlit_kwargs)
|
||||
|
||||
def _render_multi_enum_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
reference_item = schema_utils.resolve_reference(
|
||||
property["items"]["$ref"], self._schema_references
|
||||
)
|
||||
# TODO: how to select defaults
|
||||
return streamlit_app.multiselect(
|
||||
**streamlit_kwargs, options=reference_item["enum"]
|
||||
)
|
||||
|
||||
def _render_single_enum_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
reference_item = schema_utils.get_single_reference_item(
|
||||
property, self._schema_references
|
||||
)
|
||||
|
||||
if property.get("default") is not None:
|
||||
try:
|
||||
streamlit_kwargs["index"] = reference_item["enum"].index(
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
# Use default selection
|
||||
pass
|
||||
|
||||
return streamlit_app.selectbox(
|
||||
**streamlit_kwargs, options=reference_item["enum"]
|
||||
)
|
||||
|
||||
def _render_single_dict_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
# Add title and subheader
|
||||
streamlit_app.subheader(property.get("title"))
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
current_dict = self._get_value(key)
|
||||
if not current_dict:
|
||||
current_dict = {}
|
||||
|
||||
key_col, value_col = streamlit_app.columns(2)
|
||||
|
||||
with key_col:
|
||||
updated_key = streamlit_app.text_input(
|
||||
"Key", value="", key=key + "-new-key"
|
||||
)
|
||||
|
||||
with value_col:
|
||||
# TODO: also add boolean?
|
||||
value_kwargs = {"label": "Value", "key": key + "-new-value"}
|
||||
if property["additionalProperties"].get("type") == "integer":
|
||||
value_kwargs["value"] = 0 # type: ignore
|
||||
updated_value = streamlit_app.number_input(**value_kwargs)
|
||||
elif property["additionalProperties"].get("type") == "number":
|
||||
value_kwargs["value"] = 0.0 # type: ignore
|
||||
value_kwargs["format"] = "%f"
|
||||
updated_value = streamlit_app.number_input(**value_kwargs)
|
||||
else:
|
||||
value_kwargs["value"] = ""
|
||||
updated_value = streamlit_app.text_input(**value_kwargs)
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
with streamlit_app.container():
|
||||
clear_col, add_col = streamlit_app.columns([1, 2])
|
||||
|
||||
with clear_col:
|
||||
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
||||
current_dict = {}
|
||||
|
||||
with add_col:
|
||||
if (
|
||||
streamlit_app.button("Add Item", key=key + "-add-item")
|
||||
and updated_key
|
||||
):
|
||||
current_dict[updated_key] = updated_value
|
||||
|
||||
streamlit_app.write(current_dict)
|
||||
|
||||
return current_dict
|
||||
|
||||
def _render_single_reference(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
reference_item = schema_utils.get_single_reference_item(
|
||||
property, self._schema_references
|
||||
)
|
||||
return self._render_property(streamlit_app, key, reference_item)
|
||||
|
||||
def _render_multi_file_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
file_extension = None
|
||||
if "mime_type" in property:
|
||||
file_extension = mimetypes.guess_extension(property["mime_type"])
|
||||
|
||||
uploaded_files = streamlit_app.file_uploader(
|
||||
**streamlit_kwargs, accept_multiple_files=True, type=file_extension
|
||||
)
|
||||
uploaded_files_bytes = []
|
||||
if uploaded_files:
|
||||
for uploaded_file in uploaded_files:
|
||||
uploaded_files_bytes.append(uploaded_file.read())
|
||||
return uploaded_files_bytes
|
||||
|
||||
def _render_single_boolean_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
if property.get("default"):
|
||||
streamlit_kwargs["value"] = property.get("default")
|
||||
return streamlit_app.checkbox(**streamlit_kwargs)
|
||||
|
||||
def _render_single_number_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
number_transform = int
|
||||
if property.get("type") == "number":
|
||||
number_transform = float # type: ignore
|
||||
streamlit_kwargs["format"] = "%f"
|
||||
|
||||
if "multipleOf" in property:
|
||||
# Set stepcount based on multiple of parameter
|
||||
streamlit_kwargs["step"] = number_transform(property["multipleOf"])
|
||||
elif number_transform == int:
|
||||
# Set step size to 1 as default
|
||||
streamlit_kwargs["step"] = 1
|
||||
elif number_transform == float:
|
||||
# Set step size to 0.01 as default
|
||||
# TODO: adapt to default value
|
||||
streamlit_kwargs["step"] = 0.01
|
||||
|
||||
if "minimum" in property:
|
||||
streamlit_kwargs["min_value"] = number_transform(property["minimum"])
|
||||
if "exclusiveMinimum" in property:
|
||||
streamlit_kwargs["min_value"] = number_transform(
|
||||
property["exclusiveMinimum"] + streamlit_kwargs["step"]
|
||||
)
|
||||
if "maximum" in property:
|
||||
streamlit_kwargs["max_value"] = number_transform(property["maximum"])
|
||||
|
||||
if "exclusiveMaximum" in property:
|
||||
streamlit_kwargs["max_value"] = number_transform(
|
||||
property["exclusiveMaximum"] - streamlit_kwargs["step"]
|
||||
)
|
||||
|
||||
if property.get("default") is not None:
|
||||
streamlit_kwargs["value"] = number_transform(property.get("default")) # type: ignore
|
||||
else:
|
||||
if "min_value" in streamlit_kwargs:
|
||||
streamlit_kwargs["value"] = streamlit_kwargs["min_value"]
|
||||
elif number_transform == int:
|
||||
streamlit_kwargs["value"] = 0
|
||||
else:
|
||||
# Set default value to step
|
||||
streamlit_kwargs["value"] = number_transform(streamlit_kwargs["step"])
|
||||
|
||||
if "min_value" in streamlit_kwargs and "max_value" in streamlit_kwargs:
|
||||
# TODO: Only if less than X steps
|
||||
return streamlit_app.slider(**streamlit_kwargs)
|
||||
else:
|
||||
return streamlit_app.number_input(**streamlit_kwargs)
|
||||
|
||||
def _render_object_input(self, streamlit_app: st, key: str, property: Dict) -> Any:
|
||||
properties = property["properties"]
|
||||
object_inputs = {}
|
||||
for property_key in properties:
|
||||
property = properties[property_key]
|
||||
if not property.get("title"):
|
||||
# Set property key as fallback title
|
||||
property["title"] = name_to_title(property_key)
|
||||
# construct full key based on key parts -> required later to get the value
|
||||
full_key = key + "." + property_key
|
||||
object_inputs[property_key] = self._render_property(
|
||||
streamlit_app, full_key, property
|
||||
)
|
||||
return object_inputs
|
||||
|
||||
def _render_single_object_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
# Add title and subheader
|
||||
title = property.get("title")
|
||||
streamlit_app.subheader(title)
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
object_reference = schema_utils.get_single_reference_item(
|
||||
property, self._schema_references
|
||||
)
|
||||
return self._render_object_input(streamlit_app, key, object_reference)
|
||||
|
||||
def _render_property_list_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
# Add title and subheader
|
||||
streamlit_app.subheader(property.get("title"))
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
current_list = self._get_value(key)
|
||||
if not current_list:
|
||||
current_list = []
|
||||
|
||||
value_kwargs = {"label": "Value", "key": key + "-new-value"}
|
||||
if property["items"]["type"] == "integer":
|
||||
value_kwargs["value"] = 0 # type: ignore
|
||||
new_value = streamlit_app.number_input(**value_kwargs)
|
||||
elif property["items"]["type"] == "number":
|
||||
value_kwargs["value"] = 0.0 # type: ignore
|
||||
value_kwargs["format"] = "%f"
|
||||
new_value = streamlit_app.number_input(**value_kwargs)
|
||||
else:
|
||||
value_kwargs["value"] = ""
|
||||
new_value = streamlit_app.text_input(**value_kwargs)
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
with streamlit_app.container():
|
||||
clear_col, add_col = streamlit_app.columns([1, 2])
|
||||
|
||||
with clear_col:
|
||||
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
||||
current_list = []
|
||||
|
||||
with add_col:
|
||||
if (
|
||||
streamlit_app.button("Add Item", key=key + "-add-item")
|
||||
and new_value is not None
|
||||
):
|
||||
current_list.append(new_value)
|
||||
|
||||
streamlit_app.write(current_list)
|
||||
|
||||
return current_list
|
||||
|
||||
def _render_object_list_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
# TODO: support max_items, and min_items properties
|
||||
|
||||
# Add title and subheader
|
||||
streamlit_app.subheader(property.get("title"))
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
current_list = self._get_value(key)
|
||||
if not current_list:
|
||||
current_list = []
|
||||
|
||||
object_reference = schema_utils.resolve_reference(
|
||||
property["items"]["$ref"], self._schema_references
|
||||
)
|
||||
input_data = self._render_object_input(streamlit_app, key, object_reference)
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
with streamlit_app.container():
|
||||
clear_col, add_col = streamlit_app.columns([1, 2])
|
||||
|
||||
with clear_col:
|
||||
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
||||
current_list = []
|
||||
|
||||
with add_col:
|
||||
if (
|
||||
streamlit_app.button("Add Item", key=key + "-add-item")
|
||||
and input_data
|
||||
):
|
||||
current_list.append(input_data)
|
||||
|
||||
streamlit_app.write(current_list)
|
||||
return current_list
|
||||
|
||||
def _render_property(self, streamlit_app: st, key: str, property: Dict) -> Any:
|
||||
if schema_utils.is_single_enum_property(property, self._schema_references):
|
||||
return self._render_single_enum_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_multi_enum_property(property, self._schema_references):
|
||||
return self._render_multi_enum_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_file_property(property):
|
||||
return self._render_single_file_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_multi_file_property(property):
|
||||
return self._render_multi_file_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_datetime_property(property):
|
||||
return self._render_single_datetime_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_boolean_property(property):
|
||||
return self._render_single_boolean_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_dict_property(property):
|
||||
return self._render_single_dict_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_number_property(property):
|
||||
return self._render_single_number_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_string_property(property):
|
||||
return self._render_single_string_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_object(property, self._schema_references):
|
||||
return self._render_single_object_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_object_list_property(property, self._schema_references):
|
||||
return self._render_object_list_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_property_list(property):
|
||||
return self._render_property_list_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_reference(property):
|
||||
return self._render_single_reference(streamlit_app, key, property)
|
||||
|
||||
streamlit_app.warning(
|
||||
"The type of the following property is currently not supported: "
|
||||
+ str(property.get("title"))
|
||||
)
|
||||
raise Exception("Unsupported property")
|
||||
|
||||
|
||||
class OutputUI:
|
||||
def __init__(self, output_data: Any, input_data: Any):
|
||||
self._output_data = output_data
|
||||
self._input_data = input_data
|
||||
|
||||
def render_ui(self, streamlit_app) -> None:
|
||||
try:
|
||||
if isinstance(self._output_data, BaseModel):
|
||||
self._render_single_output(streamlit_app, self._output_data)
|
||||
return
|
||||
if type(self._output_data) == list:
|
||||
self._render_list_output(streamlit_app, self._output_data)
|
||||
return
|
||||
except Exception as ex:
|
||||
streamlit_app.exception(ex)
|
||||
# Fallback to
|
||||
streamlit_app.json(jsonable_encoder(self._output_data))
|
||||
|
||||
def _render_single_text_property(
|
||||
self, streamlit: st, property_schema: Dict, value: Any
|
||||
) -> None:
|
||||
# Add title and subheader
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
if value is None or value == "":
|
||||
streamlit.info("No value returned!")
|
||||
else:
|
||||
streamlit.code(str(value), language="plain")
|
||||
|
||||
def _render_single_file_property(
|
||||
self, streamlit: st, property_schema: Dict, value: Any
|
||||
) -> None:
|
||||
# Add title and subheader
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
if value is None or value == "":
|
||||
streamlit.info("No value returned!")
|
||||
else:
|
||||
# TODO: Detect if it is a FileContent instance
|
||||
# TODO: detect if it is base64
|
||||
file_extension = ""
|
||||
if "mime_type" in property_schema:
|
||||
mime_type = property_schema["mime_type"]
|
||||
file_extension = mimetypes.guess_extension(mime_type) or ""
|
||||
|
||||
if is_compatible_audio(mime_type):
|
||||
streamlit.audio(value.as_bytes(), format=mime_type)
|
||||
return
|
||||
|
||||
if is_compatible_image(mime_type):
|
||||
streamlit.image(value.as_bytes())
|
||||
return
|
||||
|
||||
if is_compatible_video(mime_type):
|
||||
streamlit.video(value.as_bytes(), format=mime_type)
|
||||
return
|
||||
|
||||
filename = (
|
||||
(property_schema["title"] + file_extension)
|
||||
.lower()
|
||||
.strip()
|
||||
.replace(" ", "-")
|
||||
)
|
||||
streamlit.markdown(
|
||||
f'<a href="data:application/octet-stream;base64,{value}" download="{filename}"><input type="button" value="Download File"></a>',
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
def _render_single_complex_property(
|
||||
self, streamlit: st, property_schema: Dict, value: Any
|
||||
) -> None:
|
||||
# Add title and subheader
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
|
||||
streamlit.json(jsonable_encoder(value))
|
||||
|
||||
def _render_single_output(self, streamlit: st, output_data: BaseModel) -> None:
|
||||
try:
|
||||
if has_output_ui_renderer(output_data):
|
||||
if function_has_named_arg(output_data.render_output_ui, "input"): # type: ignore
|
||||
# render method also requests the input data
|
||||
output_data.render_output_ui(streamlit, input=self._input_data) # type: ignore
|
||||
else:
|
||||
output_data.render_output_ui(streamlit) # type: ignore
|
||||
return
|
||||
except Exception:
|
||||
# Use default auto-generation methods if the custom rendering throws an exception
|
||||
logger.exception(
|
||||
"Failed to execute custom render_output_ui function. Using auto-generation instead"
|
||||
)
|
||||
|
||||
model_schema = output_data.schema(by_alias=False)
|
||||
model_properties = model_schema.get("properties")
|
||||
definitions = model_schema.get("definitions")
|
||||
|
||||
if model_properties:
|
||||
for property_key in output_data.__dict__:
|
||||
property_schema = model_properties.get(property_key)
|
||||
if not property_schema.get("title"):
|
||||
# Set property key as fallback title
|
||||
property_schema["title"] = property_key
|
||||
|
||||
output_property_value = output_data.__dict__[property_key]
|
||||
|
||||
if has_output_ui_renderer(output_property_value):
|
||||
output_property_value.render_output_ui(streamlit) # type: ignore
|
||||
continue
|
||||
|
||||
if isinstance(output_property_value, BaseModel):
|
||||
# Render output recursivly
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
self._render_single_output(streamlit, output_property_value)
|
||||
continue
|
||||
|
||||
if property_schema:
|
||||
if schema_utils.is_single_file_property(property_schema):
|
||||
self._render_single_file_property(
|
||||
streamlit, property_schema, output_property_value
|
||||
)
|
||||
continue
|
||||
|
||||
if (
|
||||
schema_utils.is_single_string_property(property_schema)
|
||||
or schema_utils.is_single_number_property(property_schema)
|
||||
or schema_utils.is_single_datetime_property(property_schema)
|
||||
or schema_utils.is_single_boolean_property(property_schema)
|
||||
):
|
||||
self._render_single_text_property(
|
||||
streamlit, property_schema, output_property_value
|
||||
)
|
||||
continue
|
||||
if definitions and schema_utils.is_single_enum_property(
|
||||
property_schema, definitions
|
||||
):
|
||||
self._render_single_text_property(
|
||||
streamlit, property_schema, output_property_value.value
|
||||
)
|
||||
continue
|
||||
|
||||
# TODO: render dict as table
|
||||
|
||||
self._render_single_complex_property(
|
||||
streamlit, property_schema, output_property_value
|
||||
)
|
||||
return
|
||||
|
||||
def _render_list_output(self, streamlit: st, output_data: List) -> None:
|
||||
try:
|
||||
data_items: List = []
|
||||
for data_item in output_data:
|
||||
if has_output_ui_renderer(data_item):
|
||||
# Render using the render function
|
||||
data_item.render_output_ui(streamlit) # type: ignore
|
||||
continue
|
||||
data_items.append(data_item.dict())
|
||||
# Try to show as dataframe
|
||||
streamlit.table(pd.DataFrame(data_items))
|
||||
except Exception:
|
||||
# Fallback to
|
||||
streamlit.json(jsonable_encoder(output_data))
|
||||
|
||||
|
||||
def getOpyrator(mode: str) -> Opyrator:
|
||||
if mode == None or mode.startswith('VC'):
|
||||
from control.mkgui.app_vc import convert
|
||||
return Opyrator(convert)
|
||||
if mode == None or mode.startswith('预处理'):
|
||||
from control.mkgui.preprocess import preprocess
|
||||
return Opyrator(preprocess)
|
||||
if mode == None or mode.startswith('模型训练'):
|
||||
from control.mkgui.train import train
|
||||
return Opyrator(train)
|
||||
if mode == None or mode.startswith('模型训练(VC)'):
|
||||
from control.mkgui.train_vc import train_vc
|
||||
return Opyrator(train_vc)
|
||||
from control.mkgui.app import synthesize
|
||||
return Opyrator(synthesize)
|
||||
|
||||
def render_streamlit_ui() -> None:
|
||||
# init
|
||||
session_state = st.session_state
|
||||
session_state.input_data = {}
|
||||
# Add custom css settings
|
||||
st.markdown(f"<style>{CUSTOM_STREAMLIT_CSS}</style>", unsafe_allow_html=True)
|
||||
|
||||
with st.spinner("Loading MockingBird GUI. Please wait..."):
|
||||
session_state.mode = st.sidebar.selectbox(
|
||||
'模式选择',
|
||||
( "AI拟音", "VC拟音", "预处理", "模型训练", "模型训练(VC)")
|
||||
)
|
||||
if "mode" in session_state:
|
||||
mode = session_state.mode
|
||||
else:
|
||||
mode = ""
|
||||
opyrator = getOpyrator(mode)
|
||||
title = opyrator.name + mode
|
||||
|
||||
col1, col2, _ = st.columns(3)
|
||||
col2.title(title)
|
||||
col2.markdown("欢迎使用MockingBird Web 2")
|
||||
|
||||
image = Image.open(path.join('control','mkgui', 'static', 'mb.png'))
|
||||
col1.image(image)
|
||||
|
||||
st.markdown("---")
|
||||
left, right = st.columns([0.4, 0.6])
|
||||
|
||||
with left:
|
||||
st.header("Control 控制")
|
||||
# if session_state.mode in ["AI拟音", "VC拟音"] :
|
||||
# from audiorecorder import audiorecorder
|
||||
# audio = audiorecorder("Click to record", "Recording...")
|
||||
# if len(audio) > 0:
|
||||
# # To play audio in frontend:
|
||||
# st.audio(audio.tobytes())
|
||||
|
||||
InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(st)
|
||||
execute_selected = st.button(opyrator.action)
|
||||
if execute_selected:
|
||||
with st.spinner("Executing operation. Please wait..."):
|
||||
try:
|
||||
input_data_obj = parse_obj_as(
|
||||
opyrator.input_type, session_state.input_data
|
||||
)
|
||||
session_state.output_data = opyrator(input=input_data_obj)
|
||||
session_state.latest_operation_input = input_data_obj # should this really be saved as additional session object?
|
||||
except ValidationError as ex:
|
||||
st.error(ex)
|
||||
else:
|
||||
# st.success("Operation executed successfully.")
|
||||
pass
|
||||
|
||||
with right:
|
||||
st.header("Result 结果")
|
||||
if 'output_data' in session_state:
|
||||
OutputUI(
|
||||
session_state.output_data, session_state.latest_operation_input
|
||||
).render_ui(st)
|
||||
if st.button("Clear"):
|
||||
# Clear all state
|
||||
for key in st.session_state.keys():
|
||||
del st.session_state[key]
|
||||
session_state.input_data = {}
|
||||
st.experimental_rerun()
|
||||
else:
|
||||
# placeholder
|
||||
st.caption("请使用左侧控制板进行输入并运行获得结果")
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
CUSTOM_STREAMLIT_CSS = """
|
||||
div[data-testid="stBlock"] button {
|
||||
width: 100% !important;
|
||||
margin-bottom: 20px !important;
|
||||
border-color: #bfbfbf !important;
|
||||
}
|
||||
section[data-testid="stSidebar"] div {
|
||||
max-width: 10rem;
|
||||
}
|
||||
pre code {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
"""
|
||||
@@ -1,96 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from typing import Any, Tuple
|
||||
|
||||
|
||||
# Constants
|
||||
EXT_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg_extractor"
|
||||
ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
|
||||
|
||||
|
||||
if os.path.isdir(EXT_MODELS_DIRT):
|
||||
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded extractor models: " + str(len(extractors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(ENC_MODELS_DIRT):
|
||||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded encoders models: " + str(len(encoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
class Model(str, Enum):
|
||||
VC_PPG2MEL = "ppg2mel"
|
||||
|
||||
class Dataset(str, Enum):
|
||||
AIDATATANG_200ZH = "aidatatang_200zh"
|
||||
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
|
||||
|
||||
class Input(BaseModel):
|
||||
# def render_input_ui(st, input) -> Dict:
|
||||
# input["selected_dataset"] = st.selectbox(
|
||||
# '选择数据集',
|
||||
# ("aidatatang_200zh", "aidatatang_200zh_s")
|
||||
# )
|
||||
# return input
|
||||
model: Model = Field(
|
||||
Model.VC_PPG2MEL, title="目标模型",
|
||||
)
|
||||
dataset: Dataset = Field(
|
||||
Dataset.AIDATATANG_200ZH, title="数据集选择",
|
||||
)
|
||||
datasets_root: str = Field(
|
||||
..., alias="数据集根目录", description="输入数据集根目录(相对/绝对)",
|
||||
format=True,
|
||||
example="..\\trainning_data\\"
|
||||
)
|
||||
output_root: str = Field(
|
||||
..., alias="输出根目录", description="输出结果根目录(相对/绝对)",
|
||||
format=True,
|
||||
example="..\\trainning_data\\"
|
||||
)
|
||||
n_processes: int = Field(
|
||||
2, alias="处理线程数", description="根据CPU线程数来设置",
|
||||
le=32, ge=1
|
||||
)
|
||||
extractor: extractors = Field(
|
||||
..., alias="特征提取模型",
|
||||
description="选择PPG特征提取模型文件."
|
||||
)
|
||||
encoder: encoders = Field(
|
||||
..., alias="语音编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: Tuple[str, int]
|
||||
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
sr, count = self.__root__
|
||||
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
|
||||
|
||||
def preprocess(input: Input) -> Output:
|
||||
"""Preprocess(预处理)"""
|
||||
finished = 0
|
||||
if input.model == Model.VC_PPG2MEL:
|
||||
from models.ppg2mel.preprocess import preprocess_dataset
|
||||
finished = preprocess_dataset(
|
||||
datasets_root=Path(input.datasets_root),
|
||||
dataset=input.dataset,
|
||||
out_dir=Path(input.output_root),
|
||||
n_processes=input.n_processes,
|
||||
ppg_encoder_model_fpath=Path(input.extractor.value),
|
||||
speaker_encoder_model=Path(input.encoder.value)
|
||||
)
|
||||
# TODO: pass useful return code
|
||||
return Output(__root__=(input.dataset, finished))
|
||||
@@ -1,106 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from models.synthesizer.hparams import hparams
|
||||
from models.synthesizer.train import train as synt_train
|
||||
|
||||
# Constants
|
||||
SYN_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}synthesizer"
|
||||
ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
|
||||
|
||||
|
||||
# EXT_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg_extractor"
|
||||
# CONV_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg2mel"
|
||||
# ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
|
||||
|
||||
# Pre-Load models
|
||||
if os.path.isdir(SYN_MODELS_DIRT):
|
||||
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
||||
else:
|
||||
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(ENC_MODELS_DIRT):
|
||||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded encoders models: " + str(len(encoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
class Model(str, Enum):
|
||||
DEFAULT = "default"
|
||||
|
||||
class Input(BaseModel):
|
||||
model: Model = Field(
|
||||
Model.DEFAULT, title="模型类型",
|
||||
)
|
||||
# datasets_root: str = Field(
|
||||
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
|
||||
# format=True,
|
||||
# example="..\\trainning_data\\"
|
||||
# )
|
||||
input_root: str = Field(
|
||||
..., alias="输入目录", description="预处理数据根目录",
|
||||
format=True,
|
||||
example=f"..{os.sep}audiodata{os.sep}SV2TTS{os.sep}synthesizer"
|
||||
)
|
||||
run_id: str = Field(
|
||||
"", alias="新模型名/运行ID", description="使用新ID进行重新训练,否则选择下面的模型进行继续训练",
|
||||
)
|
||||
synthesizer: synthesizers = Field(
|
||||
..., alias="已有合成模型",
|
||||
description="选择语音合成模型文件."
|
||||
)
|
||||
gpu: bool = Field(
|
||||
True, alias="GPU训练", description="选择“是”,则使用GPU训练",
|
||||
)
|
||||
verbose: bool = Field(
|
||||
True, alias="打印详情", description="选择“是”,输出更多详情",
|
||||
)
|
||||
encoder: encoders = Field(
|
||||
..., alias="语音编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
save_every: int = Field(
|
||||
1000, alias="更新间隔", description="每隔n步则更新一次模型",
|
||||
)
|
||||
backup_every: int = Field(
|
||||
10000, alias="保存间隔", description="每隔n步则保存一次模型",
|
||||
)
|
||||
log_every: int = Field(
|
||||
500, alias="打印间隔", description="每隔n步则打印一次训练统计",
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: int
|
||||
|
||||
def render_output_ui(self, streamlit_app) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
streamlit_app.subheader(f"Training started with code: {self.__root__}")
|
||||
|
||||
def train(input: Input) -> Output:
|
||||
"""Train(训练)"""
|
||||
|
||||
print(">>> Start training ...")
|
||||
force_restart = len(input.run_id) > 0
|
||||
if not force_restart:
|
||||
input.run_id = Path(input.synthesizer.value).name.split('.')[0]
|
||||
|
||||
synt_train(
|
||||
input.run_id,
|
||||
input.input_root,
|
||||
f"data{os.sep}ckpt{os.sep}synthesizer",
|
||||
input.save_every,
|
||||
input.backup_every,
|
||||
input.log_every,
|
||||
force_restart,
|
||||
hparams
|
||||
)
|
||||
return Output(__root__=0)
|
||||
@@ -1,155 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from typing import Any, Tuple
|
||||
import numpy as np
|
||||
from utils.hparams import HpsYaml
|
||||
from utils.util import AttrDict
|
||||
import torch
|
||||
|
||||
# Constants
|
||||
EXT_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg_extractor"
|
||||
CONV_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg2mel"
|
||||
ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder"
|
||||
|
||||
|
||||
if os.path.isdir(EXT_MODELS_DIRT):
|
||||
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded extractor models: " + str(len(extractors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(CONV_MODELS_DIRT):
|
||||
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
|
||||
print("Loaded convertor models: " + str(len(convertors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(ENC_MODELS_DIRT):
|
||||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded encoders models: " + str(len(encoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
class Model(str, Enum):
|
||||
VC_PPG2MEL = "ppg2mel"
|
||||
|
||||
class Dataset(str, Enum):
|
||||
AIDATATANG_200ZH = "aidatatang_200zh"
|
||||
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
|
||||
|
||||
class Input(BaseModel):
|
||||
# def render_input_ui(st, input) -> Dict:
|
||||
# input["selected_dataset"] = st.selectbox(
|
||||
# '选择数据集',
|
||||
# ("aidatatang_200zh", "aidatatang_200zh_s")
|
||||
# )
|
||||
# return input
|
||||
model: Model = Field(
|
||||
Model.VC_PPG2MEL, title="模型类型",
|
||||
)
|
||||
# datasets_root: str = Field(
|
||||
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
|
||||
# format=True,
|
||||
# example="..\\trainning_data\\"
|
||||
# )
|
||||
output_root: str = Field(
|
||||
..., alias="输出目录(可选)", description="建议不填,保持默认",
|
||||
format=True,
|
||||
example=""
|
||||
)
|
||||
continue_mode: bool = Field(
|
||||
True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练",
|
||||
)
|
||||
gpu: bool = Field(
|
||||
True, alias="GPU训练", description="选择“是”,则使用GPU训练",
|
||||
)
|
||||
verbose: bool = Field(
|
||||
True, alias="打印详情", description="选择“是”,输出更多详情",
|
||||
)
|
||||
# TODO: Move to hiden fields by default
|
||||
convertor: convertors = Field(
|
||||
..., alias="转换模型",
|
||||
description="选择语音转换模型文件."
|
||||
)
|
||||
extractor: extractors = Field(
|
||||
..., alias="特征提取模型",
|
||||
description="选择PPG特征提取模型文件."
|
||||
)
|
||||
encoder: encoders = Field(
|
||||
..., alias="语音编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
njobs: int = Field(
|
||||
8, alias="进程数", description="适用于ppg2mel",
|
||||
)
|
||||
seed: int = Field(
|
||||
default=0, alias="初始随机数", description="适用于ppg2mel",
|
||||
)
|
||||
model_name: str = Field(
|
||||
..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效",
|
||||
example="test"
|
||||
)
|
||||
model_config: str = Field(
|
||||
..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效",
|
||||
example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2"
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: Tuple[str, int]
|
||||
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
sr, count = self.__root__
|
||||
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
|
||||
|
||||
def train_vc(input: Input) -> Output:
|
||||
"""Train VC(训练 VC)"""
|
||||
|
||||
print(">>> OneShot VC training ...")
|
||||
params = AttrDict()
|
||||
params.update({
|
||||
"gpu": input.gpu,
|
||||
"cpu": not input.gpu,
|
||||
"njobs": input.njobs,
|
||||
"seed": input.seed,
|
||||
"verbose": input.verbose,
|
||||
"load": input.convertor.value,
|
||||
"warm_start": False,
|
||||
})
|
||||
if input.continue_mode:
|
||||
# trace old model and config
|
||||
p = Path(input.convertor.value)
|
||||
params.name = p.parent.name
|
||||
# search a config file
|
||||
model_config_fpaths = list(p.parent.rglob("*.yaml"))
|
||||
if len(model_config_fpaths) == 0:
|
||||
raise "No model yaml config found for convertor"
|
||||
config = HpsYaml(model_config_fpaths[0])
|
||||
params.ckpdir = p.parent.parent
|
||||
params.config = model_config_fpaths[0]
|
||||
params.logdir = os.path.join(p.parent, "log")
|
||||
else:
|
||||
# Make the config dict dot visitable
|
||||
config = HpsYaml(input.config)
|
||||
np.random.seed(input.seed)
|
||||
torch.manual_seed(input.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(input.seed)
|
||||
mode = "train"
|
||||
from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||
solver = Solver(config, params, mode)
|
||||
solver.load_data()
|
||||
solver.set_model()
|
||||
solver.exec()
|
||||
print(">>> Oneshot VC train finished!")
|
||||
|
||||
# TODO: pass useful return code
|
||||
return Output(__root__=(input.dataset, 0))
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 5.6 KiB |
@@ -1,31 +0,0 @@
|
||||
{
|
||||
"resblock": "1",
|
||||
"num_gpus": 0,
|
||||
"batch_size": 16,
|
||||
"learning_rate": 0.0002,
|
||||
"adam_b1": 0.8,
|
||||
"adam_b2": 0.99,
|
||||
"lr_decay": 0.999,
|
||||
"seed": 1234,
|
||||
|
||||
"upsample_rates": [5,5,4,2],
|
||||
"upsample_kernel_sizes": [10,10,8,4],
|
||||
"upsample_initial_channel": 512,
|
||||
"resblock_kernel_sizes": [3,7,11],
|
||||
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
||||
|
||||
"segment_size": 6400,
|
||||
"num_mels": 80,
|
||||
"num_freq": 1025,
|
||||
"n_fft": 1024,
|
||||
"hop_size": 200,
|
||||
"win_size": 800,
|
||||
|
||||
"sampling_rate": 16000,
|
||||
|
||||
"fmin": 0,
|
||||
"fmax": 7600,
|
||||
"fmax_for_loss": null,
|
||||
|
||||
"num_workers": 4
|
||||
}
|
||||
Binary file not shown.
@@ -1,8 +0,0 @@
|
||||
https://openslr.magicdatatech.com/resources/62/aidatatang_200zh.tgz
|
||||
out=download/aidatatang_200zh.tgz
|
||||
https://openslr.magicdatatech.com/resources/68/train_set.tar.gz
|
||||
out=download/magicdata.tgz
|
||||
https://openslr.magicdatatech.com/resources/93/data_aishell3.tgz
|
||||
out=download/aishell3.tgz
|
||||
https://openslr.magicdatatech.com/resources/33/data_aishell.tgz
|
||||
out=download/data_aishell.tgz
|
||||
@@ -1,8 +0,0 @@
|
||||
https://openslr.elda.org/resources/62/aidatatang_200zh.tgz
|
||||
out=download/aidatatang_200zh.tgz
|
||||
https://openslr.elda.org/resources/68/train_set.tar.gz
|
||||
out=download/magicdata.tgz
|
||||
https://openslr.elda.org/resources/93/data_aishell3.tgz
|
||||
out=download/aishell3.tgz
|
||||
https://openslr.elda.org/resources/33/data_aishell.tgz
|
||||
out=download/data_aishell.tgz
|
||||
@@ -1,8 +0,0 @@
|
||||
https://us.openslr.org/resources/62/aidatatang_200zh.tgz
|
||||
out=download/aidatatang_200zh.tgz
|
||||
https://us.openslr.org/resources/68/train_set.tar.gz
|
||||
out=download/magicdata.tgz
|
||||
https://us.openslr.org/resources/93/data_aishell3.tgz
|
||||
out=download/aishell3.tgz
|
||||
https://us.openslr.org/resources/33/data_aishell.tgz
|
||||
out=download/data_aishell.tgz
|
||||
@@ -1,4 +0,0 @@
|
||||
0c0ace77fe8ee77db8d7542d6eb0b7ddf09b1bfb880eb93a7fbdbf4611e9984b /datasets/download/aidatatang_200zh.tgz
|
||||
be2507d431ad59419ec871e60674caedb2b585f84ffa01fe359784686db0e0cc /datasets/download/aishell3.tgz
|
||||
a4a0313cde0a933e0e01a451f77de0a23d6c942f4694af5bb7f40b9dc38143fe /datasets/download/data_aishell.tgz
|
||||
1d2647c614b74048cfe16492570cc5146d800afdc07483a43b31809772632143 /datasets/download/magicdata.tgz
|
||||
@@ -1,8 +0,0 @@
|
||||
https://www.openslr.org/resources/62/aidatatang_200zh.tgz
|
||||
out=download/aidatatang_200zh.tgz
|
||||
https://www.openslr.org/resources/68/train_set.tar.gz
|
||||
out=download/magicdata.tgz
|
||||
https://www.openslr.org/resources/93/data_aishell3.tgz
|
||||
out=download/aishell3.tgz
|
||||
https://www.openslr.org/resources/33/data_aishell.tgz
|
||||
out=download/data_aishell.tgz
|
||||
@@ -1,8 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -Eeuo pipefail
|
||||
|
||||
aria2c -x 10 --disable-ipv6 --input-file /workspace/datasets_download/${DATASET_MIRROR}.txt --dir /datasets --continue
|
||||
|
||||
echo "Verifying sha256sum..."
|
||||
parallel --will-cite -a /workspace/datasets_download/datasets.sha256sum "echo -n {} | sha256sum -c"
|
||||
@@ -1,29 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -Eeuo pipefail
|
||||
|
||||
mkdir -p /datasets/aidatatang_200zh
|
||||
if [ -z "$(ls -A /datasets/aidatatang_200zh)" ] ; then
|
||||
tar xvz --directory /datasets/ -f /datasets/download/aidatatang_200zh.tgz --exclude 'aidatatang_200zh/corpus/dev/*' --exclude 'aidatatang_200zh/corpus/test/*'
|
||||
cd /datasets/aidatatang_200zh/corpus/train/
|
||||
cat *.tar.gz | tar zxvf - -i
|
||||
rm -f *.tar.gz
|
||||
fi
|
||||
|
||||
mkdir -p /datasets/magicdata
|
||||
if [ -z "$(ls -A /datasets/magicdata)" ] ; then
|
||||
tar xvz --directory /datasets/magicdata -f /datasets/download/magicdata.tgz train/
|
||||
fi
|
||||
|
||||
mkdir -p /datasets/aishell3
|
||||
if [ -z "$(ls -A /datasets/aishell3)" ] ; then
|
||||
tar xvz --directory /datasets/aishell3 -f /datasets/download/aishell3.tgz train/
|
||||
fi
|
||||
|
||||
mkdir -p /datasets/data_aishell
|
||||
if [ -z "$(ls -A /datasets/data_aishell)" ] ; then
|
||||
tar xvz --directory /datasets/ -f /datasets/download/data_aishell.tgz
|
||||
cd /datasets/data_aishell/wav/
|
||||
cat *.tar.gz | tar zxvf - -i --exclude 'dev/*' --exclude 'test/*'
|
||||
rm -f *.tar.gz
|
||||
fi
|
||||
@@ -1,5 +1,5 @@
|
||||
from pathlib import Path
|
||||
from control.toolbox import Toolbox
|
||||
from toolbox import Toolbox
|
||||
from utils.argutils import print_args
|
||||
from utils.modelutils import check_model_paths
|
||||
import argparse
|
||||
@@ -15,18 +15,12 @@ if __name__ == '__main__':
|
||||
parser.add_argument("-d", "--datasets_root", type=Path, help= \
|
||||
"Path to the directory containing your datasets. See toolbox/__init__.py for a list of "
|
||||
"supported datasets.", default=None)
|
||||
parser.add_argument("-vc", "--vc_mode", action="store_true",
|
||||
help="Voice Conversion Mode(PPG based)")
|
||||
parser.add_argument("-e", "--enc_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}encoder",
|
||||
parser.add_argument("-e", "--enc_models_dir", type=Path, default="encoder/saved_models",
|
||||
help="Directory containing saved encoder models")
|
||||
parser.add_argument("-s", "--syn_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}synthesizer",
|
||||
parser.add_argument("-s", "--syn_models_dir", type=Path, default="synthesizer/saved_models",
|
||||
help="Directory containing saved synthesizer models")
|
||||
parser.add_argument("-v", "--voc_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}vocoder",
|
||||
parser.add_argument("-v", "--voc_models_dir", type=Path, default="vocoder/saved_models",
|
||||
help="Directory containing saved vocoder models")
|
||||
parser.add_argument("-ex", "--extractor_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}ppg_extractor",
|
||||
help="Directory containing saved extrator models")
|
||||
parser.add_argument("-cv", "--convertor_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}ppg2mel",
|
||||
help="Directory containing saved convert models")
|
||||
parser.add_argument("--cpu", action="store_true", help=\
|
||||
"If True, processing is done on CPU, even when a GPU is available.")
|
||||
parser.add_argument("--seed", type=int, default=None, help=\
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
server:
|
||||
image: mockingbird:latest
|
||||
build: .
|
||||
volumes:
|
||||
- ./datasets:/datasets
|
||||
- ./synthesizer/saved_models:/workspace/synthesizer/saved_models
|
||||
environment:
|
||||
- DATASET_MIRROR=US
|
||||
- FORCE_RETRAIN=false
|
||||
- TRAIN_DATASETS=aidatatang_200zh magicdata aishell3 data_aishell
|
||||
- TRAIN_SKIP_EXISTING=true
|
||||
ports:
|
||||
- 8080:8080
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
device_ids: [ '0' ]
|
||||
capabilities: [ gpu ]
|
||||
@@ -1,17 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
if [ -z "$(ls -A /workspace/synthesizer/saved_models)" ] || [ "$FORCE_RETRAIN" = true ] ; then
|
||||
/workspace/datasets_download/download.sh
|
||||
/workspace/datasets_download/extract.sh
|
||||
for DATASET in ${TRAIN_DATASETS}
|
||||
do
|
||||
if [ "$TRAIN_SKIP_EXISTING" = true ] ; then
|
||||
python pre.py /datasets -d ${DATASET} -n $(nproc) --skip_existing
|
||||
else
|
||||
python pre.py /datasets -d ${DATASET} -n $(nproc)
|
||||
fi
|
||||
done
|
||||
python synthesizer_train.py mandarin /datasets/SV2TTS/synthesizer
|
||||
fi
|
||||
|
||||
python web.py
|
||||
@@ -1,5 +1,5 @@
|
||||
from scipy.ndimage.morphology import binary_dilation
|
||||
from models.encoder.params_data import *
|
||||
from encoder.params_data import *
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from warnings import warn
|
||||
@@ -39,7 +39,7 @@ def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
|
||||
|
||||
# Resample the wav if needed
|
||||
if source_sr is not None and source_sr != sampling_rate:
|
||||
wav = librosa.resample(wav, orig_sr = source_sr, target_sr = sampling_rate)
|
||||
wav = librosa.resample(wav, source_sr, sampling_rate)
|
||||
|
||||
# Apply the preprocessing: normalize volume and shorten long silences
|
||||
if normalize:
|
||||
@@ -56,8 +56,8 @@ def wav_to_mel_spectrogram(wav):
|
||||
Note: this not a log-mel spectrogram.
|
||||
"""
|
||||
frames = librosa.feature.melspectrogram(
|
||||
y=wav,
|
||||
sr=sampling_rate,
|
||||
wav,
|
||||
sampling_rate,
|
||||
n_fft=int(sampling_rate * mel_window_length / 1000),
|
||||
hop_length=int(sampling_rate * mel_window_step / 1000),
|
||||
n_mels=mel_n_channels
|
||||
@@ -99,7 +99,7 @@ def trim_long_silences(wav):
|
||||
return ret[width - 1:] / width
|
||||
|
||||
audio_mask = moving_average(voice_flags, vad_moving_average_width)
|
||||
audio_mask = np.round(audio_mask).astype(bool)
|
||||
audio_mask = np.round(audio_mask).astype(np.bool)
|
||||
|
||||
# Dilate the voiced regions
|
||||
audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
|
||||
2
encoder/data_objects/__init__.py
Normal file
2
encoder/data_objects/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
|
||||
from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
|
||||
@@ -1,5 +1,5 @@
|
||||
from models.encoder.data_objects.random_cycler import RandomCycler
|
||||
from models.encoder.data_objects.utterance import Utterance
|
||||
from encoder.data_objects.random_cycler import RandomCycler
|
||||
from encoder.data_objects.utterance import Utterance
|
||||
from pathlib import Path
|
||||
|
||||
# Contains the set of utterances of a single speaker
|
||||
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
from typing import List
|
||||
from models.encoder.data_objects.speaker import Speaker
|
||||
from encoder.data_objects.speaker import Speaker
|
||||
|
||||
class SpeakerBatch:
|
||||
def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
|
||||
@@ -1,7 +1,7 @@
|
||||
from models.encoder.data_objects.random_cycler import RandomCycler
|
||||
from models.encoder.data_objects.speaker_batch import SpeakerBatch
|
||||
from models.encoder.data_objects.speaker import Speaker
|
||||
from models.encoder.params_data import partials_n_frames
|
||||
from encoder.data_objects.random_cycler import RandomCycler
|
||||
from encoder.data_objects.speaker_batch import SpeakerBatch
|
||||
from encoder.data_objects.speaker import Speaker
|
||||
from encoder.params_data import partials_n_frames
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from pathlib import Path
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from models.encoder.params_data import *
|
||||
from models.encoder.model import SpeakerEncoder
|
||||
from models.encoder.audio import preprocess_wav # We want to expose this function from here
|
||||
from encoder.params_data import *
|
||||
from encoder.model import SpeakerEncoder
|
||||
from encoder.audio import preprocess_wav # We want to expose this function from here
|
||||
from matplotlib import cm
|
||||
from models.encoder import audio
|
||||
from encoder import audio
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
@@ -34,16 +34,8 @@ def load_model(weights_fpath: Path, device=None):
|
||||
_model.load_state_dict(checkpoint["model_state"])
|
||||
_model.eval()
|
||||
print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
|
||||
return _model
|
||||
|
||||
def set_model(model, device=None):
|
||||
global _model, _device
|
||||
_model = model
|
||||
if device is None:
|
||||
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
_device = device
|
||||
_model.to(device)
|
||||
|
||||
|
||||
def is_loaded():
|
||||
return _model is not None
|
||||
|
||||
@@ -65,7 +57,7 @@ def embed_frames_batch(frames_batch):
|
||||
|
||||
|
||||
def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
|
||||
min_pad_coverage=0.75, overlap=0.5, rate=None):
|
||||
min_pad_coverage=0.75, overlap=0.5):
|
||||
"""
|
||||
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
|
||||
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
|
||||
@@ -93,18 +85,9 @@ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_fram
|
||||
assert 0 <= overlap < 1
|
||||
assert 0 < min_pad_coverage <= 1
|
||||
|
||||
if rate != None:
|
||||
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
||||
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
||||
frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
|
||||
else:
|
||||
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
||||
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
||||
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
|
||||
|
||||
assert 0 < frame_step, "The rate is too high"
|
||||
assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
|
||||
(sampling_rate / (samples_per_frame * partials_n_frames))
|
||||
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
||||
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
||||
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
|
||||
|
||||
# Compute the slices
|
||||
wav_slices, mel_slices = [], []
|
||||
@@ -1,5 +1,5 @@
|
||||
from models.encoder.params_model import *
|
||||
from models.encoder.params_data import *
|
||||
from encoder.params_model import *
|
||||
from encoder.params_data import *
|
||||
from scipy.interpolate import interp1d
|
||||
from sklearn.metrics import roc_curve
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
@@ -1,8 +1,8 @@
|
||||
from multiprocess.pool import ThreadPool
|
||||
from models.encoder.params_data import *
|
||||
from models.encoder.config import librispeech_datasets, anglophone_nationalites
|
||||
from encoder.params_data import *
|
||||
from encoder.config import librispeech_datasets, anglophone_nationalites
|
||||
from datetime import datetime
|
||||
from models.encoder import audio
|
||||
from encoder import audio
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
@@ -22,7 +22,7 @@ class DatasetLog:
|
||||
self._log_params()
|
||||
|
||||
def _log_params(self):
|
||||
from models.encoder import params_data
|
||||
from encoder import params_data
|
||||
self.write_line("Parameter values:")
|
||||
for param_name in (p for p in dir(params_data) if not p.startswith("__")):
|
||||
value = getattr(params_data, param_name)
|
||||
@@ -1,7 +1,7 @@
|
||||
from models.encoder.visualizations import Visualizations
|
||||
from models.encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
|
||||
from models.encoder.params_model import *
|
||||
from models.encoder.model import SpeakerEncoder
|
||||
from encoder.visualizations import Visualizations
|
||||
from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
|
||||
from encoder.params_model import *
|
||||
from encoder.model import SpeakerEncoder
|
||||
from utils.profiler import Profiler
|
||||
from pathlib import Path
|
||||
import torch
|
||||
@@ -1,4 +1,4 @@
|
||||
from models.encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
|
||||
from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
|
||||
from datetime import datetime
|
||||
from time import perf_counter as timer
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -21,7 +21,7 @@ colormap = np.array([
|
||||
[33, 0, 127],
|
||||
[0, 0, 0],
|
||||
[183, 183, 183],
|
||||
], dtype=float) / 255
|
||||
], dtype=np.float) / 255
|
||||
|
||||
|
||||
class Visualizations:
|
||||
@@ -65,8 +65,8 @@ class Visualizations:
|
||||
def log_params(self):
|
||||
if self.disabled:
|
||||
return
|
||||
from models.encoder import params_data
|
||||
from models.encoder import params_model
|
||||
from encoder import params_data
|
||||
from encoder import params_model
|
||||
param_string = "<b>Model parameters</b>:<br>"
|
||||
for param_name in (p for p in dir(params_model) if not p.startswith("__")):
|
||||
value = getattr(params_model, param_name)
|
||||
@@ -1,10 +1,7 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from models.encoder.preprocess import (preprocess_aidatatang_200zh,
|
||||
preprocess_librispeech, preprocess_voxceleb1,
|
||||
preprocess_voxceleb2)
|
||||
from encoder.preprocess import preprocess_librispeech, preprocess_voxceleb1, preprocess_voxceleb2, preprocess_aidatatang_200zh
|
||||
from utils.argutils import print_args
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
if __name__ == "__main__":
|
||||
class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
|
||||
@@ -1,5 +1,5 @@
|
||||
from utils.argutils import print_args
|
||||
from models.encoder.train import train
|
||||
from encoder.train import train
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
BIN
fmcc_result.png
Normal file
BIN
fmcc_result.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 188 KiB |
BIN
fmcc_source.png
Normal file
BIN
fmcc_source.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 195 KiB |
120
gen_voice.py
120
gen_voice.py
@@ -1,120 +0,0 @@
|
||||
from models.synthesizer.inference import Synthesizer
|
||||
from models.encoder import inference as encoder
|
||||
from models.vocoder.hifigan import inference as gan_vocoder
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
import cn2an
|
||||
|
||||
vocoder = gan_vocoder
|
||||
|
||||
def gen_one_wav(synthesizer, in_fpath, embed, texts, file_name, seq):
|
||||
embeds = [embed] * len(texts)
|
||||
# If you know what the attention layer alignments are, you can retrieve them here by
|
||||
# passing return_alignments=True
|
||||
specs = synthesizer.synthesize_spectrograms(texts, embeds, style_idx=-1, min_stop_token=4, steps=400)
|
||||
#spec = specs[0]
|
||||
breaks = [spec.shape[1] for spec in specs]
|
||||
spec = np.concatenate(specs, axis=1)
|
||||
|
||||
# If seed is specified, reset torch seed and reload vocoder
|
||||
# Synthesizing the waveform is fairly straightforward. Remember that the longer the
|
||||
# spectrogram, the more time-efficient the vocoder.
|
||||
generated_wav, output_sample_rate = vocoder.infer_waveform(spec)
|
||||
|
||||
# Add breaks
|
||||
b_ends = np.cumsum(np.array(breaks) * synthesizer.hparams.hop_size)
|
||||
b_starts = np.concatenate(([0], b_ends[:-1]))
|
||||
wavs = [generated_wav[start:end] for start, end, in zip(b_starts, b_ends)]
|
||||
breaks = [np.zeros(int(0.15 * synthesizer.sample_rate))] * len(breaks)
|
||||
generated_wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
|
||||
|
||||
## Post-generation
|
||||
# There's a bug with sounddevice that makes the audio cut one second earlier, so we
|
||||
# pad it.
|
||||
|
||||
# Trim excess silences to compensate for gaps in spectrograms (issue #53)
|
||||
generated_wav = encoder.preprocess_wav(generated_wav)
|
||||
generated_wav = generated_wav / np.abs(generated_wav).max() * 0.97
|
||||
|
||||
# Save it on the disk
|
||||
model=os.path.basename(in_fpath)
|
||||
filename = "%s_%d_%s.wav" %(file_name, seq, model)
|
||||
sf.write(filename, generated_wav, synthesizer.sample_rate)
|
||||
|
||||
print("\nSaved output as %s\n\n" % filename)
|
||||
|
||||
|
||||
def generate_wav(enc_model_fpath, syn_model_fpath, voc_model_fpath, in_fpath, input_txt, file_name):
|
||||
if torch.cuda.is_available():
|
||||
device_id = torch.cuda.current_device()
|
||||
gpu_properties = torch.cuda.get_device_properties(device_id)
|
||||
## Print some environment information (for debugging purposes)
|
||||
print("Found %d GPUs available. Using GPU %d (%s) of compute capability %d.%d with "
|
||||
"%.1fGb total memory.\n" %
|
||||
(torch.cuda.device_count(),
|
||||
device_id,
|
||||
gpu_properties.name,
|
||||
gpu_properties.major,
|
||||
gpu_properties.minor,
|
||||
gpu_properties.total_memory / 1e9))
|
||||
else:
|
||||
print("Using CPU for inference.\n")
|
||||
|
||||
print("Preparing the encoder, the synthesizer and the vocoder...")
|
||||
encoder.load_model(enc_model_fpath)
|
||||
synthesizer = Synthesizer(syn_model_fpath)
|
||||
vocoder.load_model(voc_model_fpath)
|
||||
|
||||
encoder_wav = synthesizer.load_preprocess_wav(in_fpath)
|
||||
embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
||||
|
||||
texts = input_txt.split("\n")
|
||||
seq=0
|
||||
each_num=1500
|
||||
|
||||
punctuation = '!,。、,' # punctuate and split/clean text
|
||||
processed_texts = []
|
||||
cur_num = 0
|
||||
for text in texts:
|
||||
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
|
||||
if processed_text:
|
||||
processed_texts.append(processed_text.strip())
|
||||
cur_num += len(processed_text.strip())
|
||||
if cur_num > each_num:
|
||||
seq = seq +1
|
||||
gen_one_wav(synthesizer, in_fpath, embed, processed_texts, file_name, seq)
|
||||
processed_texts = []
|
||||
cur_num = 0
|
||||
|
||||
if len(processed_texts)>0:
|
||||
seq = seq +1
|
||||
gen_one_wav(synthesizer, in_fpath, embed, processed_texts, file_name, seq)
|
||||
|
||||
if (len(sys.argv)>=3):
|
||||
my_txt = ""
|
||||
print("reading from :", sys.argv[1])
|
||||
with open(sys.argv[1], "r") as f:
|
||||
for line in f.readlines():
|
||||
#line = line.strip('\n')
|
||||
my_txt += line
|
||||
txt_file_name = sys.argv[1]
|
||||
wav_file_name = sys.argv[2]
|
||||
|
||||
output = cn2an.transform(my_txt, "an2cn")
|
||||
print(output)
|
||||
generate_wav(
|
||||
Path("encoder/saved_models/pretrained.pt"),
|
||||
Path("synthesizer/saved_models/mandarin.pt"),
|
||||
Path("vocoder/saved_models/pretrained/g_hifigan.pt"), wav_file_name, output, txt_file_name
|
||||
)
|
||||
|
||||
else:
|
||||
print("please input the file name")
|
||||
exit(1)
|
||||
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
from models.encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
|
||||
from models.encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
|
||||
@@ -1,209 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2020 Songxiang Liu
|
||||
# Apache 2.0
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .utils.abs_model import AbsMelDecoder
|
||||
from .rnn_decoder_mol import Decoder
|
||||
from .utils.cnn_postnet import Postnet
|
||||
from .utils.vc_utils import get_mask_from_lengths
|
||||
|
||||
from utils.hparams import HpsYaml
|
||||
|
||||
class MelDecoderMOLv2(AbsMelDecoder):
|
||||
"""Use an encoder to preprocess ppg."""
|
||||
def __init__(
|
||||
self,
|
||||
num_speakers: int,
|
||||
spk_embed_dim: int,
|
||||
bottle_neck_feature_dim: int,
|
||||
encoder_dim: int = 256,
|
||||
encoder_downsample_rates: List = [2, 2],
|
||||
attention_rnn_dim: int = 512,
|
||||
decoder_rnn_dim: int = 512,
|
||||
num_decoder_rnn_layer: int = 1,
|
||||
concat_context_to_last: bool = True,
|
||||
prenet_dims: List = [256, 128],
|
||||
num_mixtures: int = 5,
|
||||
frames_per_step: int = 2,
|
||||
mask_padding: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.mask_padding = mask_padding
|
||||
self.bottle_neck_feature_dim = bottle_neck_feature_dim
|
||||
self.num_mels = 80
|
||||
self.encoder_down_factor=np.cumprod(encoder_downsample_rates)[-1]
|
||||
self.frames_per_step = frames_per_step
|
||||
self.use_spk_dvec = True
|
||||
|
||||
input_dim = bottle_neck_feature_dim
|
||||
|
||||
# Downsampling convolution
|
||||
self.bnf_prenet = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(input_dim, encoder_dim, kernel_size=1, bias=False),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
torch.nn.Conv1d(
|
||||
encoder_dim, encoder_dim,
|
||||
kernel_size=2*encoder_downsample_rates[0],
|
||||
stride=encoder_downsample_rates[0],
|
||||
padding=encoder_downsample_rates[0]//2,
|
||||
),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
torch.nn.Conv1d(
|
||||
encoder_dim, encoder_dim,
|
||||
kernel_size=2*encoder_downsample_rates[1],
|
||||
stride=encoder_downsample_rates[1],
|
||||
padding=encoder_downsample_rates[1]//2,
|
||||
),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
)
|
||||
decoder_enc_dim = encoder_dim
|
||||
self.pitch_convs = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(2, encoder_dim, kernel_size=1, bias=False),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
torch.nn.Conv1d(
|
||||
encoder_dim, encoder_dim,
|
||||
kernel_size=2*encoder_downsample_rates[0],
|
||||
stride=encoder_downsample_rates[0],
|
||||
padding=encoder_downsample_rates[0]//2,
|
||||
),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
torch.nn.Conv1d(
|
||||
encoder_dim, encoder_dim,
|
||||
kernel_size=2*encoder_downsample_rates[1],
|
||||
stride=encoder_downsample_rates[1],
|
||||
padding=encoder_downsample_rates[1]//2,
|
||||
),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
)
|
||||
|
||||
self.reduce_proj = torch.nn.Linear(encoder_dim + spk_embed_dim, encoder_dim)
|
||||
|
||||
# Decoder
|
||||
self.decoder = Decoder(
|
||||
enc_dim=decoder_enc_dim,
|
||||
num_mels=self.num_mels,
|
||||
frames_per_step=frames_per_step,
|
||||
attention_rnn_dim=attention_rnn_dim,
|
||||
decoder_rnn_dim=decoder_rnn_dim,
|
||||
num_decoder_rnn_layer=num_decoder_rnn_layer,
|
||||
prenet_dims=prenet_dims,
|
||||
num_mixtures=num_mixtures,
|
||||
use_stop_tokens=True,
|
||||
concat_context_to_last=concat_context_to_last,
|
||||
encoder_down_factor=self.encoder_down_factor,
|
||||
)
|
||||
|
||||
# Mel-Spec Postnet: some residual CNN layers
|
||||
self.postnet = Postnet()
|
||||
|
||||
def parse_output(self, outputs, output_lengths=None):
|
||||
if self.mask_padding and output_lengths is not None:
|
||||
mask = ~get_mask_from_lengths(output_lengths, outputs[0].size(1))
|
||||
mask = mask.unsqueeze(2).expand(mask.size(0), mask.size(1), self.num_mels)
|
||||
outputs[0].data.masked_fill_(mask, 0.0)
|
||||
outputs[1].data.masked_fill_(mask, 0.0)
|
||||
return outputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
bottle_neck_features: torch.Tensor,
|
||||
feature_lengths: torch.Tensor,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
logf0_uv: torch.Tensor = None,
|
||||
spembs: torch.Tensor = None,
|
||||
output_att_ws: bool = False,
|
||||
):
|
||||
decoder_inputs = self.bnf_prenet(
|
||||
bottle_neck_features.transpose(1, 2)
|
||||
).transpose(1, 2)
|
||||
logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
|
||||
decoder_inputs = decoder_inputs + logf0_uv
|
||||
|
||||
assert spembs is not None
|
||||
spk_embeds = F.normalize(
|
||||
spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
|
||||
decoder_inputs = torch.cat([decoder_inputs, spk_embeds], dim=-1)
|
||||
decoder_inputs = self.reduce_proj(decoder_inputs)
|
||||
|
||||
# (B, num_mels, T_dec)
|
||||
T_dec = torch.div(feature_lengths, int(self.encoder_down_factor), rounding_mode='floor')
|
||||
mel_outputs, predicted_stop, alignments = self.decoder(
|
||||
decoder_inputs, speech, T_dec)
|
||||
## Post-processing
|
||||
mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
if output_att_ws:
|
||||
return self.parse_output(
|
||||
[mel_outputs, mel_outputs_postnet, predicted_stop, alignments], speech_lengths)
|
||||
else:
|
||||
return self.parse_output(
|
||||
[mel_outputs, mel_outputs_postnet, predicted_stop], speech_lengths)
|
||||
|
||||
# return mel_outputs, mel_outputs_postnet
|
||||
|
||||
def inference(
|
||||
self,
|
||||
bottle_neck_features: torch.Tensor,
|
||||
logf0_uv: torch.Tensor = None,
|
||||
spembs: torch.Tensor = None,
|
||||
):
|
||||
decoder_inputs = self.bnf_prenet(bottle_neck_features.transpose(1, 2)).transpose(1, 2)
|
||||
logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
|
||||
decoder_inputs = decoder_inputs + logf0_uv
|
||||
|
||||
assert spembs is not None
|
||||
spk_embeds = F.normalize(
|
||||
spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
|
||||
bottle_neck_features = torch.cat([decoder_inputs, spk_embeds], dim=-1)
|
||||
bottle_neck_features = self.reduce_proj(bottle_neck_features)
|
||||
|
||||
## Decoder
|
||||
if bottle_neck_features.size(0) > 1:
|
||||
mel_outputs, alignments = self.decoder.inference_batched(bottle_neck_features)
|
||||
else:
|
||||
mel_outputs, alignments = self.decoder.inference(bottle_neck_features,)
|
||||
## Post-processing
|
||||
mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
# outputs = mel_outputs_postnet[0]
|
||||
|
||||
return mel_outputs[0], mel_outputs_postnet[0], alignments[0]
|
||||
|
||||
def load_model(model_file, device=None):
|
||||
# search a config file
|
||||
model_config_fpaths = list(model_file.parent.rglob("*.yaml"))
|
||||
if len(model_config_fpaths) == 0:
|
||||
raise "No model yaml config found for convertor"
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model_config = HpsYaml(model_config_fpaths[0])
|
||||
ppg2mel_model = MelDecoderMOLv2(
|
||||
**model_config["model"]
|
||||
).to(device)
|
||||
ckpt = torch.load(model_file, map_location=device)
|
||||
ppg2mel_model.load_state_dict(ckpt["model"])
|
||||
ppg2mel_model.eval()
|
||||
return ppg2mel_model
|
||||
@@ -1,113 +0,0 @@
|
||||
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
import soundfile
|
||||
import resampy
|
||||
|
||||
from models.ppg_extractor import load_model
|
||||
import encoder.inference as Encoder
|
||||
from models.encoder.audio import preprocess_wav
|
||||
from models.encoder import audio
|
||||
from utils.f0_utils import compute_f0
|
||||
|
||||
from torch.multiprocessing import Pool, cpu_count
|
||||
from functools import partial
|
||||
|
||||
SAMPLE_RATE=16000
|
||||
|
||||
def _compute_bnf(
|
||||
wav: any,
|
||||
output_fpath: str,
|
||||
device: torch.device,
|
||||
ppg_model_local: any,
|
||||
):
|
||||
"""
|
||||
Compute CTC-Attention Seq2seq ASR encoder bottle-neck features (BNF).
|
||||
"""
|
||||
ppg_model_local.to(device)
|
||||
wav_tensor = torch.from_numpy(wav).float().to(device).unsqueeze(0)
|
||||
wav_length = torch.LongTensor([wav.shape[0]]).to(device)
|
||||
with torch.no_grad():
|
||||
bnf = ppg_model_local(wav_tensor, wav_length)
|
||||
bnf_npy = bnf.squeeze(0).cpu().numpy()
|
||||
np.save(output_fpath, bnf_npy, allow_pickle=False)
|
||||
return bnf_npy, len(bnf_npy)
|
||||
|
||||
def _compute_f0_from_wav(wav, output_fpath):
|
||||
"""Compute merged f0 values."""
|
||||
f0 = compute_f0(wav, SAMPLE_RATE)
|
||||
np.save(output_fpath, f0, allow_pickle=False)
|
||||
return f0, len(f0)
|
||||
|
||||
def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device):
|
||||
Encoder.set_model(encoder_model_local)
|
||||
# Compute where to split the utterance into partials and pad if necessary
|
||||
wave_slices, mel_slices = Encoder.compute_partial_slices(len(wav), rate=1.3, min_pad_coverage=0.75)
|
||||
max_wave_length = wave_slices[-1].stop
|
||||
if max_wave_length >= len(wav):
|
||||
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
||||
|
||||
# Split the utterance into partials
|
||||
frames = audio.wav_to_mel_spectrogram(wav)
|
||||
frames_batch = np.array([frames[s] for s in mel_slices])
|
||||
partial_embeds = Encoder.embed_frames_batch(frames_batch)
|
||||
|
||||
# Compute the utterance embedding from the partial embeddings
|
||||
raw_embed = np.mean(partial_embeds, axis=0)
|
||||
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
||||
|
||||
np.save(output_fpath, embed, allow_pickle=False)
|
||||
return embed, len(embed)
|
||||
|
||||
def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local):
|
||||
# wav = preprocess_wav(wav_path)
|
||||
# try:
|
||||
wav, sr = soundfile.read(wav_path)
|
||||
if len(wav) < sr:
|
||||
return None, sr, len(wav)
|
||||
if sr != SAMPLE_RATE:
|
||||
wav = resampy.resample(wav, sr, SAMPLE_RATE)
|
||||
sr = SAMPLE_RATE
|
||||
utt_id = os.path.basename(wav_path).rstrip(".wav")
|
||||
|
||||
_, length_bnf = _compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local)
|
||||
_, length_f0 = _compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav)
|
||||
_, length_embed = _compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav)
|
||||
|
||||
def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model):
|
||||
# Glob wav files
|
||||
wav_file_list = sorted(Path(f"{datasets_root}/{dataset}").glob("**/*.wav"))
|
||||
print(f"Globbed {len(wav_file_list)} wav files.")
|
||||
|
||||
out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True)
|
||||
out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True)
|
||||
out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True)
|
||||
ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu")
|
||||
encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu")
|
||||
if n_processes is None:
|
||||
n_processes = cpu_count()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, encoder_model_local=encoder_model_local, device=device)
|
||||
job = Pool(n_processes).imap(func, wav_file_list)
|
||||
list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav"))
|
||||
|
||||
# finish processing and mark
|
||||
t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8")
|
||||
d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8")
|
||||
e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8")
|
||||
for file in sorted(out_dir.joinpath("f0").glob("*.npy")):
|
||||
id = os.path.basename(file).split(".f0.npy")[0]
|
||||
if id.endswith("01"):
|
||||
d_fid_file.write(id + "\n")
|
||||
elif id.endswith("09"):
|
||||
e_fid_file.write(id + "\n")
|
||||
else:
|
||||
t_fid_file.write(id + "\n")
|
||||
t_fid_file.close()
|
||||
d_fid_file.close()
|
||||
e_fid_file.close()
|
||||
return len(wav_file_list)
|
||||
@@ -1,374 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from .utils.mol_attention import MOLAttention
|
||||
from .utils.basic_layers import Linear
|
||||
from .utils.vc_utils import get_mask_from_lengths
|
||||
|
||||
|
||||
class DecoderPrenet(nn.Module):
|
||||
def __init__(self, in_dim, sizes):
|
||||
super().__init__()
|
||||
in_sizes = [in_dim] + sizes[:-1]
|
||||
self.layers = nn.ModuleList(
|
||||
[Linear(in_size, out_size, bias=False)
|
||||
for (in_size, out_size) in zip(in_sizes, sizes)])
|
||||
|
||||
def forward(self, x):
|
||||
for linear in self.layers:
|
||||
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""Mixture of Logistic (MoL) attention-based RNN Decoder."""
|
||||
def __init__(
|
||||
self,
|
||||
enc_dim,
|
||||
num_mels,
|
||||
frames_per_step,
|
||||
attention_rnn_dim,
|
||||
decoder_rnn_dim,
|
||||
prenet_dims,
|
||||
num_mixtures,
|
||||
encoder_down_factor=1,
|
||||
num_decoder_rnn_layer=1,
|
||||
use_stop_tokens=False,
|
||||
concat_context_to_last=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.enc_dim = enc_dim
|
||||
self.encoder_down_factor = encoder_down_factor
|
||||
self.num_mels = num_mels
|
||||
self.frames_per_step = frames_per_step
|
||||
self.attention_rnn_dim = attention_rnn_dim
|
||||
self.decoder_rnn_dim = decoder_rnn_dim
|
||||
self.prenet_dims = prenet_dims
|
||||
self.use_stop_tokens = use_stop_tokens
|
||||
self.num_decoder_rnn_layer = num_decoder_rnn_layer
|
||||
self.concat_context_to_last = concat_context_to_last
|
||||
|
||||
# Mel prenet
|
||||
self.prenet = DecoderPrenet(num_mels, prenet_dims)
|
||||
self.prenet_pitch = DecoderPrenet(num_mels, prenet_dims)
|
||||
|
||||
# Attention RNN
|
||||
self.attention_rnn = nn.LSTMCell(
|
||||
prenet_dims[-1] + enc_dim,
|
||||
attention_rnn_dim
|
||||
)
|
||||
|
||||
# Attention
|
||||
self.attention_layer = MOLAttention(
|
||||
attention_rnn_dim,
|
||||
r=frames_per_step/encoder_down_factor,
|
||||
M=num_mixtures,
|
||||
)
|
||||
|
||||
# Decoder RNN
|
||||
self.decoder_rnn_layers = nn.ModuleList()
|
||||
for i in range(num_decoder_rnn_layer):
|
||||
if i == 0:
|
||||
self.decoder_rnn_layers.append(
|
||||
nn.LSTMCell(
|
||||
enc_dim + attention_rnn_dim,
|
||||
decoder_rnn_dim))
|
||||
else:
|
||||
self.decoder_rnn_layers.append(
|
||||
nn.LSTMCell(
|
||||
decoder_rnn_dim,
|
||||
decoder_rnn_dim))
|
||||
# self.decoder_rnn = nn.LSTMCell(
|
||||
# 2 * enc_dim + attention_rnn_dim,
|
||||
# decoder_rnn_dim
|
||||
# )
|
||||
if concat_context_to_last:
|
||||
self.linear_projection = Linear(
|
||||
enc_dim + decoder_rnn_dim,
|
||||
num_mels * frames_per_step
|
||||
)
|
||||
else:
|
||||
self.linear_projection = Linear(
|
||||
decoder_rnn_dim,
|
||||
num_mels * frames_per_step
|
||||
)
|
||||
|
||||
|
||||
# Stop-token layer
|
||||
if self.use_stop_tokens:
|
||||
if concat_context_to_last:
|
||||
self.stop_layer = Linear(
|
||||
enc_dim + decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
|
||||
)
|
||||
else:
|
||||
self.stop_layer = Linear(
|
||||
decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
|
||||
)
|
||||
|
||||
|
||||
def get_go_frame(self, memory):
|
||||
B = memory.size(0)
|
||||
go_frame = torch.zeros((B, self.num_mels), dtype=torch.float,
|
||||
device=memory.device)
|
||||
return go_frame
|
||||
|
||||
def initialize_decoder_states(self, memory, mask):
|
||||
device = next(self.parameters()).device
|
||||
B = memory.size(0)
|
||||
|
||||
# attention rnn states
|
||||
self.attention_hidden = torch.zeros(
|
||||
(B, self.attention_rnn_dim), device=device)
|
||||
self.attention_cell = torch.zeros(
|
||||
(B, self.attention_rnn_dim), device=device)
|
||||
|
||||
# decoder rnn states
|
||||
self.decoder_hiddens = []
|
||||
self.decoder_cells = []
|
||||
for i in range(self.num_decoder_rnn_layer):
|
||||
self.decoder_hiddens.append(
|
||||
torch.zeros((B, self.decoder_rnn_dim),
|
||||
device=device)
|
||||
)
|
||||
self.decoder_cells.append(
|
||||
torch.zeros((B, self.decoder_rnn_dim),
|
||||
device=device)
|
||||
)
|
||||
# self.decoder_hidden = torch.zeros(
|
||||
# (B, self.decoder_rnn_dim), device=device)
|
||||
# self.decoder_cell = torch.zeros(
|
||||
# (B, self.decoder_rnn_dim), device=device)
|
||||
|
||||
self.attention_context = torch.zeros(
|
||||
(B, self.enc_dim), device=device)
|
||||
|
||||
self.memory = memory
|
||||
# self.processed_memory = self.attention_layer.memory_layer(memory)
|
||||
self.mask = mask
|
||||
|
||||
def parse_decoder_inputs(self, decoder_inputs):
|
||||
"""Prepare decoder inputs, i.e. gt mel
|
||||
Args:
|
||||
decoder_inputs:(B, T_out, n_mel_channels) inputs used for teacher-forced training.
|
||||
"""
|
||||
decoder_inputs = decoder_inputs.reshape(
|
||||
decoder_inputs.size(0),
|
||||
int(decoder_inputs.size(1)/self.frames_per_step), -1)
|
||||
# (B, T_out//r, r*num_mels) -> (T_out//r, B, r*num_mels)
|
||||
decoder_inputs = decoder_inputs.transpose(0, 1)
|
||||
# (T_out//r, B, num_mels)
|
||||
decoder_inputs = decoder_inputs[:,:,-self.num_mels:]
|
||||
return decoder_inputs
|
||||
|
||||
def parse_decoder_outputs(self, mel_outputs, alignments, stop_outputs):
|
||||
""" Prepares decoder outputs for output
|
||||
Args:
|
||||
mel_outputs:
|
||||
alignments:
|
||||
"""
|
||||
# (T_out//r, B, T_enc) -> (B, T_out//r, T_enc)
|
||||
alignments = torch.stack(alignments).transpose(0, 1)
|
||||
# (T_out//r, B) -> (B, T_out//r)
|
||||
if stop_outputs is not None:
|
||||
if alignments.size(0) == 1:
|
||||
stop_outputs = torch.stack(stop_outputs).unsqueeze(0)
|
||||
else:
|
||||
stop_outputs = torch.stack(stop_outputs).transpose(0, 1)
|
||||
stop_outputs = stop_outputs.contiguous()
|
||||
# (T_out//r, B, num_mels*r) -> (B, T_out//r, num_mels*r)
|
||||
mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
|
||||
# decouple frames per step
|
||||
# (B, T_out, num_mels)
|
||||
mel_outputs = mel_outputs.view(
|
||||
mel_outputs.size(0), -1, self.num_mels)
|
||||
return mel_outputs, alignments, stop_outputs
|
||||
|
||||
def attend(self, decoder_input):
|
||||
cell_input = torch.cat((decoder_input, self.attention_context), -1)
|
||||
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
||||
cell_input, (self.attention_hidden, self.attention_cell))
|
||||
self.attention_context, attention_weights = self.attention_layer(
|
||||
self.attention_hidden, self.memory, None, self.mask)
|
||||
|
||||
decoder_rnn_input = torch.cat(
|
||||
(self.attention_hidden, self.attention_context), -1)
|
||||
|
||||
return decoder_rnn_input, self.attention_context, attention_weights
|
||||
|
||||
def decode(self, decoder_input):
|
||||
for i in range(self.num_decoder_rnn_layer):
|
||||
if i == 0:
|
||||
self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
|
||||
decoder_input, (self.decoder_hiddens[i], self.decoder_cells[i]))
|
||||
else:
|
||||
self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
|
||||
self.decoder_hiddens[i-1], (self.decoder_hiddens[i], self.decoder_cells[i]))
|
||||
return self.decoder_hiddens[-1]
|
||||
|
||||
def forward(self, memory, mel_inputs, memory_lengths):
|
||||
""" Decoder forward pass for training
|
||||
Args:
|
||||
memory: (B, T_enc, enc_dim) Encoder outputs
|
||||
decoder_inputs: (B, T, num_mels) Decoder inputs for teacher forcing.
|
||||
memory_lengths: (B, ) Encoder output lengths for attention masking.
|
||||
Returns:
|
||||
mel_outputs: (B, T, num_mels) mel outputs from the decoder
|
||||
alignments: (B, T//r, T_enc) attention weights.
|
||||
"""
|
||||
# [1, B, num_mels]
|
||||
go_frame = self.get_go_frame(memory).unsqueeze(0)
|
||||
# [T//r, B, num_mels]
|
||||
mel_inputs = self.parse_decoder_inputs(mel_inputs)
|
||||
# [T//r + 1, B, num_mels]
|
||||
mel_inputs = torch.cat((go_frame, mel_inputs), dim=0)
|
||||
# [T//r + 1, B, prenet_dim]
|
||||
decoder_inputs = self.prenet(mel_inputs)
|
||||
# decoder_inputs_pitch = self.prenet_pitch(decoder_inputs__)
|
||||
|
||||
self.initialize_decoder_states(
|
||||
memory, mask=~get_mask_from_lengths(memory_lengths),
|
||||
)
|
||||
|
||||
self.attention_layer.init_states(memory)
|
||||
# self.attention_layer_pitch.init_states(memory_pitch)
|
||||
|
||||
mel_outputs, alignments = [], []
|
||||
if self.use_stop_tokens:
|
||||
stop_outputs = []
|
||||
else:
|
||||
stop_outputs = None
|
||||
while len(mel_outputs) < decoder_inputs.size(0) - 1:
|
||||
decoder_input = decoder_inputs[len(mel_outputs)]
|
||||
# decoder_input_pitch = decoder_inputs_pitch[len(mel_outputs)]
|
||||
|
||||
decoder_rnn_input, context, attention_weights = self.attend(decoder_input)
|
||||
|
||||
decoder_rnn_output = self.decode(decoder_rnn_input)
|
||||
if self.concat_context_to_last:
|
||||
decoder_rnn_output = torch.cat(
|
||||
(decoder_rnn_output, context), dim=1)
|
||||
|
||||
mel_output = self.linear_projection(decoder_rnn_output)
|
||||
if self.use_stop_tokens:
|
||||
stop_output = self.stop_layer(decoder_rnn_output)
|
||||
stop_outputs += [stop_output.squeeze()]
|
||||
mel_outputs += [mel_output.squeeze(1)] #? perhaps don't need squeeze
|
||||
alignments += [attention_weights]
|
||||
# alignments_pitch += [attention_weights_pitch]
|
||||
|
||||
mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
|
||||
mel_outputs, alignments, stop_outputs)
|
||||
if stop_outputs is None:
|
||||
return mel_outputs, alignments
|
||||
else:
|
||||
return mel_outputs, stop_outputs, alignments
|
||||
|
||||
def inference(self, memory, stop_threshold=0.5):
|
||||
""" Decoder inference
|
||||
Args:
|
||||
memory: (1, T_enc, D_enc) Encoder outputs
|
||||
Returns:
|
||||
mel_outputs: mel outputs from the decoder
|
||||
alignments: sequence of attention weights from the decoder
|
||||
"""
|
||||
# [1, num_mels]
|
||||
decoder_input = self.get_go_frame(memory)
|
||||
|
||||
self.initialize_decoder_states(memory, mask=None)
|
||||
|
||||
self.attention_layer.init_states(memory)
|
||||
|
||||
mel_outputs, alignments = [], []
|
||||
# NOTE(sx): heuristic
|
||||
max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step
|
||||
min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
|
||||
while True:
|
||||
decoder_input = self.prenet(decoder_input)
|
||||
|
||||
decoder_input_final, context, alignment = self.attend(decoder_input)
|
||||
|
||||
#mel_output, stop_output, alignment = self.decode(decoder_input)
|
||||
decoder_rnn_output = self.decode(decoder_input_final)
|
||||
if self.concat_context_to_last:
|
||||
decoder_rnn_output = torch.cat(
|
||||
(decoder_rnn_output, context), dim=1)
|
||||
|
||||
mel_output = self.linear_projection(decoder_rnn_output)
|
||||
stop_output = self.stop_layer(decoder_rnn_output)
|
||||
|
||||
mel_outputs += [mel_output.squeeze(1)]
|
||||
alignments += [alignment]
|
||||
|
||||
if torch.sigmoid(stop_output.data) > stop_threshold and len(mel_outputs) >= min_decoder_step:
|
||||
break
|
||||
if len(mel_outputs) >= max_decoder_step:
|
||||
# print("Warning! Decoding steps reaches max decoder steps.")
|
||||
break
|
||||
|
||||
decoder_input = mel_output[:,-self.num_mels:]
|
||||
|
||||
|
||||
mel_outputs, alignments, _ = self.parse_decoder_outputs(
|
||||
mel_outputs, alignments, None)
|
||||
|
||||
return mel_outputs, alignments
|
||||
|
||||
def inference_batched(self, memory, stop_threshold=0.5):
|
||||
""" Decoder inference
|
||||
Args:
|
||||
memory: (B, T_enc, D_enc) Encoder outputs
|
||||
Returns:
|
||||
mel_outputs: mel outputs from the decoder
|
||||
alignments: sequence of attention weights from the decoder
|
||||
"""
|
||||
# [1, num_mels]
|
||||
decoder_input = self.get_go_frame(memory)
|
||||
|
||||
self.initialize_decoder_states(memory, mask=None)
|
||||
|
||||
self.attention_layer.init_states(memory)
|
||||
|
||||
mel_outputs, alignments = [], []
|
||||
stop_outputs = []
|
||||
# NOTE(sx): heuristic
|
||||
max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step
|
||||
min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
|
||||
while True:
|
||||
decoder_input = self.prenet(decoder_input)
|
||||
|
||||
decoder_input_final, context, alignment = self.attend(decoder_input)
|
||||
|
||||
#mel_output, stop_output, alignment = self.decode(decoder_input)
|
||||
decoder_rnn_output = self.decode(decoder_input_final)
|
||||
if self.concat_context_to_last:
|
||||
decoder_rnn_output = torch.cat(
|
||||
(decoder_rnn_output, context), dim=1)
|
||||
|
||||
mel_output = self.linear_projection(decoder_rnn_output)
|
||||
# (B, 1)
|
||||
stop_output = self.stop_layer(decoder_rnn_output)
|
||||
stop_outputs += [stop_output.squeeze()]
|
||||
# stop_outputs.append(stop_output)
|
||||
|
||||
mel_outputs += [mel_output.squeeze(1)]
|
||||
alignments += [alignment]
|
||||
# print(stop_output.shape)
|
||||
if torch.all(torch.sigmoid(stop_output.squeeze().data) > stop_threshold) \
|
||||
and len(mel_outputs) >= min_decoder_step:
|
||||
break
|
||||
if len(mel_outputs) >= max_decoder_step:
|
||||
# print("Warning! Decoding steps reaches max decoder steps.")
|
||||
break
|
||||
|
||||
decoder_input = mel_output[:,-self.num_mels:]
|
||||
|
||||
|
||||
mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
|
||||
mel_outputs, alignments, stop_outputs)
|
||||
mel_outputs_stacked = []
|
||||
for mel, stop_logit in zip(mel_outputs, stop_outputs):
|
||||
idx = np.argwhere(torch.sigmoid(stop_logit.cpu()) > stop_threshold)[0][0].item()
|
||||
mel_outputs_stacked.append(mel[:idx,:])
|
||||
mel_outputs = torch.cat(mel_outputs_stacked, dim=0).unsqueeze(0)
|
||||
return mel_outputs, alignments
|
||||
@@ -1,62 +0,0 @@
|
||||
import sys
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from utils.hparams import HpsYaml
|
||||
from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||
|
||||
# For reproducibility, comment these may speed up training
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def main():
|
||||
# Arguments
|
||||
parser = argparse.ArgumentParser(description=
|
||||
'Training PPG2Mel VC model.')
|
||||
parser.add_argument('--config', type=str,
|
||||
help='Path to experiment config, e.g., config/vc.yaml')
|
||||
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
||||
parser.add_argument('--logdir', default='log/', type=str,
|
||||
help='Logging path.', required=False)
|
||||
parser.add_argument('--ckpdir', default='ckpt/', type=str,
|
||||
help='Checkpoint path.', required=False)
|
||||
parser.add_argument('--outdir', default='result/', type=str,
|
||||
help='Decode output path.', required=False)
|
||||
parser.add_argument('--load', default=None, type=str,
|
||||
help='Load pre-trained model (for training only)', required=False)
|
||||
parser.add_argument('--warm_start', action='store_true',
|
||||
help='Load model weights only, ignore specified layers.')
|
||||
parser.add_argument('--seed', default=0, type=int,
|
||||
help='Random seed for reproducable results.', required=False)
|
||||
parser.add_argument('--njobs', default=8, type=int,
|
||||
help='Number of threads for dataloader/decoding.', required=False)
|
||||
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
||||
# parser.add_argument('--no-pin', action='store_true',
|
||||
# help='Disable pin-memory for dataloader')
|
||||
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
||||
|
||||
###
|
||||
|
||||
paras = parser.parse_args()
|
||||
setattr(paras, 'gpu', not paras.cpu)
|
||||
setattr(paras, 'pin_memory', not paras.no_pin)
|
||||
setattr(paras, 'verbose', not paras.no_msg)
|
||||
# Make the config dict dot visitable
|
||||
config = HpsYaml(paras.config)
|
||||
|
||||
np.random.seed(paras.seed)
|
||||
torch.manual_seed(paras.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(paras.seed)
|
||||
|
||||
print(">>> OneShot VC training ...")
|
||||
mode = "train"
|
||||
solver = Solver(config, paras, mode)
|
||||
solver.load_data()
|
||||
solver.set_model()
|
||||
solver.exec()
|
||||
print(">>> Oneshot VC train finished!")
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,50 +0,0 @@
|
||||
from typing import Dict
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class MaskedMSELoss(nn.Module):
|
||||
def __init__(self, frames_per_step):
|
||||
super().__init__()
|
||||
self.frames_per_step = frames_per_step
|
||||
self.mel_loss_criterion = nn.MSELoss(reduction='none')
|
||||
# self.loss = nn.MSELoss()
|
||||
self.stop_loss_criterion = nn.BCEWithLogitsLoss(reduction='none')
|
||||
|
||||
def get_mask(self, lengths, max_len=None):
|
||||
# lengths: [B,]
|
||||
if max_len is None:
|
||||
max_len = torch.max(lengths)
|
||||
batch_size = lengths.size(0)
|
||||
seq_range = torch.arange(0, max_len).long()
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len).to(lengths.device)
|
||||
seq_length_expand = lengths.unsqueeze(1).expand_as(seq_range_expand)
|
||||
return (seq_range_expand < seq_length_expand).float()
|
||||
|
||||
def forward(self, mel_pred, mel_pred_postnet, mel_trg, lengths,
|
||||
stop_target, stop_pred):
|
||||
## process stop_target
|
||||
B = stop_target.size(0)
|
||||
stop_target = stop_target.reshape(B, -1, self.frames_per_step)[:, :, 0]
|
||||
stop_lengths = torch.ceil(lengths.float() / self.frames_per_step).long()
|
||||
stop_mask = self.get_mask(stop_lengths, int(mel_trg.size(1)/self.frames_per_step))
|
||||
|
||||
mel_trg.requires_grad = False
|
||||
# (B, T, 1)
|
||||
mel_mask = self.get_mask(lengths, mel_trg.size(1)).unsqueeze(-1)
|
||||
# (B, T, D)
|
||||
mel_mask = mel_mask.expand_as(mel_trg)
|
||||
mel_loss_pre = (self.mel_loss_criterion(mel_pred, mel_trg) * mel_mask).sum() / mel_mask.sum()
|
||||
mel_loss_post = (self.mel_loss_criterion(mel_pred_postnet, mel_trg) * mel_mask).sum() / mel_mask.sum()
|
||||
|
||||
mel_loss = mel_loss_pre + mel_loss_post
|
||||
|
||||
# stop token loss
|
||||
stop_loss = torch.sum(self.stop_loss_criterion(stop_pred, stop_target) * stop_mask) / stop_mask.sum()
|
||||
|
||||
return mel_loss, stop_loss
|
||||
@@ -1,45 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Optimizer():
|
||||
def __init__(self, parameters, optimizer, lr, eps, lr_scheduler,
|
||||
**kwargs):
|
||||
|
||||
# Setup torch optimizer
|
||||
self.opt_type = optimizer
|
||||
self.init_lr = lr
|
||||
self.sch_type = lr_scheduler
|
||||
opt = getattr(torch.optim, optimizer)
|
||||
if lr_scheduler == 'warmup':
|
||||
warmup_step = 4000.0
|
||||
init_lr = lr
|
||||
self.lr_scheduler = lambda step: init_lr * warmup_step ** 0.5 * \
|
||||
np.minimum((step+1)*warmup_step**-1.5, (step+1)**-0.5)
|
||||
self.opt = opt(parameters, lr=1.0)
|
||||
else:
|
||||
self.lr_scheduler = None
|
||||
self.opt = opt(parameters, lr=lr, eps=eps) # ToDo: 1e-8 better?
|
||||
|
||||
def get_opt_state_dict(self):
|
||||
return self.opt.state_dict()
|
||||
|
||||
def load_opt_state_dict(self, state_dict):
|
||||
self.opt.load_state_dict(state_dict)
|
||||
|
||||
def pre_step(self, step):
|
||||
if self.lr_scheduler is not None:
|
||||
cur_lr = self.lr_scheduler(step)
|
||||
for param_group in self.opt.param_groups:
|
||||
param_group['lr'] = cur_lr
|
||||
else:
|
||||
cur_lr = self.init_lr
|
||||
self.opt.zero_grad()
|
||||
return cur_lr
|
||||
|
||||
def step(self):
|
||||
self.opt.step()
|
||||
|
||||
def create_msg(self):
|
||||
return ['Optim.Info.| Algo. = {}\t| Lr = {}\t (schedule = {})'
|
||||
.format(self.opt_type, self.init_lr, self.sch_type)]
|
||||
@@ -1,10 +0,0 @@
|
||||
# Default parameters which will be imported by solver
|
||||
default_hparas = {
|
||||
'GRAD_CLIP': 5.0, # Grad. clip threshold
|
||||
'PROGRESS_STEP': 100, # Std. output refresh freq.
|
||||
# Decode steps for objective validation (step = ratio*input_txt_len)
|
||||
'DEV_STEP_RATIO': 1.2,
|
||||
# Number of examples (alignment/text) to show in tensorboard
|
||||
'DEV_N_EXAMPLE': 4,
|
||||
'TB_FLUSH_FREQ': 180 # Update frequency of tensorboard (secs)
|
||||
}
|
||||
@@ -1,216 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import abc
|
||||
import math
|
||||
import yaml
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from .option import default_hparas
|
||||
from utils.util import human_format, Timer
|
||||
|
||||
|
||||
class BaseSolver():
|
||||
'''
|
||||
Prototype Solver for all kinds of tasks
|
||||
Arguments
|
||||
config - yaml-styled config
|
||||
paras - argparse outcome
|
||||
mode - "train"/"test"
|
||||
'''
|
||||
|
||||
def __init__(self, config, paras, mode="train"):
|
||||
# General Settings
|
||||
self.config = config # load from yaml file
|
||||
self.paras = paras # command line args
|
||||
self.mode = mode # 'train' or 'test'
|
||||
for k, v in default_hparas.items():
|
||||
setattr(self, k, v)
|
||||
self.device = torch.device('cuda') if self.paras.gpu and torch.cuda.is_available() \
|
||||
else torch.device('cpu')
|
||||
|
||||
# Name experiment
|
||||
self.exp_name = paras.name
|
||||
if self.exp_name is None:
|
||||
if 'exp_name' in self.config:
|
||||
self.exp_name = self.config.exp_name
|
||||
else:
|
||||
# By default, exp is named after config file
|
||||
self.exp_name = paras.config.split('/')[-1].replace('.yaml', '')
|
||||
if mode == 'train':
|
||||
self.exp_name += '_seed{}'.format(paras.seed)
|
||||
|
||||
|
||||
if mode == 'train':
|
||||
# Filepath setup
|
||||
os.makedirs(paras.ckpdir, exist_ok=True)
|
||||
self.ckpdir = os.path.join(paras.ckpdir, self.exp_name)
|
||||
os.makedirs(self.ckpdir, exist_ok=True)
|
||||
|
||||
# Logger settings
|
||||
self.logdir = os.path.join(paras.logdir, self.exp_name)
|
||||
self.log = SummaryWriter(
|
||||
self.logdir, flush_secs=self.TB_FLUSH_FREQ)
|
||||
self.timer = Timer()
|
||||
|
||||
# Hyper-parameters
|
||||
self.step = 0
|
||||
self.valid_step = config.hparas.valid_step
|
||||
self.max_step = config.hparas.max_step
|
||||
|
||||
self.verbose('Exp. name : {}'.format(self.exp_name))
|
||||
self.verbose('Loading data... large corpus may took a while.')
|
||||
|
||||
# elif mode == 'test':
|
||||
# # Output path
|
||||
# os.makedirs(paras.outdir, exist_ok=True)
|
||||
# self.ckpdir = os.path.join(paras.outdir, self.exp_name)
|
||||
|
||||
# Load training config to get acoustic feat and build model
|
||||
# self.src_config = HpsYaml(config.src.config)
|
||||
# self.paras.load = config.src.ckpt
|
||||
|
||||
# self.verbose('Evaluating result of tr. config @ {}'.format(
|
||||
# config.src.config))
|
||||
|
||||
def backward(self, loss):
|
||||
'''
|
||||
Standard backward step with self.timer and debugger
|
||||
Arguments
|
||||
loss - the loss to perform loss.backward()
|
||||
'''
|
||||
self.timer.set()
|
||||
loss.backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(), self.GRAD_CLIP)
|
||||
if math.isnan(grad_norm):
|
||||
self.verbose('Error : grad norm is NaN @ step '+str(self.step))
|
||||
else:
|
||||
self.optimizer.step()
|
||||
self.timer.cnt('bw')
|
||||
return grad_norm
|
||||
|
||||
def load_ckpt(self):
|
||||
''' Load ckpt if --load option is specified '''
|
||||
print(self.paras)
|
||||
if self.paras.load is not None:
|
||||
if self.paras.warm_start:
|
||||
self.verbose(f"Warm starting model from checkpoint {self.paras.load}.")
|
||||
ckpt = torch.load(
|
||||
self.paras.load, map_location=self.device if self.mode == 'train'
|
||||
else 'cpu')
|
||||
model_dict = ckpt['model']
|
||||
if "ignore_layers" in self.config.model and len(self.config.model.ignore_layers) > 0:
|
||||
model_dict = {k:v for k, v in model_dict.items()
|
||||
if k not in self.config.model.ignore_layers}
|
||||
dummy_dict = self.model.state_dict()
|
||||
dummy_dict.update(model_dict)
|
||||
model_dict = dummy_dict
|
||||
self.model.load_state_dict(model_dict)
|
||||
else:
|
||||
# Load weights
|
||||
ckpt = torch.load(
|
||||
self.paras.load, map_location=self.device if self.mode == 'train'
|
||||
else 'cpu')
|
||||
self.model.load_state_dict(ckpt['model'])
|
||||
|
||||
# Load task-dependent items
|
||||
if self.mode == 'train':
|
||||
self.step = ckpt['global_step']
|
||||
self.optimizer.load_opt_state_dict(ckpt['optimizer'])
|
||||
self.verbose('Load ckpt from {}, restarting at step {}'.format(
|
||||
self.paras.load, self.step))
|
||||
else:
|
||||
for k, v in ckpt.items():
|
||||
if type(v) is float:
|
||||
metric, score = k, v
|
||||
self.model.eval()
|
||||
self.verbose('Evaluation target = {} (recorded {} = {:.2f} %)'.format(
|
||||
self.paras.load, metric, score))
|
||||
|
||||
def verbose(self, msg):
|
||||
''' Verbose function for print information to stdout'''
|
||||
if self.paras.verbose:
|
||||
if type(msg) == list:
|
||||
for m in msg:
|
||||
print('[INFO]', m.ljust(100))
|
||||
else:
|
||||
print('[INFO]', msg.ljust(100))
|
||||
|
||||
def progress(self, msg):
|
||||
''' Verbose function for updating progress on stdout (do not include newline) '''
|
||||
if self.paras.verbose:
|
||||
sys.stdout.write("\033[K") # Clear line
|
||||
print('[{}] {}'.format(human_format(self.step), msg), end='\r')
|
||||
|
||||
def write_log(self, log_name, log_dict):
|
||||
'''
|
||||
Write log to TensorBoard
|
||||
log_name - <str> Name of tensorboard variable
|
||||
log_value - <dict>/<array> Value of variable (e.g. dict of losses), passed if value = None
|
||||
'''
|
||||
if type(log_dict) is dict:
|
||||
log_dict = {key: val for key, val in log_dict.items() if (
|
||||
val is not None and not math.isnan(val))}
|
||||
if log_dict is None:
|
||||
pass
|
||||
elif len(log_dict) > 0:
|
||||
if 'align' in log_name or 'spec' in log_name:
|
||||
img, form = log_dict
|
||||
self.log.add_image(
|
||||
log_name, img, global_step=self.step, dataformats=form)
|
||||
elif 'text' in log_name or 'hyp' in log_name:
|
||||
self.log.add_text(log_name, log_dict, self.step)
|
||||
else:
|
||||
self.log.add_scalars(log_name, log_dict, self.step)
|
||||
|
||||
def save_checkpoint(self, f_name, metric, score, show_msg=True):
|
||||
''''
|
||||
Ckpt saver
|
||||
f_name - <str> the name of ckpt file (w/o prefix) to store, overwrite if existed
|
||||
score - <float> The value of metric used to evaluate model
|
||||
'''
|
||||
ckpt_path = os.path.join(self.ckpdir, f_name)
|
||||
full_dict = {
|
||||
"model": self.model.state_dict(),
|
||||
"optimizer": self.optimizer.get_opt_state_dict(),
|
||||
"global_step": self.step,
|
||||
metric: score
|
||||
}
|
||||
|
||||
torch.save(full_dict, ckpt_path)
|
||||
if show_msg:
|
||||
self.verbose("Saved checkpoint (step = {}, {} = {:.2f}) and status @ {}".
|
||||
format(human_format(self.step), metric, score, ckpt_path))
|
||||
|
||||
|
||||
# ----------------------------------- Abtract Methods ------------------------------------------ #
|
||||
@abc.abstractmethod
|
||||
def load_data(self):
|
||||
'''
|
||||
Called by main to load all data
|
||||
After this call, data related attributes should be setup (e.g. self.tr_set, self.dev_set)
|
||||
No return value
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_model(self):
|
||||
'''
|
||||
Called by main to set models
|
||||
After this call, model related attributes should be setup (e.g. self.l2_loss)
|
||||
The followings MUST be setup
|
||||
- self.model (torch.nn.Module)
|
||||
- self.optimizer (src.Optimizer),
|
||||
init. w/ self.optimizer = src.Optimizer(self.model.parameters(),**self.config['hparas'])
|
||||
Loading pre-trained model should also be performed here
|
||||
No return value
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def exec(self):
|
||||
'''
|
||||
Called by main to execute training/inference
|
||||
'''
|
||||
raise NotImplementedError
|
||||
@@ -1,288 +0,0 @@
|
||||
import os, sys
|
||||
# sys.path.append('/home/shaunxliu/projects/nnsp')
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
from .solver import BaseSolver
|
||||
from utils.data_load import OneshotVcDataset, MultiSpkVcCollate
|
||||
# from src.rnn_ppg2mel import BiRnnPpg2MelModel
|
||||
# from src.mel_decoder_mol_encAddlf0 import MelDecoderMOL
|
||||
from .loss import MaskedMSELoss
|
||||
from .optim import Optimizer
|
||||
from utils.util import human_format
|
||||
from models.ppg2mel import MelDecoderMOLv2
|
||||
|
||||
|
||||
class Solver(BaseSolver):
|
||||
"""Customized Solver."""
|
||||
def __init__(self, config, paras, mode):
|
||||
super().__init__(config, paras, mode)
|
||||
self.num_att_plots = 5
|
||||
self.att_ws_dir = f"{self.logdir}/att_ws"
|
||||
os.makedirs(self.att_ws_dir, exist_ok=True)
|
||||
self.best_loss = np.inf
|
||||
|
||||
def fetch_data(self, data):
|
||||
"""Move data to device"""
|
||||
data = [i.to(self.device) for i in data]
|
||||
return data
|
||||
|
||||
def load_data(self):
|
||||
""" Load data for training/validation/plotting."""
|
||||
train_dataset = OneshotVcDataset(
|
||||
meta_file=self.config.data.train_fid_list,
|
||||
vctk_ppg_dir=self.config.data.vctk_ppg_dir,
|
||||
libri_ppg_dir=self.config.data.libri_ppg_dir,
|
||||
vctk_f0_dir=self.config.data.vctk_f0_dir,
|
||||
libri_f0_dir=self.config.data.libri_f0_dir,
|
||||
vctk_wav_dir=self.config.data.vctk_wav_dir,
|
||||
libri_wav_dir=self.config.data.libri_wav_dir,
|
||||
vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir,
|
||||
libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir,
|
||||
ppg_file_ext=self.config.data.ppg_file_ext,
|
||||
min_max_norm_mel=self.config.data.min_max_norm_mel,
|
||||
mel_min=self.config.data.mel_min,
|
||||
mel_max=self.config.data.mel_max,
|
||||
)
|
||||
dev_dataset = OneshotVcDataset(
|
||||
meta_file=self.config.data.dev_fid_list,
|
||||
vctk_ppg_dir=self.config.data.vctk_ppg_dir,
|
||||
libri_ppg_dir=self.config.data.libri_ppg_dir,
|
||||
vctk_f0_dir=self.config.data.vctk_f0_dir,
|
||||
libri_f0_dir=self.config.data.libri_f0_dir,
|
||||
vctk_wav_dir=self.config.data.vctk_wav_dir,
|
||||
libri_wav_dir=self.config.data.libri_wav_dir,
|
||||
vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir,
|
||||
libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir,
|
||||
ppg_file_ext=self.config.data.ppg_file_ext,
|
||||
min_max_norm_mel=self.config.data.min_max_norm_mel,
|
||||
mel_min=self.config.data.mel_min,
|
||||
mel_max=self.config.data.mel_max,
|
||||
)
|
||||
self.train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
num_workers=self.paras.njobs,
|
||||
shuffle=True,
|
||||
batch_size=self.config.hparas.batch_size,
|
||||
pin_memory=False,
|
||||
drop_last=True,
|
||||
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
|
||||
use_spk_dvec=True),
|
||||
)
|
||||
self.dev_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
num_workers=self.paras.njobs,
|
||||
shuffle=False,
|
||||
batch_size=self.config.hparas.batch_size,
|
||||
pin_memory=False,
|
||||
drop_last=False,
|
||||
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
|
||||
use_spk_dvec=True),
|
||||
)
|
||||
self.plot_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
num_workers=self.paras.njobs,
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
pin_memory=False,
|
||||
drop_last=False,
|
||||
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
|
||||
use_spk_dvec=True,
|
||||
give_uttids=True),
|
||||
)
|
||||
msg = "Have prepared training set and dev set."
|
||||
self.verbose(msg)
|
||||
|
||||
def load_pretrained_params(self):
|
||||
print("Load pretrained model from: ", self.config.data.pretrain_model_file)
|
||||
ignore_layer_prefixes = ["speaker_embedding_table"]
|
||||
pretrain_model_file = self.config.data.pretrain_model_file
|
||||
pretrain_ckpt = torch.load(
|
||||
pretrain_model_file, map_location=self.device
|
||||
)["model"]
|
||||
model_dict = self.model.state_dict()
|
||||
print(self.model)
|
||||
|
||||
# 1. filter out unnecessrary keys
|
||||
for prefix in ignore_layer_prefixes:
|
||||
pretrain_ckpt = {k : v
|
||||
for k, v in pretrain_ckpt.items() if not k.startswith(prefix)
|
||||
}
|
||||
# 2. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrain_ckpt)
|
||||
|
||||
# 3. load the new state dict
|
||||
self.model.load_state_dict(model_dict)
|
||||
|
||||
def set_model(self):
|
||||
"""Setup model and optimizer"""
|
||||
# Model
|
||||
print("[INFO] Model name: ", self.config["model_name"])
|
||||
self.model = MelDecoderMOLv2(
|
||||
**self.config["model"]
|
||||
).to(self.device)
|
||||
# self.load_pretrained_params()
|
||||
|
||||
# model_params = [{'params': self.model.spk_embedding.weight}]
|
||||
model_params = [{'params': self.model.parameters()}]
|
||||
|
||||
# Loss criterion
|
||||
self.loss_criterion = MaskedMSELoss(self.config.model.frames_per_step)
|
||||
|
||||
# Optimizer
|
||||
self.optimizer = Optimizer(model_params, **self.config["hparas"])
|
||||
self.verbose(self.optimizer.create_msg())
|
||||
|
||||
# Automatically load pre-trained model if self.paras.load is given
|
||||
self.load_ckpt()
|
||||
|
||||
def exec(self):
|
||||
self.verbose("Total training steps {}.".format(
|
||||
human_format(self.max_step)))
|
||||
|
||||
mel_loss = None
|
||||
n_epochs = 0
|
||||
# Set as current time
|
||||
self.timer.set()
|
||||
|
||||
while self.step < self.max_step:
|
||||
for data in self.train_dataloader:
|
||||
# Pre-step: updata lr_rate and do zero_grad
|
||||
lr_rate = self.optimizer.pre_step(self.step)
|
||||
total_loss = 0
|
||||
# data to device
|
||||
ppgs, lf0_uvs, mels, in_lengths, \
|
||||
out_lengths, spk_ids, stop_tokens = self.fetch_data(data)
|
||||
self.timer.cnt("rd")
|
||||
mel_outputs, mel_outputs_postnet, predicted_stop = self.model(
|
||||
ppgs,
|
||||
in_lengths,
|
||||
mels,
|
||||
out_lengths,
|
||||
lf0_uvs,
|
||||
spk_ids
|
||||
)
|
||||
mel_loss, stop_loss = self.loss_criterion(
|
||||
mel_outputs,
|
||||
mel_outputs_postnet,
|
||||
mels,
|
||||
out_lengths,
|
||||
stop_tokens,
|
||||
predicted_stop
|
||||
)
|
||||
loss = mel_loss + stop_loss
|
||||
|
||||
self.timer.cnt("fw")
|
||||
|
||||
# Back-prop
|
||||
grad_norm = self.backward(loss)
|
||||
self.step += 1
|
||||
|
||||
# Logger
|
||||
if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
|
||||
self.progress("Tr|loss:{:.4f},mel-loss:{:.4f},stop-loss:{:.4f}|Grad.Norm-{:.2f}|{}"
|
||||
.format(loss.cpu().item(), mel_loss.cpu().item(),
|
||||
stop_loss.cpu().item(), grad_norm, self.timer.show()))
|
||||
self.write_log('loss', {'tr/loss': loss,
|
||||
'tr/mel-loss': mel_loss,
|
||||
'tr/stop-loss': stop_loss})
|
||||
|
||||
# Validation
|
||||
if (self.step == 1) or (self.step % self.valid_step == 0):
|
||||
self.validate()
|
||||
|
||||
# End of step
|
||||
# https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
|
||||
torch.cuda.empty_cache()
|
||||
self.timer.set()
|
||||
if self.step > self.max_step:
|
||||
break
|
||||
n_epochs += 1
|
||||
self.log.close()
|
||||
|
||||
def validate(self):
|
||||
self.model.eval()
|
||||
dev_loss, dev_mel_loss, dev_stop_loss = 0.0, 0.0, 0.0
|
||||
|
||||
for i, data in enumerate(self.dev_dataloader):
|
||||
self.progress('Valid step - {}/{}'.format(i+1, len(self.dev_dataloader)))
|
||||
# Fetch data
|
||||
ppgs, lf0_uvs, mels, in_lengths, \
|
||||
out_lengths, spk_ids, stop_tokens = self.fetch_data(data)
|
||||
with torch.no_grad():
|
||||
mel_outputs, mel_outputs_postnet, predicted_stop = self.model(
|
||||
ppgs,
|
||||
in_lengths,
|
||||
mels,
|
||||
out_lengths,
|
||||
lf0_uvs,
|
||||
spk_ids
|
||||
)
|
||||
mel_loss, stop_loss = self.loss_criterion(
|
||||
mel_outputs,
|
||||
mel_outputs_postnet,
|
||||
mels,
|
||||
out_lengths,
|
||||
stop_tokens,
|
||||
predicted_stop
|
||||
)
|
||||
loss = mel_loss + stop_loss
|
||||
|
||||
dev_loss += loss.cpu().item()
|
||||
dev_mel_loss += mel_loss.cpu().item()
|
||||
dev_stop_loss += stop_loss.cpu().item()
|
||||
|
||||
dev_loss = dev_loss / (i + 1)
|
||||
dev_mel_loss = dev_mel_loss / (i + 1)
|
||||
dev_stop_loss = dev_stop_loss / (i + 1)
|
||||
self.save_checkpoint(f'step_{self.step}.pth', 'loss', dev_loss, show_msg=False)
|
||||
if dev_loss < self.best_loss:
|
||||
self.best_loss = dev_loss
|
||||
self.save_checkpoint(f'best_loss_step_{self.step}.pth', 'loss', dev_loss)
|
||||
self.write_log('loss', {'dv/loss': dev_loss,
|
||||
'dv/mel-loss': dev_mel_loss,
|
||||
'dv/stop-loss': dev_stop_loss})
|
||||
|
||||
# plot attention
|
||||
for i, data in enumerate(self.plot_dataloader):
|
||||
if i == self.num_att_plots:
|
||||
break
|
||||
# Fetch data
|
||||
ppgs, lf0_uvs, mels, in_lengths, \
|
||||
out_lengths, spk_ids, stop_tokens = self.fetch_data(data[:-1])
|
||||
fid = data[-1][0]
|
||||
with torch.no_grad():
|
||||
_, _, _, att_ws = self.model(
|
||||
ppgs,
|
||||
in_lengths,
|
||||
mels,
|
||||
out_lengths,
|
||||
lf0_uvs,
|
||||
spk_ids,
|
||||
output_att_ws=True
|
||||
)
|
||||
att_ws = att_ws.squeeze(0).cpu().numpy()
|
||||
att_ws = att_ws[None]
|
||||
w, h = plt.figaspect(1.0 / len(att_ws))
|
||||
fig = plt.Figure(figsize=(w * 1.3, h * 1.3))
|
||||
axes = fig.subplots(1, len(att_ws))
|
||||
if len(att_ws) == 1:
|
||||
axes = [axes]
|
||||
|
||||
for ax, aw in zip(axes, att_ws):
|
||||
ax.imshow(aw.astype(np.float32), aspect="auto")
|
||||
ax.set_title(f"{fid}")
|
||||
ax.set_xlabel("Input")
|
||||
ax.set_ylabel("Output")
|
||||
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||||
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
|
||||
fig_name = f"{self.att_ws_dir}/{fid}_step{self.step}.png"
|
||||
fig.savefig(fig_name)
|
||||
|
||||
# Resume training
|
||||
self.model.train()
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
class AbsMelDecoder(torch.nn.Module, ABC):
|
||||
"""The abstract PPG-based voice conversion class
|
||||
This "model" is one of mediator objects for "Task" class.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
bottle_neck_features: torch.Tensor,
|
||||
feature_lengths: torch.Tensor,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
logf0_uv: torch.Tensor = None,
|
||||
spembs: torch.Tensor = None,
|
||||
styleembs: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
@@ -1,79 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
def tile(x, count, dim=0):
|
||||
"""
|
||||
Tiles x on dimension dim count times.
|
||||
"""
|
||||
perm = list(range(len(x.size())))
|
||||
if dim != 0:
|
||||
perm[0], perm[dim] = perm[dim], perm[0]
|
||||
x = x.permute(perm).contiguous()
|
||||
out_size = list(x.size())
|
||||
out_size[0] *= count
|
||||
batch = x.size(0)
|
||||
x = x.view(batch, -1) \
|
||||
.transpose(0, 1) \
|
||||
.repeat(count, 1) \
|
||||
.transpose(0, 1) \
|
||||
.contiguous() \
|
||||
.view(*out_size)
|
||||
if dim != 0:
|
||||
x = x.permute(perm).contiguous()
|
||||
return x
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
||||
super(Linear, self).__init__()
|
||||
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
||||
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear_layer.weight,
|
||||
gain=torch.nn.init.calculate_gain(w_init_gain))
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_layer(x)
|
||||
|
||||
class Conv1d(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
|
||||
padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
|
||||
super(Conv1d, self).__init__()
|
||||
if padding is None:
|
||||
assert(kernel_size % 2 == 1)
|
||||
padding = int(dilation * (kernel_size - 1)/2)
|
||||
|
||||
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation,
|
||||
bias=bias)
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
|
||||
|
||||
def forward(self, x):
|
||||
# x: BxDxT
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
|
||||
def tile(x, count, dim=0):
|
||||
"""
|
||||
Tiles x on dimension dim count times.
|
||||
"""
|
||||
perm = list(range(len(x.size())))
|
||||
if dim != 0:
|
||||
perm[0], perm[dim] = perm[dim], perm[0]
|
||||
x = x.permute(perm).contiguous()
|
||||
out_size = list(x.size())
|
||||
out_size[0] *= count
|
||||
batch = x.size(0)
|
||||
x = x.view(batch, -1) \
|
||||
.transpose(0, 1) \
|
||||
.repeat(count, 1) \
|
||||
.transpose(0, 1) \
|
||||
.contiguous() \
|
||||
.view(*out_size)
|
||||
if dim != 0:
|
||||
x = x.permute(perm).contiguous()
|
||||
return x
|
||||
@@ -1,52 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .basic_layers import Linear, Conv1d
|
||||
|
||||
|
||||
class Postnet(nn.Module):
|
||||
"""Postnet
|
||||
- Five 1-d convolution with 512 channels and kernel size 5
|
||||
"""
|
||||
def __init__(self, num_mels=80,
|
||||
num_layers=5,
|
||||
hidden_dim=512,
|
||||
kernel_size=5):
|
||||
super(Postnet, self).__init__()
|
||||
self.convolutions = nn.ModuleList()
|
||||
|
||||
self.convolutions.append(
|
||||
nn.Sequential(
|
||||
Conv1d(
|
||||
num_mels, hidden_dim,
|
||||
kernel_size=kernel_size, stride=1,
|
||||
padding=int((kernel_size - 1) / 2),
|
||||
dilation=1, w_init_gain='tanh'),
|
||||
nn.BatchNorm1d(hidden_dim)))
|
||||
|
||||
for i in range(1, num_layers - 1):
|
||||
self.convolutions.append(
|
||||
nn.Sequential(
|
||||
Conv1d(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
kernel_size=kernel_size, stride=1,
|
||||
padding=int((kernel_size - 1) / 2),
|
||||
dilation=1, w_init_gain='tanh'),
|
||||
nn.BatchNorm1d(hidden_dim)))
|
||||
|
||||
self.convolutions.append(
|
||||
nn.Sequential(
|
||||
Conv1d(
|
||||
hidden_dim, num_mels,
|
||||
kernel_size=kernel_size, stride=1,
|
||||
padding=int((kernel_size - 1) / 2),
|
||||
dilation=1, w_init_gain='linear'),
|
||||
nn.BatchNorm1d(num_mels)))
|
||||
|
||||
def forward(self, x):
|
||||
# x: (B, num_mels, T_dec)
|
||||
for i in range(len(self.convolutions) - 1):
|
||||
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
|
||||
x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
|
||||
return x
|
||||
@@ -1,123 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class MOLAttention(nn.Module):
|
||||
""" Discretized Mixture of Logistic (MOL) attention.
|
||||
C.f. Section 5 of "MelNet: A Generative Model for Audio in the Frequency Domain" and
|
||||
GMMv2b model in "Location-relative attention mechanisms for robust long-form speech synthesis".
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
query_dim,
|
||||
r=1,
|
||||
M=5,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
query_dim: attention_rnn_dim.
|
||||
M: number of mixtures.
|
||||
"""
|
||||
super().__init__()
|
||||
if r < 1:
|
||||
self.r = float(r)
|
||||
else:
|
||||
self.r = int(r)
|
||||
self.M = M
|
||||
self.score_mask_value = 0.0 # -float("inf")
|
||||
self.eps = 1e-5
|
||||
# Position arrary for encoder time steps
|
||||
self.J = None
|
||||
# Query layer: [w, sigma,]
|
||||
self.query_layer = torch.nn.Sequential(
|
||||
nn.Linear(query_dim, 256, bias=True),
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 3*M, bias=True)
|
||||
)
|
||||
self.mu_prev = None
|
||||
self.initialize_bias()
|
||||
|
||||
def initialize_bias(self):
|
||||
"""Initialize sigma and Delta."""
|
||||
# sigma
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[self.M:2*self.M], 1.0)
|
||||
# Delta: softplus(1.8545) = 2.0; softplus(3.9815) = 4.0; softplus(0.5413) = 1.0
|
||||
# softplus(-0.432) = 0.5003
|
||||
if self.r == 2:
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 1.8545)
|
||||
elif self.r == 4:
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 3.9815)
|
||||
elif self.r == 1:
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 0.5413)
|
||||
else:
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], -0.432)
|
||||
|
||||
|
||||
def init_states(self, memory):
|
||||
"""Initialize mu_prev and J.
|
||||
This function should be called by the decoder before decoding one batch.
|
||||
Args:
|
||||
memory: (B, T, D_enc) encoder output.
|
||||
"""
|
||||
B, T_enc, _ = memory.size()
|
||||
device = memory.device
|
||||
self.J = torch.arange(0, T_enc + 2.0).to(device) + 0.5 # NOTE: for discretize usage
|
||||
# self.J = memory.new_tensor(np.arange(T_enc), dtype=torch.float)
|
||||
self.mu_prev = torch.zeros(B, self.M).to(device)
|
||||
|
||||
def forward(self, att_rnn_h, memory, memory_pitch=None, mask=None):
|
||||
"""
|
||||
att_rnn_h: attetion rnn hidden state.
|
||||
memory: encoder outputs (B, T_enc, D).
|
||||
mask: binary mask for padded data (B, T_enc).
|
||||
"""
|
||||
# [B, 3M]
|
||||
mixture_params = self.query_layer(att_rnn_h)
|
||||
|
||||
# [B, M]
|
||||
w_hat = mixture_params[:, :self.M]
|
||||
sigma_hat = mixture_params[:, self.M:2*self.M]
|
||||
Delta_hat = mixture_params[:, 2*self.M:3*self.M]
|
||||
|
||||
# print("w_hat: ", w_hat)
|
||||
# print("sigma_hat: ", sigma_hat)
|
||||
# print("Delta_hat: ", Delta_hat)
|
||||
|
||||
# Dropout to de-correlate attention heads
|
||||
w_hat = F.dropout(w_hat, p=0.5, training=self.training) # NOTE(sx): needed?
|
||||
|
||||
# Mixture parameters
|
||||
w = torch.softmax(w_hat, dim=-1) + self.eps
|
||||
sigma = F.softplus(sigma_hat) + self.eps
|
||||
Delta = F.softplus(Delta_hat)
|
||||
mu_cur = self.mu_prev + Delta
|
||||
# print("w:", w)
|
||||
j = self.J[:memory.size(1) + 1]
|
||||
|
||||
# Attention weights
|
||||
# CDF of logistic distribution
|
||||
phi_t = w.unsqueeze(-1) * (1 / (1 + torch.sigmoid(
|
||||
(mu_cur.unsqueeze(-1) - j) / sigma.unsqueeze(-1))))
|
||||
# print("phi_t:", phi_t)
|
||||
|
||||
# Discretize attention weights
|
||||
# (B, T_enc + 1)
|
||||
alpha_t = torch.sum(phi_t, dim=1)
|
||||
alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1]
|
||||
alpha_t[alpha_t == 0] = self.eps
|
||||
# print("alpha_t: ", alpha_t.size())
|
||||
# Apply masking
|
||||
if mask is not None:
|
||||
alpha_t.data.masked_fill_(mask, self.score_mask_value)
|
||||
|
||||
context = torch.bmm(alpha_t.unsqueeze(1), memory).squeeze(1)
|
||||
if memory_pitch is not None:
|
||||
context_pitch = torch.bmm(alpha_t.unsqueeze(1), memory_pitch).squeeze(1)
|
||||
|
||||
self.mu_prev = mu_cur
|
||||
|
||||
if memory_pitch is not None:
|
||||
return context, context_pitch, alpha_t
|
||||
return context, alpha_t
|
||||
|
||||
@@ -1,451 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""Network related utility tools."""
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def to_device(m, x):
|
||||
"""Send tensor into the device of the module.
|
||||
|
||||
Args:
|
||||
m (torch.nn.Module): Torch module.
|
||||
x (Tensor): Torch tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: Torch tensor located in the same place as torch module.
|
||||
|
||||
"""
|
||||
assert isinstance(m, torch.nn.Module)
|
||||
device = next(m.parameters()).device
|
||||
return x.to(device)
|
||||
|
||||
|
||||
def pad_list(xs, pad_value):
|
||||
"""Perform padding for the list of tensors.
|
||||
|
||||
Args:
|
||||
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
||||
pad_value (float): Value for padding.
|
||||
|
||||
Returns:
|
||||
Tensor: Padded tensor (B, Tmax, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
||||
>>> x
|
||||
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
||||
>>> pad_list(x, 0)
|
||||
tensor([[1., 1., 1., 1.],
|
||||
[1., 1., 0., 0.],
|
||||
[1., 0., 0., 0.]])
|
||||
|
||||
"""
|
||||
n_batch = len(xs)
|
||||
max_len = max(x.size(0) for x in xs)
|
||||
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
||||
|
||||
for i in range(n_batch):
|
||||
pad[i, :xs[i].size(0)] = xs[i]
|
||||
|
||||
return pad
|
||||
|
||||
|
||||
def make_pad_mask(lengths, xs=None, length_dim=-1):
|
||||
"""Make mask tensor containing indices of padded part.
|
||||
|
||||
Args:
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
|
||||
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
|
||||
|
||||
Returns:
|
||||
Tensor: Mask tensor containing indices of padded part.
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||
|
||||
Examples:
|
||||
With only lengths.
|
||||
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_non_pad_mask(lengths)
|
||||
masks = [[0, 0, 0, 0 ,0],
|
||||
[0, 0, 0, 1, 1],
|
||||
[0, 0, 1, 1, 1]]
|
||||
|
||||
With the reference tensor.
|
||||
|
||||
>>> xs = torch.zeros((3, 2, 4))
|
||||
>>> make_pad_mask(lengths, xs)
|
||||
tensor([[[0, 0, 0, 0],
|
||||
[0, 0, 0, 0]],
|
||||
[[0, 0, 0, 1],
|
||||
[0, 0, 0, 1]],
|
||||
[[0, 0, 1, 1],
|
||||
[0, 0, 1, 1]]], dtype=torch.uint8)
|
||||
>>> xs = torch.zeros((3, 2, 6))
|
||||
>>> make_pad_mask(lengths, xs)
|
||||
tensor([[[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1]],
|
||||
[[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1]],
|
||||
[[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
|
||||
With the reference tensor and dimension indicator.
|
||||
|
||||
>>> xs = torch.zeros((3, 6, 6))
|
||||
>>> make_pad_mask(lengths, xs, 1)
|
||||
tensor([[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1]],
|
||||
[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1]],
|
||||
[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
>>> make_pad_mask(lengths, xs, 2)
|
||||
tensor([[[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1]],
|
||||
[[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1]],
|
||||
[[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
|
||||
"""
|
||||
if length_dim == 0:
|
||||
raise ValueError('length_dim cannot be 0: {}'.format(length_dim))
|
||||
|
||||
if not isinstance(lengths, list):
|
||||
lengths = lengths.tolist()
|
||||
bs = int(len(lengths))
|
||||
if xs is None:
|
||||
maxlen = int(max(lengths))
|
||||
else:
|
||||
maxlen = xs.size(length_dim)
|
||||
|
||||
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
||||
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
|
||||
if xs is not None:
|
||||
assert xs.size(0) == bs, (xs.size(0), bs)
|
||||
|
||||
if length_dim < 0:
|
||||
length_dim = xs.dim() + length_dim
|
||||
# ind = (:, None, ..., None, :, , None, ..., None)
|
||||
ind = tuple(slice(None) if i in (0, length_dim) else None
|
||||
for i in range(xs.dim()))
|
||||
mask = mask[ind].expand_as(xs).to(xs.device)
|
||||
return mask
|
||||
|
||||
|
||||
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
|
||||
"""Make mask tensor containing indices of non-padded part.
|
||||
|
||||
Args:
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
|
||||
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
|
||||
|
||||
Returns:
|
||||
ByteTensor: mask tensor containing indices of padded part.
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||
|
||||
Examples:
|
||||
With only lengths.
|
||||
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_non_pad_mask(lengths)
|
||||
masks = [[1, 1, 1, 1 ,1],
|
||||
[1, 1, 1, 0, 0],
|
||||
[1, 1, 0, 0, 0]]
|
||||
|
||||
With the reference tensor.
|
||||
|
||||
>>> xs = torch.zeros((3, 2, 4))
|
||||
>>> make_non_pad_mask(lengths, xs)
|
||||
tensor([[[1, 1, 1, 1],
|
||||
[1, 1, 1, 1]],
|
||||
[[1, 1, 1, 0],
|
||||
[1, 1, 1, 0]],
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0]]], dtype=torch.uint8)
|
||||
>>> xs = torch.zeros((3, 2, 6))
|
||||
>>> make_non_pad_mask(lengths, xs)
|
||||
tensor([[[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0]],
|
||||
[[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0]],
|
||||
[[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
|
||||
With the reference tensor and dimension indicator.
|
||||
|
||||
>>> xs = torch.zeros((3, 6, 6))
|
||||
>>> make_non_pad_mask(lengths, xs, 1)
|
||||
tensor([[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0]],
|
||||
[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0]],
|
||||
[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
>>> make_non_pad_mask(lengths, xs, 2)
|
||||
tensor([[[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0]],
|
||||
[[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0]],
|
||||
[[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
|
||||
"""
|
||||
return ~make_pad_mask(lengths, xs, length_dim)
|
||||
|
||||
|
||||
def mask_by_length(xs, lengths, fill=0):
|
||||
"""Mask tensor according to length.
|
||||
|
||||
Args:
|
||||
xs (Tensor): Batch of input tensor (B, `*`).
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
fill (int or float): Value to fill masked part.
|
||||
|
||||
Returns:
|
||||
Tensor: Batch of masked input tensor (B, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = torch.arange(5).repeat(3, 1) + 1
|
||||
>>> x
|
||||
tensor([[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 4, 5]])
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> mask_by_length(x, lengths)
|
||||
tensor([[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 0, 0],
|
||||
[1, 2, 0, 0, 0]])
|
||||
|
||||
"""
|
||||
assert xs.size(0) == len(lengths)
|
||||
ret = xs.data.new(*xs.size()).fill_(fill)
|
||||
for i, l in enumerate(lengths):
|
||||
ret[i, :l] = xs[i, :l]
|
||||
return ret
|
||||
|
||||
|
||||
def th_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||
"""Calculate accuracy.
|
||||
|
||||
Args:
|
||||
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||||
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
||||
ignore_label (int): Ignore label id.
|
||||
|
||||
Returns:
|
||||
float: Accuracy value (0.0 - 1.0).
|
||||
|
||||
"""
|
||||
pad_pred = pad_outputs.view(
|
||||
pad_targets.size(0),
|
||||
pad_targets.size(1),
|
||||
pad_outputs.size(1)).argmax(2)
|
||||
mask = pad_targets != ignore_label
|
||||
numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||||
denominator = torch.sum(mask)
|
||||
return float(numerator) / float(denominator)
|
||||
|
||||
|
||||
def to_torch_tensor(x):
|
||||
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
|
||||
|
||||
Args:
|
||||
x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
|
||||
|
||||
Returns:
|
||||
Tensor or ComplexTensor: Type converted inputs.
|
||||
|
||||
Examples:
|
||||
>>> xs = np.ones(3, dtype=np.float32)
|
||||
>>> xs = to_torch_tensor(xs)
|
||||
tensor([1., 1., 1.])
|
||||
>>> xs = torch.ones(3, 4, 5)
|
||||
>>> assert to_torch_tensor(xs) is xs
|
||||
>>> xs = {'real': xs, 'imag': xs}
|
||||
>>> to_torch_tensor(xs)
|
||||
ComplexTensor(
|
||||
Real:
|
||||
tensor([1., 1., 1.])
|
||||
Imag;
|
||||
tensor([1., 1., 1.])
|
||||
)
|
||||
|
||||
"""
|
||||
# If numpy, change to torch tensor
|
||||
if isinstance(x, np.ndarray):
|
||||
if x.dtype.kind == 'c':
|
||||
# Dynamically importing because torch_complex requires python3
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
return ComplexTensor(x)
|
||||
else:
|
||||
return torch.from_numpy(x)
|
||||
|
||||
# If {'real': ..., 'imag': ...}, convert to ComplexTensor
|
||||
elif isinstance(x, dict):
|
||||
# Dynamically importing because torch_complex requires python3
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
if 'real' not in x or 'imag' not in x:
|
||||
raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
|
||||
# Relative importing because of using python3 syntax
|
||||
return ComplexTensor(x['real'], x['imag'])
|
||||
|
||||
# If torch.Tensor, as it is
|
||||
elif isinstance(x, torch.Tensor):
|
||||
return x
|
||||
|
||||
else:
|
||||
error = ("x must be numpy.ndarray, torch.Tensor or a dict like "
|
||||
"{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
|
||||
"but got {}".format(type(x)))
|
||||
try:
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
except Exception:
|
||||
# If PY2
|
||||
raise ValueError(error)
|
||||
else:
|
||||
# If PY3
|
||||
if isinstance(x, ComplexTensor):
|
||||
return x
|
||||
else:
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def get_subsample(train_args, mode, arch):
|
||||
"""Parse the subsampling factors from the training args for the specified `mode` and `arch`.
|
||||
|
||||
Args:
|
||||
train_args: argument Namespace containing options.
|
||||
mode: one of ('asr', 'mt', 'st')
|
||||
arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
|
||||
|
||||
Returns:
|
||||
np.ndarray / List[np.ndarray]: subsampling factors.
|
||||
"""
|
||||
if arch == 'transformer':
|
||||
return np.array([1])
|
||||
|
||||
elif mode == 'mt' and arch == 'rnn':
|
||||
# +1 means input (+1) and layers outputs (train_args.elayer)
|
||||
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||
logging.warning('Subsampling is not performed for machine translation.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif (mode == 'asr' and arch in ('rnn', 'rnn-t')) or \
|
||||
(mode == 'mt' and arch == 'rnn') or \
|
||||
(mode == 'st' and arch == 'rnn'):
|
||||
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||
ss = train_args.subsample.split("_")
|
||||
for j in range(min(train_args.elayers + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif mode == 'asr' and arch == 'rnn_mix':
|
||||
subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int)
|
||||
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||
ss = train_args.subsample.split("_")
|
||||
for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif mode == 'asr' and arch == 'rnn_mulenc':
|
||||
subsample_list = []
|
||||
for idx in range(train_args.num_encs):
|
||||
subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int)
|
||||
if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"):
|
||||
ss = train_args.subsample[idx].split("_")
|
||||
for j in range(min(train_args.elayers[idx] + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Encoder %d: Subsampling is not performed for vgg*. '
|
||||
'It is performed in max pooling layers at CNN.', idx + 1)
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
subsample_list.append(subsample)
|
||||
return subsample_list
|
||||
|
||||
else:
|
||||
raise ValueError('Invalid options: mode={}, arch={}'.format(mode, arch))
|
||||
|
||||
|
||||
def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]):
|
||||
"""Replace keys of old prefix with new prefix in state dict."""
|
||||
# need this list not to break the dict iterator
|
||||
old_keys = [k for k in state_dict if k.startswith(old_prefix)]
|
||||
if len(old_keys) > 0:
|
||||
logging.warning(f'Rename: {old_prefix} -> {new_prefix}')
|
||||
for k in old_keys:
|
||||
v = state_dict.pop(k)
|
||||
new_k = k.replace(old_prefix, new_prefix)
|
||||
state_dict[new_k] = v
|
||||
@@ -1,22 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
def gcd(a, b):
|
||||
"""Greatest common divisor."""
|
||||
a, b = (a, b) if a >=b else (b, a)
|
||||
if a%b == 0:
|
||||
return b
|
||||
else :
|
||||
return gcd(b, a%b)
|
||||
|
||||
def lcm(a, b):
|
||||
"""Least common multiple"""
|
||||
return a * b // gcd(a, b)
|
||||
|
||||
def get_mask_from_lengths(lengths, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = torch.max(lengths).item()
|
||||
ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
|
||||
mask = (ids < lengths.unsqueeze(1)).bool()
|
||||
return mask
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
import argparse
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
|
||||
from .frontend import DefaultFrontend
|
||||
from .utterance_mvn import UtteranceMVN
|
||||
from .encoder.conformer_encoder import ConformerEncoder
|
||||
|
||||
_model = None # type: PPGModel
|
||||
_device = None
|
||||
|
||||
class PPGModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
frontend,
|
||||
normalizer,
|
||||
encoder,
|
||||
):
|
||||
super().__init__()
|
||||
self.frontend = frontend
|
||||
self.normalize = normalizer
|
||||
self.encoder = encoder
|
||||
|
||||
def forward(self, speech, speech_lengths):
|
||||
"""
|
||||
|
||||
Args:
|
||||
speech (tensor): (B, L)
|
||||
speech_lengths (tensor): (B, )
|
||||
|
||||
Returns:
|
||||
bottle_neck_feats (tensor): (B, L//hop_size, 144)
|
||||
|
||||
"""
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
||||
return encoder_out
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
):
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
|
||||
if self.frontend is not None:
|
||||
# Frontend
|
||||
# e.g. STFT and Feature extract
|
||||
# data_loader may send time-domain signal in this case
|
||||
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||||
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
# No frontend and no feature extract
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return feats, feats_lengths
|
||||
|
||||
def extract_from_wav(self, src_wav):
|
||||
src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(_device)
|
||||
src_wav_lengths = torch.LongTensor([len(src_wav)]).to(_device)
|
||||
return self(src_wav_tensor, src_wav_lengths)
|
||||
|
||||
|
||||
def build_model(args):
|
||||
normalizer = UtteranceMVN(**args.normalize_conf)
|
||||
frontend = DefaultFrontend(**args.frontend_conf)
|
||||
encoder = ConformerEncoder(input_size=80, **args.encoder_conf)
|
||||
model = PPGModel(frontend, normalizer, encoder)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_model(model_file, device=None):
|
||||
global _model, _device
|
||||
|
||||
if device is None:
|
||||
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
_device = device
|
||||
# search a config file
|
||||
model_config_fpaths = list(model_file.parent.rglob("*.yaml"))
|
||||
config_file = model_config_fpaths[0]
|
||||
with config_file.open("r", encoding="utf-8") as f:
|
||||
args = yaml.safe_load(f)
|
||||
|
||||
args = argparse.Namespace(**args)
|
||||
|
||||
model = build_model(args)
|
||||
model_state_dict = model.state_dict()
|
||||
|
||||
ckpt_state_dict = torch.load(model_file, map_location=_device)
|
||||
ckpt_state_dict = {k:v for k,v in ckpt_state_dict.items() if 'encoder' in k}
|
||||
|
||||
model_state_dict.update(ckpt_state_dict)
|
||||
model.load_state_dict(model_state_dict)
|
||||
|
||||
_model = model.eval().to(_device)
|
||||
return _model
|
||||
|
||||
|
||||
@@ -1,398 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Common functions for ASR."""
|
||||
|
||||
import argparse
|
||||
import editdistance
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
import six
|
||||
import sys
|
||||
|
||||
from itertools import groupby
|
||||
|
||||
|
||||
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
|
||||
"""End detection.
|
||||
|
||||
desribed in Eq. (50) of S. Watanabe et al
|
||||
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
|
||||
|
||||
:param ended_hyps:
|
||||
:param i:
|
||||
:param M:
|
||||
:param D_end:
|
||||
:return:
|
||||
"""
|
||||
if len(ended_hyps) == 0:
|
||||
return False
|
||||
count = 0
|
||||
best_hyp = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[0]
|
||||
for m in six.moves.range(M):
|
||||
# get ended_hyps with their length is i - m
|
||||
hyp_length = i - m
|
||||
hyps_same_length = [x for x in ended_hyps if len(x['yseq']) == hyp_length]
|
||||
if len(hyps_same_length) > 0:
|
||||
best_hyp_same_length = sorted(hyps_same_length, key=lambda x: x['score'], reverse=True)[0]
|
||||
if best_hyp_same_length['score'] - best_hyp['score'] < D_end:
|
||||
count += 1
|
||||
|
||||
if count == M:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
# TODO(takaaki-hori): add different smoothing methods
|
||||
def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
|
||||
"""Obtain label distribution for loss smoothing.
|
||||
|
||||
:param odim:
|
||||
:param lsm_type:
|
||||
:param blank:
|
||||
:param transcript:
|
||||
:return:
|
||||
"""
|
||||
if transcript is not None:
|
||||
with open(transcript, 'rb') as f:
|
||||
trans_json = json.load(f)['utts']
|
||||
|
||||
if lsm_type == 'unigram':
|
||||
assert transcript is not None, 'transcript is required for %s label smoothing' % lsm_type
|
||||
labelcount = np.zeros(odim)
|
||||
for k, v in trans_json.items():
|
||||
ids = np.array([int(n) for n in v['output'][0]['tokenid'].split()])
|
||||
# to avoid an error when there is no text in an uttrance
|
||||
if len(ids) > 0:
|
||||
labelcount[ids] += 1
|
||||
labelcount[odim - 1] = len(transcript) # count <eos>
|
||||
labelcount[labelcount == 0] = 1 # flooring
|
||||
labelcount[blank] = 0 # remove counts for blank
|
||||
labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
|
||||
else:
|
||||
logging.error(
|
||||
"Error: unexpected label smoothing type: %s" % lsm_type)
|
||||
sys.exit()
|
||||
|
||||
return labeldist
|
||||
|
||||
|
||||
def get_vgg2l_odim(idim, in_channel=3, out_channel=128, downsample=True):
|
||||
"""Return the output size of the VGG frontend.
|
||||
|
||||
:param in_channel: input channel size
|
||||
:param out_channel: output channel size
|
||||
:return: output size
|
||||
:rtype int
|
||||
"""
|
||||
idim = idim / in_channel
|
||||
if downsample:
|
||||
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling
|
||||
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling
|
||||
return int(idim) * out_channel # numer of channels
|
||||
|
||||
|
||||
class ErrorCalculator(object):
|
||||
"""Calculate CER and WER for E2E_ASR and CTC models during training.
|
||||
|
||||
:param y_hats: numpy array with predicted text
|
||||
:param y_pads: numpy array with true (target) text
|
||||
:param char_list:
|
||||
:param sym_space:
|
||||
:param sym_blank:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def __init__(self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False,
|
||||
trans_type="char"):
|
||||
"""Construct an ErrorCalculator object."""
|
||||
super(ErrorCalculator, self).__init__()
|
||||
|
||||
self.report_cer = report_cer
|
||||
self.report_wer = report_wer
|
||||
self.trans_type = trans_type
|
||||
self.char_list = char_list
|
||||
self.space = sym_space
|
||||
self.blank = sym_blank
|
||||
self.idx_blank = self.char_list.index(self.blank)
|
||||
if self.space in self.char_list:
|
||||
self.idx_space = self.char_list.index(self.space)
|
||||
else:
|
||||
self.idx_space = None
|
||||
|
||||
def __call__(self, ys_hat, ys_pad, is_ctc=False):
|
||||
"""Calculate sentence-level WER/CER score.
|
||||
|
||||
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
||||
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
||||
:param bool is_ctc: calculate CER score for CTC
|
||||
:return: sentence-level WER score
|
||||
:rtype float
|
||||
:return: sentence-level CER score
|
||||
:rtype float
|
||||
"""
|
||||
cer, wer = None, None
|
||||
if is_ctc:
|
||||
return self.calculate_cer_ctc(ys_hat, ys_pad)
|
||||
elif not self.report_cer and not self.report_wer:
|
||||
return cer, wer
|
||||
|
||||
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
|
||||
if self.report_cer:
|
||||
cer = self.calculate_cer(seqs_hat, seqs_true)
|
||||
|
||||
if self.report_wer:
|
||||
wer = self.calculate_wer(seqs_hat, seqs_true)
|
||||
return cer, wer
|
||||
|
||||
def calculate_cer_ctc(self, ys_hat, ys_pad):
|
||||
"""Calculate sentence-level CER score for CTC.
|
||||
|
||||
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
||||
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
||||
:return: average sentence-level CER score
|
||||
:rtype float
|
||||
"""
|
||||
cers, char_ref_lens = [], []
|
||||
for i, y in enumerate(ys_hat):
|
||||
y_hat = [x[0] for x in groupby(y)]
|
||||
y_true = ys_pad[i]
|
||||
seq_hat, seq_true = [], []
|
||||
for idx in y_hat:
|
||||
idx = int(idx)
|
||||
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
||||
seq_hat.append(self.char_list[int(idx)])
|
||||
|
||||
for idx in y_true:
|
||||
idx = int(idx)
|
||||
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
||||
seq_true.append(self.char_list[int(idx)])
|
||||
if self.trans_type == "char":
|
||||
hyp_chars = "".join(seq_hat)
|
||||
ref_chars = "".join(seq_true)
|
||||
else:
|
||||
hyp_chars = " ".join(seq_hat)
|
||||
ref_chars = " ".join(seq_true)
|
||||
|
||||
if len(ref_chars) > 0:
|
||||
cers.append(editdistance.eval(hyp_chars, ref_chars))
|
||||
char_ref_lens.append(len(ref_chars))
|
||||
|
||||
cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
|
||||
return cer_ctc
|
||||
|
||||
def convert_to_char(self, ys_hat, ys_pad):
|
||||
"""Convert index to character.
|
||||
|
||||
:param torch.Tensor seqs_hat: prediction (batch, seqlen)
|
||||
:param torch.Tensor seqs_true: reference (batch, seqlen)
|
||||
:return: token list of prediction
|
||||
:rtype list
|
||||
:return: token list of reference
|
||||
:rtype list
|
||||
"""
|
||||
seqs_hat, seqs_true = [], []
|
||||
for i, y_hat in enumerate(ys_hat):
|
||||
y_true = ys_pad[i]
|
||||
eos_true = np.where(y_true == -1)[0]
|
||||
eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true)
|
||||
# To avoid wrong higher WER than the one obtained from the decoding
|
||||
# eos from y_true is used to mark the eos in y_hat
|
||||
# because of that y_hats has not padded outs with -1.
|
||||
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]]
|
||||
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
|
||||
# seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
|
||||
seq_hat_text = " ".join(seq_hat).replace(self.space, ' ')
|
||||
seq_hat_text = seq_hat_text.replace(self.blank, '')
|
||||
# seq_true_text = "".join(seq_true).replace(self.space, ' ')
|
||||
seq_true_text = " ".join(seq_true).replace(self.space, ' ')
|
||||
seqs_hat.append(seq_hat_text)
|
||||
seqs_true.append(seq_true_text)
|
||||
return seqs_hat, seqs_true
|
||||
|
||||
def calculate_cer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level CER score.
|
||||
|
||||
:param list seqs_hat: prediction
|
||||
:param list seqs_true: reference
|
||||
:return: average sentence-level CER score
|
||||
:rtype float
|
||||
"""
|
||||
char_eds, char_ref_lens = [], []
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_chars = seq_hat_text.replace(' ', '')
|
||||
ref_chars = seq_true_text.replace(' ', '')
|
||||
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
|
||||
char_ref_lens.append(len(ref_chars))
|
||||
return float(sum(char_eds)) / sum(char_ref_lens)
|
||||
|
||||
def calculate_wer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level WER score.
|
||||
|
||||
:param list seqs_hat: prediction
|
||||
:param list seqs_true: reference
|
||||
:return: average sentence-level WER score
|
||||
:rtype float
|
||||
"""
|
||||
word_eds, word_ref_lens = [], []
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_words = seq_hat_text.split()
|
||||
ref_words = seq_true_text.split()
|
||||
word_eds.append(editdistance.eval(hyp_words, ref_words))
|
||||
word_ref_lens.append(len(ref_words))
|
||||
return float(sum(word_eds)) / sum(word_ref_lens)
|
||||
|
||||
|
||||
class ErrorCalculatorTrans(object):
|
||||
"""Calculate CER and WER for transducer models.
|
||||
|
||||
Args:
|
||||
decoder (nn.Module): decoder module
|
||||
args (Namespace): argument Namespace containing options
|
||||
report_cer (boolean): compute CER option
|
||||
report_wer (boolean): compute WER option
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, decoder, args, report_cer=False, report_wer=False):
|
||||
"""Construct an ErrorCalculator object for transducer model."""
|
||||
super(ErrorCalculatorTrans, self).__init__()
|
||||
|
||||
self.dec = decoder
|
||||
|
||||
recog_args = {'beam_size': args.beam_size,
|
||||
'nbest': args.nbest,
|
||||
'space': args.sym_space,
|
||||
'score_norm_transducer': args.score_norm_transducer}
|
||||
|
||||
self.recog_args = argparse.Namespace(**recog_args)
|
||||
|
||||
self.char_list = args.char_list
|
||||
self.space = args.sym_space
|
||||
self.blank = args.sym_blank
|
||||
|
||||
self.report_cer = args.report_cer
|
||||
self.report_wer = args.report_wer
|
||||
|
||||
def __call__(self, hs_pad, ys_pad):
|
||||
"""Calculate sentence-level WER/CER score for transducer models.
|
||||
|
||||
Args:
|
||||
hs_pad (torch.Tensor): batch of padded input sequence (batch, T, D)
|
||||
ys_pad (torch.Tensor): reference (batch, seqlen)
|
||||
|
||||
Returns:
|
||||
(float): sentence-level CER score
|
||||
(float): sentence-level WER score
|
||||
|
||||
"""
|
||||
cer, wer = None, None
|
||||
|
||||
if not self.report_cer and not self.report_wer:
|
||||
return cer, wer
|
||||
|
||||
batchsize = int(hs_pad.size(0))
|
||||
batch_nbest = []
|
||||
|
||||
for b in six.moves.range(batchsize):
|
||||
if self.recog_args.beam_size == 1:
|
||||
nbest_hyps = self.dec.recognize(hs_pad[b], self.recog_args)
|
||||
else:
|
||||
nbest_hyps = self.dec.recognize_beam(hs_pad[b], self.recog_args)
|
||||
batch_nbest.append(nbest_hyps)
|
||||
|
||||
ys_hat = [nbest_hyp[0]['yseq'][1:] for nbest_hyp in batch_nbest]
|
||||
|
||||
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad.cpu())
|
||||
|
||||
if self.report_cer:
|
||||
cer = self.calculate_cer(seqs_hat, seqs_true)
|
||||
|
||||
if self.report_wer:
|
||||
wer = self.calculate_wer(seqs_hat, seqs_true)
|
||||
|
||||
return cer, wer
|
||||
|
||||
def convert_to_char(self, ys_hat, ys_pad):
|
||||
"""Convert index to character.
|
||||
|
||||
Args:
|
||||
ys_hat (torch.Tensor): prediction (batch, seqlen)
|
||||
ys_pad (torch.Tensor): reference (batch, seqlen)
|
||||
|
||||
Returns:
|
||||
(list): token list of prediction
|
||||
(list): token list of reference
|
||||
|
||||
"""
|
||||
seqs_hat, seqs_true = [], []
|
||||
|
||||
for i, y_hat in enumerate(ys_hat):
|
||||
y_true = ys_pad[i]
|
||||
|
||||
eos_true = np.where(y_true == -1)[0]
|
||||
eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true)
|
||||
|
||||
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]]
|
||||
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
|
||||
|
||||
seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
|
||||
seq_hat_text = seq_hat_text.replace(self.blank, '')
|
||||
seq_true_text = "".join(seq_true).replace(self.space, ' ')
|
||||
|
||||
seqs_hat.append(seq_hat_text)
|
||||
seqs_true.append(seq_true_text)
|
||||
|
||||
return seqs_hat, seqs_true
|
||||
|
||||
def calculate_cer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level CER score for transducer model.
|
||||
|
||||
Args:
|
||||
seqs_hat (torch.Tensor): prediction (batch, seqlen)
|
||||
seqs_true (torch.Tensor): reference (batch, seqlen)
|
||||
|
||||
Returns:
|
||||
(float): average sentence-level CER score
|
||||
|
||||
"""
|
||||
char_eds, char_ref_lens = [], []
|
||||
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_chars = seq_hat_text.replace(' ', '')
|
||||
ref_chars = seq_true_text.replace(' ', '')
|
||||
|
||||
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
|
||||
char_ref_lens.append(len(ref_chars))
|
||||
|
||||
return float(sum(char_eds)) / sum(char_ref_lens)
|
||||
|
||||
def calculate_wer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level WER score for transducer model.
|
||||
|
||||
Args:
|
||||
seqs_hat (torch.Tensor): prediction (batch, seqlen)
|
||||
seqs_true (torch.Tensor): reference (batch, seqlen)
|
||||
|
||||
Returns:
|
||||
(float): average sentence-level WER score
|
||||
|
||||
"""
|
||||
word_eds, word_ref_lens = [], []
|
||||
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_words = seq_hat_text.split()
|
||||
ref_words = seq_true_text.split()
|
||||
|
||||
word_eds.append(editdistance.eval(hyp_words, ref_words))
|
||||
word_ref_lens.append(len(ref_words))
|
||||
|
||||
return float(sum(word_eds)) / sum(word_ref_lens)
|
||||
@@ -1,183 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Multi-Head Attention layer definition."""
|
||||
|
||||
import math
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class MultiHeadedAttention(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
:param int n_head: the number of head s
|
||||
:param int n_feat: the number of features
|
||||
:param float dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_head, n_feat, dropout_rate):
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super(MultiHeadedAttention, self).__init__()
|
||||
assert n_feat % n_head == 0
|
||||
# We assume d_v always equals d_k
|
||||
self.d_k = n_feat // n_head
|
||||
self.h = n_head
|
||||
self.linear_q = nn.Linear(n_feat, n_feat)
|
||||
self.linear_k = nn.Linear(n_feat, n_feat)
|
||||
self.linear_v = nn.Linear(n_feat, n_feat)
|
||||
self.linear_out = nn.Linear(n_feat, n_feat)
|
||||
self.attn = None
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
def forward_qkv(self, query, key, value):
|
||||
"""Transform query, key and value.
|
||||
|
||||
:param torch.Tensor query: (batch, time1, size)
|
||||
:param torch.Tensor key: (batch, time2, size)
|
||||
:param torch.Tensor value: (batch, time2, size)
|
||||
:return torch.Tensor transformed query, key and value
|
||||
|
||||
"""
|
||||
n_batch = query.size(0)
|
||||
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
||||
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
||||
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
||||
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
||||
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
||||
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward_attention(self, value, scores, mask):
|
||||
"""Compute attention context vector.
|
||||
|
||||
:param torch.Tensor value: (batch, head, time2, size)
|
||||
:param torch.Tensor scores: (batch, head, time1, time2)
|
||||
:param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2)
|
||||
:return torch.Tensor transformed `value` (batch, time1, d_model)
|
||||
weighted by the attention score (batch, time1, time2)
|
||||
|
||||
"""
|
||||
n_batch = value.size(0)
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
min_value = float(
|
||||
numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
|
||||
)
|
||||
scores = scores.masked_fill(mask, min_value)
|
||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||
mask, 0.0
|
||||
) # (batch, head, time1, time2)
|
||||
else:
|
||||
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
||||
|
||||
p_attn = self.dropout(self.attn)
|
||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||
x = (
|
||||
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
||||
) # (batch, time1, d_model)
|
||||
|
||||
return self.linear_out(x) # (batch, time1, d_model)
|
||||
|
||||
def forward(self, query, key, value, mask):
|
||||
"""Compute 'Scaled Dot Product Attention'.
|
||||
|
||||
:param torch.Tensor query: (batch, time1, size)
|
||||
:param torch.Tensor key: (batch, time2, size)
|
||||
:param torch.Tensor value: (batch, time2, size)
|
||||
:param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2)
|
||||
:param torch.nn.Dropout dropout:
|
||||
:return torch.Tensor: attention output (batch, time1, d_model)
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
||||
return self.forward_attention(v, scores, mask)
|
||||
|
||||
|
||||
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
||||
"""Multi-Head Attention layer with relative position encoding.
|
||||
|
||||
Paper: https://arxiv.org/abs/1901.02860
|
||||
|
||||
:param int n_head: the number of head s
|
||||
:param int n_feat: the number of features
|
||||
:param float dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_head, n_feat, dropout_rate):
|
||||
"""Construct an RelPositionMultiHeadedAttention object."""
|
||||
super().__init__(n_head, n_feat, dropout_rate)
|
||||
# linear transformation for positional ecoding
|
||||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
def rel_shift(self, x, zero_triu=False):
|
||||
"""Compute relative positinal encoding.
|
||||
|
||||
:param torch.Tensor x: (batch, time, size)
|
||||
:param bool zero_triu: return the lower triangular part of the matrix
|
||||
"""
|
||||
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
||||
x_padded = torch.cat([zero_pad, x], dim=-1)
|
||||
|
||||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
||||
x = x_padded[:, :, 1:].view_as(x)
|
||||
|
||||
if zero_triu:
|
||||
ones = torch.ones((x.size(2), x.size(3)))
|
||||
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, query, key, value, pos_emb, mask):
|
||||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
||||
|
||||
:param torch.Tensor query: (batch, time1, size)
|
||||
:param torch.Tensor key: (batch, time2, size)
|
||||
:param torch.Tensor value: (batch, time2, size)
|
||||
:param torch.Tensor pos_emb: (batch, time1, size)
|
||||
:param torch.Tensor mask: (batch, time1, time2)
|
||||
:param torch.nn.Dropout dropout:
|
||||
:return torch.Tensor: attention output (batch, time1, d_model)
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
||||
|
||||
n_batch_pos = pos_emb.size(0)
|
||||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
||||
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
||||
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
||||
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
# (batch, head, time1, time2)
|
||||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
||||
|
||||
# compute matrix b and matrix d
|
||||
# (batch, head, time1, time2)
|
||||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
||||
self.d_k
|
||||
) # (batch, head, time1, time2)
|
||||
|
||||
return self.forward_attention(v, scores, mask)
|
||||
@@ -1,262 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Encoder definition."""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from typing import Callable
|
||||
from typing import Collection
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
from .convolution import ConvolutionModule
|
||||
from .encoder_layer import EncoderLayer
|
||||
from ..nets_utils import get_activation, make_pad_mask
|
||||
from .vgg import VGG2L
|
||||
from .attention import MultiHeadedAttention, RelPositionMultiHeadedAttention
|
||||
from .embedding import PositionalEncoding, ScaledPositionalEncoding, RelPositionalEncoding
|
||||
from .layer_norm import LayerNorm
|
||||
from .multi_layer_conv import Conv1dLinear, MultiLayeredConv1d
|
||||
from .positionwise_feed_forward import PositionwiseFeedForward
|
||||
from .repeat import repeat
|
||||
from .subsampling import Conv2dNoSubsampling, Conv2dSubsampling
|
||||
|
||||
|
||||
class ConformerEncoder(torch.nn.Module):
|
||||
"""Conformer encoder module.
|
||||
|
||||
:param int idim: input dim
|
||||
:param int attention_dim: dimention of attention
|
||||
:param int attention_heads: the number of heads of multi head attention
|
||||
:param int linear_units: the number of units of position-wise feed forward
|
||||
:param int num_blocks: the number of decoder blocks
|
||||
:param float dropout_rate: dropout rate
|
||||
:param float attention_dropout_rate: dropout rate in attention
|
||||
:param float positional_dropout_rate: dropout rate after adding positional encoding
|
||||
:param str or torch.nn.Module input_layer: input layer type
|
||||
:param bool normalize_before: whether to use layer_norm before the first block
|
||||
:param bool concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
:param str positionwise_layer_type: linear of conv1d
|
||||
:param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
|
||||
:param str encoder_pos_enc_layer_type: encoder positional encoding layer type
|
||||
:param str encoder_attn_layer_type: encoder attention layer type
|
||||
:param str activation_type: encoder activation function type
|
||||
:param bool macaron_style: whether to use macaron style for positionwise layer
|
||||
:param bool use_cnn_module: whether to use convolution module
|
||||
:param int cnn_module_kernel: kernerl size of convolution module
|
||||
:param int padding_idx: padding_idx for input_layer=embed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
attention_dim=256,
|
||||
attention_heads=4,
|
||||
linear_units=2048,
|
||||
num_blocks=6,
|
||||
dropout_rate=0.1,
|
||||
positional_dropout_rate=0.1,
|
||||
attention_dropout_rate=0.0,
|
||||
input_layer="conv2d",
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
positionwise_layer_type="linear",
|
||||
positionwise_conv_kernel_size=1,
|
||||
macaron_style=False,
|
||||
pos_enc_layer_type="abs_pos",
|
||||
selfattention_layer_type="selfattn",
|
||||
activation_type="swish",
|
||||
use_cnn_module=False,
|
||||
cnn_module_kernel=31,
|
||||
padding_idx=-1,
|
||||
no_subsample=False,
|
||||
subsample_by_2=False,
|
||||
):
|
||||
"""Construct an Encoder object."""
|
||||
super().__init__()
|
||||
|
||||
self._output_size = attention_dim
|
||||
idim = input_size
|
||||
|
||||
activation = get_activation(activation_type)
|
||||
if pos_enc_layer_type == "abs_pos":
|
||||
pos_enc_class = PositionalEncoding
|
||||
elif pos_enc_layer_type == "scaled_abs_pos":
|
||||
pos_enc_class = ScaledPositionalEncoding
|
||||
elif pos_enc_layer_type == "rel_pos":
|
||||
assert selfattention_layer_type == "rel_selfattn"
|
||||
pos_enc_class = RelPositionalEncoding
|
||||
else:
|
||||
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(idim, attention_dim),
|
||||
torch.nn.LayerNorm(attention_dim),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
logging.info("Encoder input layer type: conv2d")
|
||||
if no_subsample:
|
||||
self.embed = Conv2dNoSubsampling(
|
||||
idim,
|
||||
attention_dim,
|
||||
dropout_rate,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
else:
|
||||
self.embed = Conv2dSubsampling(
|
||||
idim,
|
||||
attention_dim,
|
||||
dropout_rate,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
subsample_by_2, # NOTE(Sx): added by songxiang
|
||||
)
|
||||
elif input_layer == "vgg2l":
|
||||
self.embed = VGG2L(idim, attention_dim)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif isinstance(input_layer, torch.nn.Module):
|
||||
self.embed = torch.nn.Sequential(
|
||||
input_layer,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
self.embed = torch.nn.Sequential(
|
||||
pos_enc_class(attention_dim, positional_dropout_rate)
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
activation,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
logging.info("encoder self-attention layer type = self-attention")
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
elif selfattention_layer_type == "rel_selfattn":
|
||||
assert pos_enc_layer_type == "rel_pos"
|
||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
|
||||
|
||||
convolution_layer = ConvolutionModule
|
||||
convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
attention_dim,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
||||
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input lengths (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
Position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
|
||||
if isinstance(self.embed, (Conv2dSubsampling, Conv2dNoSubsampling, VGG2L)):
|
||||
# print(xs_pad.shape)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
# print(xs_pad[0].size())
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
xs_pad, masks = self.encoders(xs_pad, masks)
|
||||
if isinstance(xs_pad, tuple):
|
||||
xs_pad = xs_pad[0]
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
return xs_pad, olens, None
|
||||
|
||||
# def forward(self, xs, masks):
|
||||
# """Encode input sequence.
|
||||
|
||||
# :param torch.Tensor xs: input tensor
|
||||
# :param torch.Tensor masks: input mask
|
||||
# :return: position embedded tensor and mask
|
||||
# :rtype Tuple[torch.Tensor, torch.Tensor]:
|
||||
# """
|
||||
# if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
|
||||
# xs, masks = self.embed(xs, masks)
|
||||
# else:
|
||||
# xs = self.embed(xs)
|
||||
|
||||
# xs, masks = self.encoders(xs, masks)
|
||||
# if isinstance(xs, tuple):
|
||||
# xs = xs[0]
|
||||
|
||||
# if self.normalize_before:
|
||||
# xs = self.after_norm(xs)
|
||||
# return xs, masks
|
||||
@@ -1,74 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
||||
# Northwestern Polytechnical University (Pengcheng Guo)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""ConvolutionModule definition."""
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model.
|
||||
|
||||
:param int channels: channels of cnn
|
||||
:param int kernel_size: kernerl size of cnn
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
|
||||
self.pointwise_conv1 = nn.Conv1d(
|
||||
channels,
|
||||
2 * channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
self.pointwise_conv2 = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
"""Compute convolution module.
|
||||
|
||||
:param torch.Tensor x: (batch, time, size)
|
||||
:return torch.Tensor: convoluted `value` (batch, time, d_model)
|
||||
"""
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
# GLU mechanism
|
||||
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
||||
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.activation(self.norm(x))
|
||||
|
||||
x = self.pointwise_conv2(x)
|
||||
|
||||
return x.transpose(1, 2)
|
||||
@@ -1,166 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Positonal Encoding Module."""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _pre_hook(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
"""Perform pre-hook in load_state_dict for backward compatibility.
|
||||
|
||||
Note:
|
||||
We saved self.pe until v.0.5.2 but we have omitted it later.
|
||||
Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
|
||||
|
||||
"""
|
||||
k = prefix + "pe"
|
||||
if k in state_dict:
|
||||
state_dict.pop(k)
|
||||
|
||||
|
||||
class PositionalEncoding(torch.nn.Module):
|
||||
"""Positional encoding.
|
||||
|
||||
:param int d_model: embedding dim
|
||||
:param float dropout_rate: dropout rate
|
||||
:param int max_len: maximum input length
|
||||
:param reverse: whether to reverse the input position
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.reverse = reverse
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
self._register_load_state_dict_pre_hook(_pre_hook)
|
||||
|
||||
def extend_pe(self, x):
|
||||
"""Reset the positional encodings."""
|
||||
if self.pe is not None:
|
||||
if self.pe.size(1) >= x.size(1):
|
||||
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
pe = torch.zeros(x.size(1), self.d_model)
|
||||
if self.reverse:
|
||||
position = torch.arange(
|
||||
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
||||
).unsqueeze(1)
|
||||
else:
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale + self.pe[:, : x.size(1)]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class ScaledPositionalEncoding(PositionalEncoding):
|
||||
"""Scaled positional encoding module.
|
||||
|
||||
See also: Sec. 3.2 https://arxiv.org/pdf/1809.08895.pdf
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000):
|
||||
"""Initialize class.
|
||||
|
||||
:param int d_model: embedding dim
|
||||
:param float dropout_rate: dropout rate
|
||||
:param int max_len: maximum input length
|
||||
|
||||
"""
|
||||
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
|
||||
|
||||
def reset_parameters(self):
|
||||
"""Reset parameters."""
|
||||
self.alpha.data = torch.tensor(1.0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x + self.alpha * self.pe[:, : x.size(1)]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class RelPositionalEncoding(PositionalEncoding):
|
||||
"""Relitive positional encoding module.
|
||||
|
||||
See : Appendix B in https://arxiv.org/abs/1901.02860
|
||||
|
||||
:param int d_model: embedding dim
|
||||
:param float dropout_rate: dropout rate
|
||||
:param int max_len: maximum input length
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000):
|
||||
"""Initialize class.
|
||||
|
||||
:param int d_model: embedding dim
|
||||
:param float dropout_rate: dropout rate
|
||||
:param int max_len: maximum input length
|
||||
|
||||
"""
|
||||
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
||||
|
||||
def forward(self, x):
|
||||
"""Compute positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: x. Its shape is (batch, time, ...)
|
||||
torch.Tensor: pos_emb. Its shape is (1, time, ...)
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale
|
||||
pos_emb = self.pe[:, : x.size(1)]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
@@ -1,217 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Encoder definition."""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule
|
||||
from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer
|
||||
from espnet.nets.pytorch_backend.nets_utils import get_activation
|
||||
from espnet.nets.pytorch_backend.transducer.vgg import VGG2L
|
||||
from espnet.nets.pytorch_backend.transformer.attention import (
|
||||
MultiHeadedAttention, # noqa: H301
|
||||
RelPositionMultiHeadedAttention, # noqa: H301
|
||||
)
|
||||
from espnet.nets.pytorch_backend.transformer.embedding import (
|
||||
PositionalEncoding, # noqa: H301
|
||||
ScaledPositionalEncoding, # noqa: H301
|
||||
RelPositionalEncoding, # noqa: H301
|
||||
)
|
||||
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
|
||||
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear
|
||||
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d
|
||||
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from espnet.nets.pytorch_backend.transformer.repeat import repeat
|
||||
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling
|
||||
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
"""Conformer encoder module.
|
||||
|
||||
:param int idim: input dim
|
||||
:param int attention_dim: dimention of attention
|
||||
:param int attention_heads: the number of heads of multi head attention
|
||||
:param int linear_units: the number of units of position-wise feed forward
|
||||
:param int num_blocks: the number of decoder blocks
|
||||
:param float dropout_rate: dropout rate
|
||||
:param float attention_dropout_rate: dropout rate in attention
|
||||
:param float positional_dropout_rate: dropout rate after adding positional encoding
|
||||
:param str or torch.nn.Module input_layer: input layer type
|
||||
:param bool normalize_before: whether to use layer_norm before the first block
|
||||
:param bool concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
:param str positionwise_layer_type: linear of conv1d
|
||||
:param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
|
||||
:param str encoder_pos_enc_layer_type: encoder positional encoding layer type
|
||||
:param str encoder_attn_layer_type: encoder attention layer type
|
||||
:param str activation_type: encoder activation function type
|
||||
:param bool macaron_style: whether to use macaron style for positionwise layer
|
||||
:param bool use_cnn_module: whether to use convolution module
|
||||
:param int cnn_module_kernel: kernerl size of convolution module
|
||||
:param int padding_idx: padding_idx for input_layer=embed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
idim,
|
||||
attention_dim=256,
|
||||
attention_heads=4,
|
||||
linear_units=2048,
|
||||
num_blocks=6,
|
||||
dropout_rate=0.1,
|
||||
positional_dropout_rate=0.1,
|
||||
attention_dropout_rate=0.0,
|
||||
input_layer="conv2d",
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
positionwise_layer_type="linear",
|
||||
positionwise_conv_kernel_size=1,
|
||||
macaron_style=False,
|
||||
pos_enc_layer_type="abs_pos",
|
||||
selfattention_layer_type="selfattn",
|
||||
activation_type="swish",
|
||||
use_cnn_module=False,
|
||||
cnn_module_kernel=31,
|
||||
padding_idx=-1,
|
||||
):
|
||||
"""Construct an Encoder object."""
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
activation = get_activation(activation_type)
|
||||
if pos_enc_layer_type == "abs_pos":
|
||||
pos_enc_class = PositionalEncoding
|
||||
elif pos_enc_layer_type == "scaled_abs_pos":
|
||||
pos_enc_class = ScaledPositionalEncoding
|
||||
elif pos_enc_layer_type == "rel_pos":
|
||||
assert selfattention_layer_type == "rel_selfattn"
|
||||
pos_enc_class = RelPositionalEncoding
|
||||
else:
|
||||
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(idim, attention_dim),
|
||||
torch.nn.LayerNorm(attention_dim),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(
|
||||
idim,
|
||||
attention_dim,
|
||||
dropout_rate,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "vgg2l":
|
||||
self.embed = VGG2L(idim, attention_dim)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif isinstance(input_layer, torch.nn.Module):
|
||||
self.embed = torch.nn.Sequential(
|
||||
input_layer,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
self.embed = torch.nn.Sequential(
|
||||
pos_enc_class(attention_dim, positional_dropout_rate)
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
activation,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
logging.info("encoder self-attention layer type = self-attention")
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
elif selfattention_layer_type == "rel_selfattn":
|
||||
assert pos_enc_layer_type == "rel_pos"
|
||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
|
||||
|
||||
convolution_layer = ConvolutionModule
|
||||
convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
attention_dim,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
||||
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
|
||||
def forward(self, xs, masks):
|
||||
"""Encode input sequence.
|
||||
|
||||
:param torch.Tensor xs: input tensor
|
||||
:param torch.Tensor masks: input mask
|
||||
:return: position embedded tensor and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
|
||||
xs, masks = self.embed(xs, masks)
|
||||
else:
|
||||
xs = self.embed(xs)
|
||||
|
||||
xs, masks = self.encoders(xs, masks)
|
||||
if isinstance(xs, tuple):
|
||||
xs = xs[0]
|
||||
|
||||
if self.normalize_before:
|
||||
xs = self.after_norm(xs)
|
||||
return xs, masks
|
||||
@@ -1,152 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Encoder self-attention layer definition."""
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .layer_norm import LayerNorm
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
"""Encoder layer module.
|
||||
|
||||
:param int size: input dim
|
||||
:param espnet.nets.pytorch_backend.transformer.attention.
|
||||
MultiHeadedAttention self_attn: self attention module
|
||||
RelPositionMultiHeadedAttention self_attn: self attention module
|
||||
:param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.
|
||||
PositionwiseFeedForward feed_forward:
|
||||
feed forward module
|
||||
:param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward
|
||||
for macaron style
|
||||
PositionwiseFeedForward feed_forward:
|
||||
feed forward module
|
||||
:param espnet.nets.pytorch_backend.conformer.convolution.
|
||||
ConvolutionModule feed_foreard:
|
||||
feed forward module
|
||||
:param float dropout_rate: dropout rate
|
||||
:param bool normalize_before: whether to use layer_norm before the first block
|
||||
:param bool concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
feed_forward_macaron,
|
||||
conv_module,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayer, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.feed_forward_macaron = feed_forward_macaron
|
||||
self.conv_module = conv_module
|
||||
self.norm_ff = LayerNorm(size) # for the FNN module
|
||||
self.norm_mha = LayerNorm(size) # for the MHA module
|
||||
if feed_forward_macaron is not None:
|
||||
self.norm_ff_macaron = LayerNorm(size)
|
||||
self.ff_scale = 0.5
|
||||
else:
|
||||
self.ff_scale = 1.0
|
||||
if self.conv_module is not None:
|
||||
self.norm_conv = LayerNorm(size) # for the CNN module
|
||||
self.norm_final = LayerNorm(size) # for the final output of the block
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
|
||||
def forward(self, x_input, mask, cache=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
:param torch.Tensor x_input: encoded source features, w/o pos_emb
|
||||
tuple((batch, max_time_in, size), (1, max_time_in, size))
|
||||
or (batch, max_time_in, size)
|
||||
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
||||
:param torch.Tensor cache: cache for x (batch, max_time_in - 1, size)
|
||||
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
if isinstance(x_input, tuple):
|
||||
x, pos_emb = x_input[0], x_input[1]
|
||||
else:
|
||||
x, pos_emb = x_input, None
|
||||
|
||||
# whether to use macaron style
|
||||
if self.feed_forward_macaron is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_ff_macaron(x)
|
||||
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_ff_macaron(x)
|
||||
|
||||
# multi-headed self-attention module
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_mha(x)
|
||||
|
||||
if cache is None:
|
||||
x_q = x
|
||||
else:
|
||||
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
|
||||
x_q = x[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
mask = None if mask is None else mask[:, -1:, :]
|
||||
|
||||
if pos_emb is not None:
|
||||
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
|
||||
else:
|
||||
x_att = self.self_attn(x_q, x, x, mask)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat((x, x_att), dim=-1)
|
||||
x = residual + self.concat_linear(x_concat)
|
||||
else:
|
||||
x = residual + self.dropout(x_att)
|
||||
if not self.normalize_before:
|
||||
x = self.norm_mha(x)
|
||||
|
||||
# convolution module
|
||||
if self.conv_module is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_conv(x)
|
||||
x = residual + self.dropout(self.conv_module(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_conv(x)
|
||||
|
||||
# feed forward module
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_ff(x)
|
||||
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_ff(x)
|
||||
|
||||
if self.conv_module is not None:
|
||||
x = self.norm_final(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
if pos_emb is not None:
|
||||
return (x, pos_emb), mask
|
||||
|
||||
return x, mask
|
||||
@@ -1,33 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Layer normalization module."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
"""Layer normalization module.
|
||||
|
||||
:param int nout: output dim size
|
||||
:param int dim: dimension to be normalized
|
||||
"""
|
||||
|
||||
def __init__(self, nout, dim=-1):
|
||||
"""Construct an LayerNorm object."""
|
||||
super(LayerNorm, self).__init__(nout, eps=1e-12)
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
"""Apply layer normalization.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:return: layer normalized tensor
|
||||
:rtype torch.Tensor
|
||||
"""
|
||||
if self.dim == -1:
|
||||
return super(LayerNorm, self).forward(x)
|
||||
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
||||
@@ -1,105 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Layer modules for FFT block in FastSpeech (Feed-forward Transformer)."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MultiLayeredConv1d(torch.nn.Module):
|
||||
"""Multi-layered conv1d for Transformer block.
|
||||
|
||||
This is a module of multi-leyered conv1d designed
|
||||
to replace positionwise feed-forward network
|
||||
in Transforner block, which is introduced in
|
||||
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
||||
|
||||
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
||||
https://arxiv.org/pdf/1905.09263.pdf
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
||||
"""Initialize MultiLayeredConv1d module.
|
||||
|
||||
Args:
|
||||
in_chans (int): Number of input channels.
|
||||
hidden_chans (int): Number of hidden channels.
|
||||
kernel_size (int): Kernel size of conv1d.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
super(MultiLayeredConv1d, self).__init__()
|
||||
self.w_1 = torch.nn.Conv1d(
|
||||
in_chans,
|
||||
hidden_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
self.w_2 = torch.nn.Conv1d(
|
||||
hidden_chans,
|
||||
in_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, x):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of input tensors (B, ..., in_chans).
|
||||
|
||||
Returns:
|
||||
Tensor: Batch of output tensors (B, ..., hidden_chans).
|
||||
|
||||
"""
|
||||
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
||||
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
|
||||
|
||||
|
||||
class Conv1dLinear(torch.nn.Module):
|
||||
"""Conv1D + Linear for Transformer block.
|
||||
|
||||
A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
||||
"""Initialize Conv1dLinear module.
|
||||
|
||||
Args:
|
||||
in_chans (int): Number of input channels.
|
||||
hidden_chans (int): Number of hidden channels.
|
||||
kernel_size (int): Kernel size of conv1d.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
super(Conv1dLinear, self).__init__()
|
||||
self.w_1 = torch.nn.Conv1d(
|
||||
in_chans,
|
||||
hidden_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, x):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of input tensors (B, ..., in_chans).
|
||||
|
||||
Returns:
|
||||
Tensor: Batch of output tensors (B, ..., hidden_chans).
|
||||
|
||||
"""
|
||||
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
||||
return self.w_2(self.dropout(x))
|
||||
@@ -1,31 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Positionwise feed forward layer definition."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class PositionwiseFeedForward(torch.nn.Module):
|
||||
"""Positionwise feed forward layer.
|
||||
|
||||
:param int idim: input dimenstion
|
||||
:param int hidden_units: number of hidden units
|
||||
:param float dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
|
||||
"""Construct an PositionwiseFeedForward object."""
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
||||
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward funciton."""
|
||||
return self.w_2(self.dropout(self.activation(self.w_1(x))))
|
||||
@@ -1,30 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Repeat the same layer definition."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MultiSequential(torch.nn.Sequential):
|
||||
"""Multi-input multi-output torch.nn.Sequential."""
|
||||
|
||||
def forward(self, *args):
|
||||
"""Repeat."""
|
||||
for m in self:
|
||||
args = m(*args)
|
||||
return args
|
||||
|
||||
|
||||
def repeat(N, fn):
|
||||
"""Repeat module N times.
|
||||
|
||||
:param int N: repeat time
|
||||
:param function fn: function to generate module
|
||||
:return: repeated modules
|
||||
:rtype: MultiSequential
|
||||
"""
|
||||
return MultiSequential(*[fn(n) for n in range(N)])
|
||||
@@ -1,218 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Subsampling layer definition."""
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
|
||||
|
||||
|
||||
class Conv2dSubsampling(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length or 1/2 length).
|
||||
|
||||
:param int idim: input dim
|
||||
:param int odim: output dim
|
||||
:param flaot dropout_rate: dropout rate
|
||||
:param torch.nn.Module pos_enc: custom position encoding layer
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate, pos_enc=None,
|
||||
subsample_by_2=False,
|
||||
):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super(Conv2dSubsampling, self).__init__()
|
||||
self.subsample_by_2 = subsample_by_2
|
||||
if subsample_by_2:
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * (idim // 2), odim),
|
||||
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
else:
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, kernel_size=4, stride=2, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * (idim // 4), odim),
|
||||
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:param torch.Tensor x_mask: input mask
|
||||
:return: subsampled x and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
if self.subsample_by_2:
|
||||
return x, x_mask[:, :, ::2]
|
||||
else:
|
||||
return x, x_mask[:, :, ::2][:, :, ::2]
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Subsample x.
|
||||
|
||||
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
||||
return the positioning encoding.
|
||||
|
||||
"""
|
||||
if key != -1:
|
||||
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
||||
return self.out[key]
|
||||
|
||||
|
||||
class Conv2dNoSubsampling(torch.nn.Module):
|
||||
"""Convolutional 2D without subsampling.
|
||||
|
||||
:param int idim: input dim
|
||||
:param int odim: output dim
|
||||
:param flaot dropout_rate: dropout rate
|
||||
:param torch.nn.Module pos_enc: custom position encoding layer
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super().__init__()
|
||||
logging.info("Encoder does not do down-sample on mel-spectrogram.")
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, kernel_size=5, stride=1, padding=2),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * idim, odim),
|
||||
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:param torch.Tensor x_mask: input mask
|
||||
:return: subsampled x and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Subsample x.
|
||||
|
||||
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
||||
return the positioning encoding.
|
||||
|
||||
"""
|
||||
if key != -1:
|
||||
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
||||
return self.out[key]
|
||||
|
||||
|
||||
class Conv2dSubsampling6(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/6 length).
|
||||
|
||||
:param int idim: input dim
|
||||
:param int odim: output dim
|
||||
:param flaot dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super(Conv2dSubsampling6, self).__init__()
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 5, 3),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim),
|
||||
PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:param torch.Tensor x_mask: input mask
|
||||
:return: subsampled x and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask[:, :, :-2:2][:, :, :-4:3]
|
||||
|
||||
|
||||
class Conv2dSubsampling8(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/8 length).
|
||||
|
||||
:param int idim: input dim
|
||||
:param int odim: output dim
|
||||
:param flaot dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super(Conv2dSubsampling8, self).__init__()
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim),
|
||||
PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:param torch.Tensor x_mask: input mask
|
||||
:return: subsampled x and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
|
||||
@@ -1,18 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
||||
# Northwestern Polytechnical University (Pengcheng Guo)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Swish() activation function for Conformer."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Swish(torch.nn.Module):
|
||||
"""Construct an Swish object."""
|
||||
|
||||
def forward(self, x):
|
||||
"""Return Swich activation function."""
|
||||
return x * torch.sigmoid(x)
|
||||
@@ -1,77 +0,0 @@
|
||||
"""VGG2L definition for transformer-transducer."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class VGG2L(torch.nn.Module):
|
||||
"""VGG2L module for transformer-transducer encoder."""
|
||||
|
||||
def __init__(self, idim, odim):
|
||||
"""Construct a VGG2L object.
|
||||
|
||||
Args:
|
||||
idim (int): dimension of inputs
|
||||
odim (int): dimension of outputs
|
||||
|
||||
"""
|
||||
super(VGG2L, self).__init__()
|
||||
|
||||
self.vgg2l = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 64, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(64, 64, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.MaxPool2d((3, 2)),
|
||||
torch.nn.Conv2d(64, 128, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.MaxPool2d((2, 2)),
|
||||
)
|
||||
|
||||
self.output = torch.nn.Linear(128 * ((idim // 2) // 2), odim)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""VGG2L forward for x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): input torch (B, T, idim)
|
||||
x_mask (torch.Tensor): (B, 1, T)
|
||||
|
||||
Returns:
|
||||
x (torch.Tensor): input torch (B, sub(T), attention_dim)
|
||||
x_mask (torch.Tensor): (B, 1, sub(T))
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1)
|
||||
x = self.vgg2l(x)
|
||||
|
||||
b, c, t, f = x.size()
|
||||
|
||||
x = self.output(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
else:
|
||||
x_mask = self.create_new_mask(x_mask, x)
|
||||
|
||||
return x, x_mask
|
||||
|
||||
def create_new_mask(self, x_mask, x):
|
||||
"""Create a subsampled version of x_mask.
|
||||
|
||||
Args:
|
||||
x_mask (torch.Tensor): (B, 1, T)
|
||||
x (torch.Tensor): (B, sub(T), attention_dim)
|
||||
|
||||
Returns:
|
||||
x_mask (torch.Tensor): (B, 1, sub(T))
|
||||
|
||||
"""
|
||||
x_t1 = x_mask.size(2) - (x_mask.size(2) % 3)
|
||||
x_mask = x_mask[:, :, :x_t1][:, :, ::3]
|
||||
|
||||
x_t2 = x_mask.size(2) - (x_mask.size(2) % 2)
|
||||
x_mask = x_mask[:, :, :x_t2][:, :, ::2]
|
||||
|
||||
return x_mask
|
||||
@@ -1,298 +0,0 @@
|
||||
import logging
|
||||
import six
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pack_padded_sequence
|
||||
from torch.nn.utils.rnn import pad_packed_sequence
|
||||
|
||||
from .e2e_asr_common import get_vgg2l_odim
|
||||
from .nets_utils import make_pad_mask, to_device
|
||||
|
||||
|
||||
class RNNP(torch.nn.Module):
|
||||
"""RNN with projection layer module
|
||||
|
||||
:param int idim: dimension of inputs
|
||||
:param int elayers: number of encoder layers
|
||||
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
|
||||
:param int hdim: number of projection units
|
||||
:param np.ndarray subsample: list of subsampling numbers
|
||||
:param float dropout: dropout rate
|
||||
:param str typ: The RNN type
|
||||
"""
|
||||
|
||||
def __init__(self, idim, elayers, cdim, hdim, subsample, dropout, typ="blstm"):
|
||||
super(RNNP, self).__init__()
|
||||
bidir = typ[0] == "b"
|
||||
for i in six.moves.range(elayers):
|
||||
if i == 0:
|
||||
inputdim = idim
|
||||
else:
|
||||
inputdim = hdim
|
||||
rnn = torch.nn.LSTM(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir,
|
||||
batch_first=True) if "lstm" in typ \
|
||||
else torch.nn.GRU(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir, batch_first=True)
|
||||
setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn)
|
||||
# bottleneck layer to merge
|
||||
if bidir:
|
||||
setattr(self, "bt%d" % i, torch.nn.Linear(2 * cdim, hdim))
|
||||
else:
|
||||
setattr(self, "bt%d" % i, torch.nn.Linear(cdim, hdim))
|
||||
|
||||
self.elayers = elayers
|
||||
self.cdim = cdim
|
||||
self.subsample = subsample
|
||||
self.typ = typ
|
||||
self.bidir = bidir
|
||||
|
||||
def forward(self, xs_pad, ilens, prev_state=None):
|
||||
"""RNNP forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:param torch.Tensor prev_state: batch of previous RNN states
|
||||
:return: batch of hidden state sequences (B, Tmax, hdim)
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
|
||||
elayer_states = []
|
||||
for layer in six.moves.range(self.elayers):
|
||||
xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True, enforce_sorted=False)
|
||||
rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer))
|
||||
rnn.flatten_parameters()
|
||||
if prev_state is not None and rnn.bidirectional:
|
||||
prev_state = reset_backward_rnn_state(prev_state)
|
||||
ys, states = rnn(xs_pack, hx=None if prev_state is None else prev_state[layer])
|
||||
elayer_states.append(states)
|
||||
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
|
||||
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
|
||||
sub = self.subsample[layer + 1]
|
||||
if sub > 1:
|
||||
ys_pad = ys_pad[:, ::sub]
|
||||
ilens = [int(i + 1) // sub for i in ilens]
|
||||
# (sum _utt frame_utt) x dim
|
||||
projected = getattr(self, 'bt' + str(layer)
|
||||
)(ys_pad.contiguous().view(-1, ys_pad.size(2)))
|
||||
if layer == self.elayers - 1:
|
||||
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
|
||||
else:
|
||||
xs_pad = torch.tanh(projected.view(ys_pad.size(0), ys_pad.size(1), -1))
|
||||
|
||||
return xs_pad, ilens, elayer_states # x: utt list of frame x dim
|
||||
|
||||
|
||||
class RNN(torch.nn.Module):
|
||||
"""RNN module
|
||||
|
||||
:param int idim: dimension of inputs
|
||||
:param int elayers: number of encoder layers
|
||||
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
|
||||
:param int hdim: number of final projection units
|
||||
:param float dropout: dropout rate
|
||||
:param str typ: The RNN type
|
||||
"""
|
||||
|
||||
def __init__(self, idim, elayers, cdim, hdim, dropout, typ="blstm"):
|
||||
super(RNN, self).__init__()
|
||||
bidir = typ[0] == "b"
|
||||
self.nbrnn = torch.nn.LSTM(idim, cdim, elayers, batch_first=True,
|
||||
dropout=dropout, bidirectional=bidir) if "lstm" in typ \
|
||||
else torch.nn.GRU(idim, cdim, elayers, batch_first=True, dropout=dropout,
|
||||
bidirectional=bidir)
|
||||
if bidir:
|
||||
self.l_last = torch.nn.Linear(cdim * 2, hdim)
|
||||
else:
|
||||
self.l_last = torch.nn.Linear(cdim, hdim)
|
||||
self.typ = typ
|
||||
|
||||
def forward(self, xs_pad, ilens, prev_state=None):
|
||||
"""RNN forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:param torch.Tensor prev_state: batch of previous RNN states
|
||||
:return: batch of hidden state sequences (B, Tmax, eprojs)
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
|
||||
xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True)
|
||||
self.nbrnn.flatten_parameters()
|
||||
if prev_state is not None and self.nbrnn.bidirectional:
|
||||
# We assume that when previous state is passed, it means that we're streaming the input
|
||||
# and therefore cannot propagate backward BRNN state (otherwise it goes in the wrong direction)
|
||||
prev_state = reset_backward_rnn_state(prev_state)
|
||||
ys, states = self.nbrnn(xs_pack, hx=prev_state)
|
||||
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
|
||||
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
|
||||
# (sum _utt frame_utt) x dim
|
||||
projected = torch.tanh(self.l_last(
|
||||
ys_pad.contiguous().view(-1, ys_pad.size(2))))
|
||||
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
|
||||
return xs_pad, ilens, states # x: utt list of frame x dim
|
||||
|
||||
|
||||
def reset_backward_rnn_state(states):
|
||||
"""Sets backward BRNN states to zeroes - useful in processing of sliding windows over the inputs"""
|
||||
if isinstance(states, (list, tuple)):
|
||||
for state in states:
|
||||
state[1::2] = 0.
|
||||
else:
|
||||
states[1::2] = 0.
|
||||
return states
|
||||
|
||||
|
||||
class VGG2L(torch.nn.Module):
|
||||
"""VGG-like module
|
||||
|
||||
:param int in_channel: number of input channels
|
||||
"""
|
||||
|
||||
def __init__(self, in_channel=1, downsample=True):
|
||||
super(VGG2L, self).__init__()
|
||||
# CNN layer (VGG motivated)
|
||||
self.conv1_1 = torch.nn.Conv2d(in_channel, 64, 3, stride=1, padding=1)
|
||||
self.conv1_2 = torch.nn.Conv2d(64, 64, 3, stride=1, padding=1)
|
||||
self.conv2_1 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1)
|
||||
self.conv2_2 = torch.nn.Conv2d(128, 128, 3, stride=1, padding=1)
|
||||
|
||||
self.in_channel = in_channel
|
||||
self.downsample = downsample
|
||||
if downsample:
|
||||
self.stride = 2
|
||||
else:
|
||||
self.stride = 1
|
||||
|
||||
def forward(self, xs_pad, ilens, **kwargs):
|
||||
"""VGG2L forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:return: batch of padded hidden state sequences (B, Tmax // 4, 128 * D // 4) if downsample
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
|
||||
|
||||
# x: utt x frame x dim
|
||||
# xs_pad = F.pad_sequence(xs_pad)
|
||||
|
||||
# x: utt x 1 (input channel num) x frame x dim
|
||||
xs_pad = xs_pad.view(xs_pad.size(0), xs_pad.size(1), self.in_channel,
|
||||
xs_pad.size(2) // self.in_channel).transpose(1, 2)
|
||||
|
||||
# NOTE: max_pool1d ?
|
||||
xs_pad = F.relu(self.conv1_1(xs_pad))
|
||||
xs_pad = F.relu(self.conv1_2(xs_pad))
|
||||
if self.downsample:
|
||||
xs_pad = F.max_pool2d(xs_pad, 2, stride=self.stride, ceil_mode=True)
|
||||
|
||||
xs_pad = F.relu(self.conv2_1(xs_pad))
|
||||
xs_pad = F.relu(self.conv2_2(xs_pad))
|
||||
if self.downsample:
|
||||
xs_pad = F.max_pool2d(xs_pad, 2, stride=self.stride, ceil_mode=True)
|
||||
if torch.is_tensor(ilens):
|
||||
ilens = ilens.cpu().numpy()
|
||||
else:
|
||||
ilens = np.array(ilens, dtype=np.float32)
|
||||
if self.downsample:
|
||||
ilens = np.array(np.ceil(ilens / 2), dtype=np.int64)
|
||||
ilens = np.array(
|
||||
np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64).tolist()
|
||||
|
||||
# x: utt_list of frame (remove zeropaded frames) x (input channel num x dim)
|
||||
xs_pad = xs_pad.transpose(1, 2)
|
||||
xs_pad = xs_pad.contiguous().view(
|
||||
xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3))
|
||||
return xs_pad, ilens, None # no state in this layer
|
||||
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
"""Encoder module
|
||||
|
||||
:param str etype: type of encoder network
|
||||
:param int idim: number of dimensions of encoder network
|
||||
:param int elayers: number of layers of encoder network
|
||||
:param int eunits: number of lstm units of encoder network
|
||||
:param int eprojs: number of projection units of encoder network
|
||||
:param np.ndarray subsample: list of subsampling numbers
|
||||
:param float dropout: dropout rate
|
||||
:param int in_channel: number of input channels
|
||||
"""
|
||||
|
||||
def __init__(self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1):
|
||||
super(Encoder, self).__init__()
|
||||
typ = etype.lstrip("vgg").rstrip("p")
|
||||
if typ not in ['lstm', 'gru', 'blstm', 'bgru']:
|
||||
logging.error("Error: need to specify an appropriate encoder architecture")
|
||||
|
||||
if etype.startswith("vgg"):
|
||||
if etype[-1] == "p":
|
||||
self.enc = torch.nn.ModuleList([VGG2L(in_channel),
|
||||
RNNP(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits,
|
||||
eprojs,
|
||||
subsample, dropout, typ=typ)])
|
||||
logging.info('Use CNN-VGG + ' + typ.upper() + 'P for encoder')
|
||||
else:
|
||||
self.enc = torch.nn.ModuleList([VGG2L(in_channel),
|
||||
RNN(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits,
|
||||
eprojs,
|
||||
dropout, typ=typ)])
|
||||
logging.info('Use CNN-VGG + ' + typ.upper() + ' for encoder')
|
||||
else:
|
||||
if etype[-1] == "p":
|
||||
self.enc = torch.nn.ModuleList(
|
||||
[RNNP(idim, elayers, eunits, eprojs, subsample, dropout, typ=typ)])
|
||||
logging.info(typ.upper() + ' with every-layer projection for encoder')
|
||||
else:
|
||||
self.enc = torch.nn.ModuleList([RNN(idim, elayers, eunits, eprojs, dropout, typ=typ)])
|
||||
logging.info(typ.upper() + ' without projection for encoder')
|
||||
|
||||
def forward(self, xs_pad, ilens, prev_states=None):
|
||||
"""Encoder forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...)
|
||||
:return: batch of hidden state sequences (B, Tmax, eprojs)
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if prev_states is None:
|
||||
prev_states = [None] * len(self.enc)
|
||||
assert len(prev_states) == len(self.enc)
|
||||
|
||||
current_states = []
|
||||
for module, prev_state in zip(self.enc, prev_states):
|
||||
xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
|
||||
current_states.append(states)
|
||||
|
||||
# make mask to remove bias value in padded part
|
||||
mask = to_device(self, make_pad_mask(ilens).unsqueeze(-1))
|
||||
|
||||
return xs_pad.masked_fill(mask, 0.0), ilens, current_states
|
||||
|
||||
|
||||
def encoder_for(args, idim, subsample):
|
||||
"""Instantiates an encoder module given the program arguments
|
||||
|
||||
:param Namespace args: The arguments
|
||||
:param int or List of integer idim: dimension of input, e.g. 83, or
|
||||
List of dimensions of inputs, e.g. [83,83]
|
||||
:param List or List of List subsample: subsample factors, e.g. [1,2,2,1,1], or
|
||||
List of subsample factors of each encoder. e.g. [[1,2,2,1,1], [1,2,2,1,1]]
|
||||
:rtype torch.nn.Module
|
||||
:return: The encoder module
|
||||
"""
|
||||
num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility
|
||||
if num_encs == 1:
|
||||
# compatible with single encoder asr mode
|
||||
return Encoder(args.etype, idim, args.elayers, args.eunits, args.eprojs, subsample, args.dropout_rate)
|
||||
elif num_encs >= 1:
|
||||
enc_list = torch.nn.ModuleList()
|
||||
for idx in range(num_encs):
|
||||
enc = Encoder(args.etype[idx], idim[idx], args.elayers[idx], args.eunits[idx], args.eprojs, subsample[idx],
|
||||
args.dropout_rate[idx])
|
||||
enc_list.append(enc)
|
||||
return enc_list
|
||||
else:
|
||||
raise ValueError("Number of encoders needs to be more than one. {}".format(num_encs))
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user