mirror of
https://github.com/babysor/Realtime-Voice-Clone-Chinese.git
synced 2026-05-12 11:35:56 +08:00
Refactor Project to 3 parts: Models, Control, Data
Need readme
This commit is contained in:
0
models/encoder/__init__.py
Normal file
0
models/encoder/__init__.py
Normal file
117
models/encoder/audio.py
Normal file
117
models/encoder/audio.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from scipy.ndimage.morphology import binary_dilation
|
||||
from models.encoder.params_data import *
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from warnings import warn
|
||||
import numpy as np
|
||||
import librosa
|
||||
import struct
|
||||
|
||||
try:
|
||||
import webrtcvad
|
||||
except:
|
||||
warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.")
|
||||
webrtcvad=None
|
||||
|
||||
int16_max = (2 ** 15) - 1
|
||||
|
||||
|
||||
def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
|
||||
source_sr: Optional[int] = None,
|
||||
normalize: Optional[bool] = True,
|
||||
trim_silence: Optional[bool] = True):
|
||||
"""
|
||||
Applies the preprocessing operations used in training the Speaker Encoder to a waveform
|
||||
either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
|
||||
|
||||
:param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
|
||||
just .wav), either the waveform as a numpy array of floats.
|
||||
:param source_sr: if passing an audio waveform, the sampling rate of the waveform before
|
||||
preprocessing. After preprocessing, the waveform's sampling rate will match the data
|
||||
hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
|
||||
this argument will be ignored.
|
||||
"""
|
||||
# Load the wav from disk if needed
|
||||
if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
|
||||
wav, source_sr = librosa.load(str(fpath_or_wav), sr=None)
|
||||
else:
|
||||
wav = fpath_or_wav
|
||||
|
||||
# Resample the wav if needed
|
||||
if source_sr is not None and source_sr != sampling_rate:
|
||||
wav = librosa.resample(wav, source_sr, sampling_rate)
|
||||
|
||||
# Apply the preprocessing: normalize volume and shorten long silences
|
||||
if normalize:
|
||||
wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
|
||||
if webrtcvad and trim_silence:
|
||||
wav = trim_long_silences(wav)
|
||||
|
||||
return wav
|
||||
|
||||
|
||||
def wav_to_mel_spectrogram(wav):
|
||||
"""
|
||||
Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
|
||||
Note: this not a log-mel spectrogram.
|
||||
"""
|
||||
frames = librosa.feature.melspectrogram(
|
||||
y=wav,
|
||||
sr=sampling_rate,
|
||||
n_fft=int(sampling_rate * mel_window_length / 1000),
|
||||
hop_length=int(sampling_rate * mel_window_step / 1000),
|
||||
n_mels=mel_n_channels
|
||||
)
|
||||
return frames.astype(np.float32).T
|
||||
|
||||
|
||||
def trim_long_silences(wav):
|
||||
"""
|
||||
Ensures that segments without voice in the waveform remain no longer than a
|
||||
threshold determined by the VAD parameters in params.py.
|
||||
|
||||
:param wav: the raw waveform as a numpy array of floats
|
||||
:return: the same waveform with silences trimmed away (length <= original wav length)
|
||||
"""
|
||||
# Compute the voice detection window size
|
||||
samples_per_window = (vad_window_length * sampling_rate) // 1000
|
||||
|
||||
# Trim the end of the audio to have a multiple of the window size
|
||||
wav = wav[:len(wav) - (len(wav) % samples_per_window)]
|
||||
|
||||
# Convert the float waveform to 16-bit mono PCM
|
||||
pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
|
||||
|
||||
# Perform voice activation detection
|
||||
voice_flags = []
|
||||
vad = webrtcvad.Vad(mode=3)
|
||||
for window_start in range(0, len(wav), samples_per_window):
|
||||
window_end = window_start + samples_per_window
|
||||
voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
|
||||
sample_rate=sampling_rate))
|
||||
voice_flags = np.array(voice_flags)
|
||||
|
||||
# Smooth the voice detection with a moving average
|
||||
def moving_average(array, width):
|
||||
array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
|
||||
ret = np.cumsum(array_padded, dtype=float)
|
||||
ret[width:] = ret[width:] - ret[:-width]
|
||||
return ret[width - 1:] / width
|
||||
|
||||
audio_mask = moving_average(voice_flags, vad_moving_average_width)
|
||||
audio_mask = np.round(audio_mask).astype(np.bool)
|
||||
|
||||
# Dilate the voiced regions
|
||||
audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
|
||||
audio_mask = np.repeat(audio_mask, samples_per_window)
|
||||
|
||||
return wav[audio_mask == True]
|
||||
|
||||
|
||||
def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
|
||||
if increase_only and decrease_only:
|
||||
raise ValueError("Both increase only and decrease only are set")
|
||||
dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
|
||||
if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
|
||||
return wav
|
||||
return wav * (10 ** (dBFS_change / 20))
|
||||
45
models/encoder/config.py
Normal file
45
models/encoder/config.py
Normal file
@@ -0,0 +1,45 @@
|
||||
librispeech_datasets = {
|
||||
"train": {
|
||||
"clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
|
||||
"other": ["LibriSpeech/train-other-500"]
|
||||
},
|
||||
"test": {
|
||||
"clean": ["LibriSpeech/test-clean"],
|
||||
"other": ["LibriSpeech/test-other"]
|
||||
},
|
||||
"dev": {
|
||||
"clean": ["LibriSpeech/dev-clean"],
|
||||
"other": ["LibriSpeech/dev-other"]
|
||||
},
|
||||
}
|
||||
libritts_datasets = {
|
||||
"train": {
|
||||
"clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
|
||||
"other": ["LibriTTS/train-other-500"]
|
||||
},
|
||||
"test": {
|
||||
"clean": ["LibriTTS/test-clean"],
|
||||
"other": ["LibriTTS/test-other"]
|
||||
},
|
||||
"dev": {
|
||||
"clean": ["LibriTTS/dev-clean"],
|
||||
"other": ["LibriTTS/dev-other"]
|
||||
},
|
||||
}
|
||||
voxceleb_datasets = {
|
||||
"voxceleb1" : {
|
||||
"train": ["VoxCeleb1/wav"],
|
||||
"test": ["VoxCeleb1/test_wav"]
|
||||
},
|
||||
"voxceleb2" : {
|
||||
"train": ["VoxCeleb2/dev/aac"],
|
||||
"test": ["VoxCeleb2/test_wav"]
|
||||
}
|
||||
}
|
||||
|
||||
other_datasets = [
|
||||
"LJSpeech-1.1",
|
||||
"VCTK-Corpus/wav48",
|
||||
]
|
||||
|
||||
anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
|
||||
2
models/encoder/data_objects/__init__.py
Normal file
2
models/encoder/data_objects/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from models.encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
|
||||
from models.encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
|
||||
37
models/encoder/data_objects/random_cycler.py
Normal file
37
models/encoder/data_objects/random_cycler.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import random
|
||||
|
||||
class RandomCycler:
|
||||
"""
|
||||
Creates an internal copy of a sequence and allows access to its items in a constrained random
|
||||
order. For a source sequence of n items and one or several consecutive queries of a total
|
||||
of m items, the following guarantees hold (one implies the other):
|
||||
- Each item will be returned between m // n and ((m - 1) // n) + 1 times.
|
||||
- Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
|
||||
"""
|
||||
|
||||
def __init__(self, source):
|
||||
if len(source) == 0:
|
||||
raise Exception("Can't create RandomCycler from an empty collection")
|
||||
self.all_items = list(source)
|
||||
self.next_items = []
|
||||
|
||||
def sample(self, count: int):
|
||||
shuffle = lambda l: random.sample(l, len(l))
|
||||
|
||||
out = []
|
||||
while count > 0:
|
||||
if count >= len(self.all_items):
|
||||
out.extend(shuffle(list(self.all_items)))
|
||||
count -= len(self.all_items)
|
||||
continue
|
||||
n = min(count, len(self.next_items))
|
||||
out.extend(self.next_items[:n])
|
||||
count -= n
|
||||
self.next_items = self.next_items[n:]
|
||||
if len(self.next_items) == 0:
|
||||
self.next_items = shuffle(list(self.all_items))
|
||||
return out
|
||||
|
||||
def __next__(self):
|
||||
return self.sample(1)[0]
|
||||
|
||||
40
models/encoder/data_objects/speaker.py
Normal file
40
models/encoder/data_objects/speaker.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from models.encoder.data_objects.random_cycler import RandomCycler
|
||||
from models.encoder.data_objects.utterance import Utterance
|
||||
from pathlib import Path
|
||||
|
||||
# Contains the set of utterances of a single speaker
|
||||
class Speaker:
|
||||
def __init__(self, root: Path):
|
||||
self.root = root
|
||||
self.name = root.name
|
||||
self.utterances = None
|
||||
self.utterance_cycler = None
|
||||
|
||||
def _load_utterances(self):
|
||||
with self.root.joinpath("_sources.txt").open("r") as sources_file:
|
||||
sources = [l.split(",") for l in sources_file]
|
||||
sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
|
||||
self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
|
||||
self.utterance_cycler = RandomCycler(self.utterances)
|
||||
|
||||
def random_partial(self, count, n_frames):
|
||||
"""
|
||||
Samples a batch of <count> unique partial utterances from the disk in a way that all
|
||||
utterances come up at least once every two cycles and in a random order every time.
|
||||
|
||||
:param count: The number of partial utterances to sample from the set of utterances from
|
||||
that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
|
||||
the number of utterances available.
|
||||
:param n_frames: The number of frames in the partial utterance.
|
||||
:return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
|
||||
frames are the frames of the partial utterances and range is the range of the partial
|
||||
utterance with regard to the complete utterance.
|
||||
"""
|
||||
if self.utterances is None:
|
||||
self._load_utterances()
|
||||
|
||||
utterances = self.utterance_cycler.sample(count)
|
||||
|
||||
a = [(u,) + u.random_partial(n_frames) for u in utterances]
|
||||
|
||||
return a
|
||||
12
models/encoder/data_objects/speaker_batch.py
Normal file
12
models/encoder/data_objects/speaker_batch.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import numpy as np
|
||||
from typing import List
|
||||
from models.encoder.data_objects.speaker import Speaker
|
||||
|
||||
class SpeakerBatch:
|
||||
def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
|
||||
self.speakers = speakers
|
||||
self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
|
||||
|
||||
# Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
|
||||
# 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
|
||||
self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
|
||||
56
models/encoder/data_objects/speaker_verification_dataset.py
Normal file
56
models/encoder/data_objects/speaker_verification_dataset.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from models.encoder.data_objects.random_cycler import RandomCycler
|
||||
from models.encoder.data_objects.speaker_batch import SpeakerBatch
|
||||
from models.encoder.data_objects.speaker import Speaker
|
||||
from models.encoder.params_data import partials_n_frames
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from pathlib import Path
|
||||
|
||||
# TODO: improve with a pool of speakers for data efficiency
|
||||
|
||||
class SpeakerVerificationDataset(Dataset):
|
||||
def __init__(self, datasets_root: Path):
|
||||
self.root = datasets_root
|
||||
speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
|
||||
if len(speaker_dirs) == 0:
|
||||
raise Exception("No speakers found. Make sure you are pointing to the directory "
|
||||
"containing all preprocessed speaker directories.")
|
||||
self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
|
||||
self.speaker_cycler = RandomCycler(self.speakers)
|
||||
|
||||
def __len__(self):
|
||||
return int(1e10)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return next(self.speaker_cycler)
|
||||
|
||||
def get_logs(self):
|
||||
log_string = ""
|
||||
for log_fpath in self.root.glob("*.txt"):
|
||||
with log_fpath.open("r") as log_file:
|
||||
log_string += "".join(log_file.readlines())
|
||||
return log_string
|
||||
|
||||
|
||||
class SpeakerVerificationDataLoader(DataLoader):
|
||||
def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
|
||||
batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
|
||||
worker_init_fn=None):
|
||||
self.utterances_per_speaker = utterances_per_speaker
|
||||
|
||||
super().__init__(
|
||||
dataset=dataset,
|
||||
batch_size=speakers_per_batch,
|
||||
shuffle=False,
|
||||
sampler=sampler,
|
||||
batch_sampler=batch_sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=self.collate,
|
||||
pin_memory=pin_memory,
|
||||
drop_last=False,
|
||||
timeout=timeout,
|
||||
worker_init_fn=worker_init_fn
|
||||
)
|
||||
|
||||
def collate(self, speakers):
|
||||
return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
|
||||
|
||||
26
models/encoder/data_objects/utterance.py
Normal file
26
models/encoder/data_objects/utterance.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Utterance:
|
||||
def __init__(self, frames_fpath, wave_fpath):
|
||||
self.frames_fpath = frames_fpath
|
||||
self.wave_fpath = wave_fpath
|
||||
|
||||
def get_frames(self):
|
||||
return np.load(self.frames_fpath)
|
||||
|
||||
def random_partial(self, n_frames):
|
||||
"""
|
||||
Crops the frames into a partial utterance of n_frames
|
||||
|
||||
:param n_frames: The number of frames of the partial utterance
|
||||
:return: the partial utterance frames and a tuple indicating the start and end of the
|
||||
partial utterance in the complete utterance.
|
||||
"""
|
||||
frames = self.get_frames()
|
||||
if frames.shape[0] == n_frames:
|
||||
start = 0
|
||||
else:
|
||||
start = np.random.randint(0, frames.shape[0] - n_frames)
|
||||
end = start + n_frames
|
||||
return frames[start:end], (start, end)
|
||||
195
models/encoder/inference.py
Normal file
195
models/encoder/inference.py
Normal file
@@ -0,0 +1,195 @@
|
||||
from models.encoder.params_data import *
|
||||
from models.encoder.model import SpeakerEncoder
|
||||
from models.encoder.audio import preprocess_wav # We want to expose this function from here
|
||||
from matplotlib import cm
|
||||
from models.encoder import audio
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
_model = None # type: SpeakerEncoder
|
||||
_device = None # type: torch.device
|
||||
|
||||
|
||||
def load_model(weights_fpath: Path, device=None):
|
||||
"""
|
||||
Loads the model in memory. If this function is not explicitely called, it will be run on the
|
||||
first call to embed_frames() with the default weights file.
|
||||
|
||||
:param weights_fpath: the path to saved model weights.
|
||||
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
|
||||
model will be loaded and will run on this device. Outputs will however always be on the cpu.
|
||||
If None, will default to your GPU if it"s available, otherwise your CPU.
|
||||
"""
|
||||
# TODO: I think the slow loading of the encoder might have something to do with the device it
|
||||
# was saved on. Worth investigating.
|
||||
global _model, _device
|
||||
if device is None:
|
||||
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
elif isinstance(device, str):
|
||||
_device = torch.device(device)
|
||||
_model = SpeakerEncoder(_device, torch.device("cpu"))
|
||||
checkpoint = torch.load(weights_fpath, _device)
|
||||
_model.load_state_dict(checkpoint["model_state"])
|
||||
_model.eval()
|
||||
print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
|
||||
return _model
|
||||
|
||||
def set_model(model, device=None):
|
||||
global _model, _device
|
||||
_model = model
|
||||
if device is None:
|
||||
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
_device = device
|
||||
_model.to(device)
|
||||
|
||||
def is_loaded():
|
||||
return _model is not None
|
||||
|
||||
|
||||
def embed_frames_batch(frames_batch):
|
||||
"""
|
||||
Computes embeddings for a batch of mel spectrogram.
|
||||
|
||||
:param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
|
||||
(batch_size, n_frames, n_channels)
|
||||
:return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
|
||||
"""
|
||||
if _model is None:
|
||||
raise Exception("Model was not loaded. Call load_model() before inference.")
|
||||
|
||||
frames = torch.from_numpy(frames_batch).to(_device)
|
||||
embed = _model.forward(frames).detach().cpu().numpy()
|
||||
return embed
|
||||
|
||||
|
||||
def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
|
||||
min_pad_coverage=0.75, overlap=0.5, rate=None):
|
||||
"""
|
||||
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
|
||||
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
|
||||
spectrogram slices are returned, so as to make each partial utterance waveform correspond to
|
||||
its spectrogram. This function assumes that the mel spectrogram parameters used are those
|
||||
defined in params_data.py.
|
||||
|
||||
The returned ranges may be indexing further than the length of the waveform. It is
|
||||
recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
|
||||
|
||||
:param n_samples: the number of samples in the waveform
|
||||
:param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
|
||||
utterance
|
||||
:param min_pad_coverage: when reaching the last partial utterance, it may or may not have
|
||||
enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
|
||||
then the last partial utterance will be considered, as if we padded the audio. Otherwise,
|
||||
it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
|
||||
utterance, this parameter is ignored so that the function always returns at least 1 slice.
|
||||
:param overlap: by how much the partial utterance should overlap. If set to 0, the partial
|
||||
utterances are entirely disjoint.
|
||||
:return: the waveform slices and mel spectrogram slices as lists of array slices. Index
|
||||
respectively the waveform and the mel spectrogram with these slices to obtain the partial
|
||||
utterances.
|
||||
"""
|
||||
assert 0 <= overlap < 1
|
||||
assert 0 < min_pad_coverage <= 1
|
||||
|
||||
if rate != None:
|
||||
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
||||
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
||||
frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
|
||||
else:
|
||||
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
||||
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
||||
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
|
||||
|
||||
assert 0 < frame_step, "The rate is too high"
|
||||
assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
|
||||
(sampling_rate / (samples_per_frame * partials_n_frames))
|
||||
|
||||
# Compute the slices
|
||||
wav_slices, mel_slices = [], []
|
||||
steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
|
||||
for i in range(0, steps, frame_step):
|
||||
mel_range = np.array([i, i + partial_utterance_n_frames])
|
||||
wav_range = mel_range * samples_per_frame
|
||||
mel_slices.append(slice(*mel_range))
|
||||
wav_slices.append(slice(*wav_range))
|
||||
|
||||
# Evaluate whether extra padding is warranted or not
|
||||
last_wav_range = wav_slices[-1]
|
||||
coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
|
||||
if coverage < min_pad_coverage and len(mel_slices) > 1:
|
||||
mel_slices = mel_slices[:-1]
|
||||
wav_slices = wav_slices[:-1]
|
||||
|
||||
return wav_slices, mel_slices
|
||||
|
||||
|
||||
def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
|
||||
"""
|
||||
Computes an embedding for a single utterance.
|
||||
|
||||
# TODO: handle multiple wavs to benefit from batching on GPU
|
||||
:param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
|
||||
:param using_partials: if True, then the utterance is split in partial utterances of
|
||||
<partial_utterance_n_frames> frames and the utterance embedding is computed from their
|
||||
normalized average. If False, the utterance is instead computed from feeding the entire
|
||||
spectogram to the network.
|
||||
:param return_partials: if True, the partial embeddings will also be returned along with the
|
||||
wav slices that correspond to the partial embeddings.
|
||||
:param kwargs: additional arguments to compute_partial_splits()
|
||||
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
|
||||
<return_partials> is True, the partial utterances as a numpy array of float32 of shape
|
||||
(n_partials, model_embedding_size) and the wav partials as a list of slices will also be
|
||||
returned. If <using_partials> is simultaneously set to False, both these values will be None
|
||||
instead.
|
||||
"""
|
||||
# Process the entire utterance if not using partials
|
||||
if not using_partials:
|
||||
frames = audio.wav_to_mel_spectrogram(wav)
|
||||
embed = embed_frames_batch(frames[None, ...])[0]
|
||||
if return_partials:
|
||||
return embed, None, None
|
||||
return embed
|
||||
|
||||
# Compute where to split the utterance into partials and pad if necessary
|
||||
wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
|
||||
max_wave_length = wave_slices[-1].stop
|
||||
if max_wave_length >= len(wav):
|
||||
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
||||
|
||||
# Split the utterance into partials
|
||||
frames = audio.wav_to_mel_spectrogram(wav)
|
||||
frames_batch = np.array([frames[s] for s in mel_slices])
|
||||
partial_embeds = embed_frames_batch(frames_batch)
|
||||
|
||||
# Compute the utterance embedding from the partial embeddings
|
||||
raw_embed = np.mean(partial_embeds, axis=0)
|
||||
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
||||
|
||||
if return_partials:
|
||||
return embed, partial_embeds, wave_slices
|
||||
return embed
|
||||
|
||||
|
||||
def embed_speaker(wavs, **kwargs):
|
||||
raise NotImplemented()
|
||||
|
||||
|
||||
def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
|
||||
if ax is None:
|
||||
ax = plt.gca()
|
||||
|
||||
if shape is None:
|
||||
height = int(np.sqrt(len(embed)))
|
||||
shape = (height, -1)
|
||||
embed = embed.reshape(shape)
|
||||
|
||||
cmap = cm.get_cmap()
|
||||
mappable = ax.imshow(embed, cmap=cmap)
|
||||
cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
|
||||
sm = cm.ScalarMappable(cmap=cmap)
|
||||
sm.set_clim(*color_range)
|
||||
|
||||
ax.set_xticks([]), ax.set_yticks([])
|
||||
ax.set_title(title)
|
||||
135
models/encoder/model.py
Normal file
135
models/encoder/model.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from models.encoder.params_model import *
|
||||
from models.encoder.params_data import *
|
||||
from scipy.interpolate import interp1d
|
||||
from sklearn.metrics import roc_curve
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from scipy.optimize import brentq
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class SpeakerEncoder(nn.Module):
|
||||
def __init__(self, device, loss_device):
|
||||
super().__init__()
|
||||
self.loss_device = loss_device
|
||||
|
||||
# Network defition
|
||||
self.lstm = nn.LSTM(input_size=mel_n_channels,
|
||||
hidden_size=model_hidden_size,
|
||||
num_layers=model_num_layers,
|
||||
batch_first=True).to(device)
|
||||
self.linear = nn.Linear(in_features=model_hidden_size,
|
||||
out_features=model_embedding_size).to(device)
|
||||
self.relu = torch.nn.ReLU().to(device)
|
||||
|
||||
# Cosine similarity scaling (with fixed initial parameter values)
|
||||
self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
|
||||
self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
|
||||
|
||||
# Loss
|
||||
self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
|
||||
|
||||
def do_gradient_ops(self):
|
||||
# Gradient scale
|
||||
self.similarity_weight.grad *= 0.01
|
||||
self.similarity_bias.grad *= 0.01
|
||||
|
||||
# Gradient clipping
|
||||
clip_grad_norm_(self.parameters(), 3, norm_type=2)
|
||||
|
||||
def forward(self, utterances, hidden_init=None):
|
||||
"""
|
||||
Computes the embeddings of a batch of utterance spectrograms.
|
||||
|
||||
:param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
|
||||
(batch_size, n_frames, n_channels)
|
||||
:param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
|
||||
batch_size, hidden_size). Will default to a tensor of zeros if None.
|
||||
:return: the embeddings as a tensor of shape (batch_size, embedding_size)
|
||||
"""
|
||||
# Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
|
||||
# and the final cell state.
|
||||
out, (hidden, cell) = self.lstm(utterances, hidden_init)
|
||||
|
||||
# We take only the hidden state of the last layer
|
||||
embeds_raw = self.relu(self.linear(hidden[-1]))
|
||||
|
||||
# L2-normalize it
|
||||
embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
|
||||
|
||||
return embeds
|
||||
|
||||
def similarity_matrix(self, embeds):
|
||||
"""
|
||||
Computes the similarity matrix according the section 2.1 of GE2E.
|
||||
|
||||
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
||||
utterances_per_speaker, embedding_size)
|
||||
:return: the similarity matrix as a tensor of shape (speakers_per_batch,
|
||||
utterances_per_speaker, speakers_per_batch)
|
||||
"""
|
||||
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
||||
|
||||
# Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
|
||||
centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
|
||||
centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
|
||||
|
||||
# Exclusive centroids (1 per utterance)
|
||||
centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
|
||||
centroids_excl /= (utterances_per_speaker - 1)
|
||||
centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)
|
||||
|
||||
# Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
|
||||
# product of these vectors (which is just an element-wise multiplication reduced by a sum).
|
||||
# We vectorize the computation for efficiency.
|
||||
sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
|
||||
speakers_per_batch).to(self.loss_device)
|
||||
mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
|
||||
for j in range(speakers_per_batch):
|
||||
mask = np.where(mask_matrix[j])[0]
|
||||
sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
|
||||
sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
|
||||
|
||||
## Even more vectorized version (slower maybe because of transpose)
|
||||
# sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
|
||||
# ).to(self.loss_device)
|
||||
# eye = np.eye(speakers_per_batch, dtype=np.int)
|
||||
# mask = np.where(1 - eye)
|
||||
# sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
|
||||
# mask = np.where(eye)
|
||||
# sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
|
||||
# sim_matrix2 = sim_matrix2.transpose(1, 2)
|
||||
|
||||
sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
|
||||
return sim_matrix
|
||||
|
||||
def loss(self, embeds):
|
||||
"""
|
||||
Computes the softmax loss according the section 2.1 of GE2E.
|
||||
|
||||
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
||||
utterances_per_speaker, embedding_size)
|
||||
:return: the loss and the EER for this batch of embeddings.
|
||||
"""
|
||||
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
||||
|
||||
# Loss
|
||||
sim_matrix = self.similarity_matrix(embeds)
|
||||
sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
|
||||
speakers_per_batch))
|
||||
ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
|
||||
target = torch.from_numpy(ground_truth).long().to(self.loss_device)
|
||||
loss = self.loss_fn(sim_matrix, target)
|
||||
|
||||
# EER (not backpropagated)
|
||||
with torch.no_grad():
|
||||
inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
|
||||
labels = np.array([inv_argmax(i) for i in ground_truth])
|
||||
preds = sim_matrix.detach().cpu().numpy()
|
||||
|
||||
# Snippet from https://yangcha.github.io/EER-ROC/
|
||||
fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
|
||||
eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
|
||||
|
||||
return loss, eer
|
||||
29
models/encoder/params_data.py
Normal file
29
models/encoder/params_data.py
Normal file
@@ -0,0 +1,29 @@
|
||||
|
||||
## Mel-filterbank
|
||||
mel_window_length = 25 # In milliseconds
|
||||
mel_window_step = 10 # In milliseconds
|
||||
mel_n_channels = 40
|
||||
|
||||
|
||||
## Audio
|
||||
sampling_rate = 16000
|
||||
# Number of spectrogram frames in a partial utterance
|
||||
partials_n_frames = 160 # 1600 ms
|
||||
# Number of spectrogram frames at inference
|
||||
inference_n_frames = 80 # 800 ms
|
||||
|
||||
|
||||
## Voice Activation Detection
|
||||
# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
|
||||
# This sets the granularity of the VAD. Should not need to be changed.
|
||||
vad_window_length = 30 # In milliseconds
|
||||
# Number of frames to average together when performing the moving average smoothing.
|
||||
# The larger this value, the larger the VAD variations must be to not get smoothed out.
|
||||
vad_moving_average_width = 8
|
||||
# Maximum number of consecutive silent frames a segment can have.
|
||||
vad_max_silence_length = 6
|
||||
|
||||
|
||||
## Audio volume normalization
|
||||
audio_norm_target_dBFS = -30
|
||||
|
||||
11
models/encoder/params_model.py
Normal file
11
models/encoder/params_model.py
Normal file
@@ -0,0 +1,11 @@
|
||||
|
||||
## Model parameters
|
||||
model_hidden_size = 256
|
||||
model_embedding_size = 256
|
||||
model_num_layers = 3
|
||||
|
||||
|
||||
## Training parameters
|
||||
learning_rate_init = 1e-4
|
||||
speakers_per_batch = 64
|
||||
utterances_per_speaker = 10
|
||||
184
models/encoder/preprocess.py
Normal file
184
models/encoder/preprocess.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from multiprocess.pool import ThreadPool
|
||||
from models.encoder.params_data import *
|
||||
from models.encoder.config import librispeech_datasets, anglophone_nationalites
|
||||
from datetime import datetime
|
||||
from models.encoder import audio
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DatasetLog:
|
||||
"""
|
||||
Registers metadata about the dataset in a text file.
|
||||
"""
|
||||
def __init__(self, root, name):
|
||||
self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
|
||||
self.sample_data = dict()
|
||||
|
||||
start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
|
||||
self.write_line("Creating dataset %s on %s" % (name, start_time))
|
||||
self.write_line("-----")
|
||||
self._log_params()
|
||||
|
||||
def _log_params(self):
|
||||
from models.encoder import params_data
|
||||
self.write_line("Parameter values:")
|
||||
for param_name in (p for p in dir(params_data) if not p.startswith("__")):
|
||||
value = getattr(params_data, param_name)
|
||||
self.write_line("\t%s: %s" % (param_name, value))
|
||||
self.write_line("-----")
|
||||
|
||||
def write_line(self, line):
|
||||
self.text_file.write("%s\n" % line)
|
||||
|
||||
def add_sample(self, **kwargs):
|
||||
for param_name, value in kwargs.items():
|
||||
if not param_name in self.sample_data:
|
||||
self.sample_data[param_name] = []
|
||||
self.sample_data[param_name].append(value)
|
||||
|
||||
def finalize(self):
|
||||
self.write_line("Statistics:")
|
||||
for param_name, values in self.sample_data.items():
|
||||
self.write_line("\t%s:" % param_name)
|
||||
self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
|
||||
self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
|
||||
self.write_line("-----")
|
||||
end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
|
||||
self.write_line("Finished on %s" % end_time)
|
||||
self.text_file.close()
|
||||
|
||||
|
||||
def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
|
||||
dataset_root = datasets_root.joinpath(dataset_name)
|
||||
if not dataset_root.exists():
|
||||
print("Couldn\'t find %s, skipping this dataset." % dataset_root)
|
||||
return None, None
|
||||
return dataset_root, DatasetLog(out_dir, dataset_name)
|
||||
|
||||
|
||||
def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
|
||||
skip_existing, logger):
|
||||
print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
|
||||
|
||||
# Function to preprocess utterances for one speaker
|
||||
def preprocess_speaker(speaker_dir: Path):
|
||||
# Give a name to the speaker that includes its dataset
|
||||
speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
|
||||
|
||||
# Create an output directory with that name, as well as a txt file containing a
|
||||
# reference to each source file.
|
||||
speaker_out_dir = out_dir.joinpath(speaker_name)
|
||||
speaker_out_dir.mkdir(exist_ok=True)
|
||||
sources_fpath = speaker_out_dir.joinpath("_sources.txt")
|
||||
|
||||
# There's a possibility that the preprocessing was interrupted earlier, check if
|
||||
# there already is a sources file.
|
||||
if sources_fpath.exists():
|
||||
try:
|
||||
with sources_fpath.open("r") as sources_file:
|
||||
existing_fnames = {line.split(",")[0] for line in sources_file}
|
||||
except:
|
||||
existing_fnames = {}
|
||||
else:
|
||||
existing_fnames = {}
|
||||
|
||||
# Gather all audio files for that speaker recursively
|
||||
sources_file = sources_fpath.open("a" if skip_existing else "w")
|
||||
for in_fpath in speaker_dir.glob("**/*.%s" % extension):
|
||||
# Check if the target output file already exists
|
||||
out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
|
||||
out_fname = out_fname.replace(".%s" % extension, ".npy")
|
||||
if skip_existing and out_fname in existing_fnames:
|
||||
continue
|
||||
|
||||
# Load and preprocess the waveform
|
||||
wav = audio.preprocess_wav(in_fpath)
|
||||
if len(wav) == 0:
|
||||
continue
|
||||
|
||||
# Create the mel spectrogram, discard those that are too short
|
||||
frames = audio.wav_to_mel_spectrogram(wav)
|
||||
if len(frames) < partials_n_frames:
|
||||
continue
|
||||
|
||||
out_fpath = speaker_out_dir.joinpath(out_fname)
|
||||
np.save(out_fpath, frames)
|
||||
logger.add_sample(duration=len(wav) / sampling_rate)
|
||||
sources_file.write("%s,%s\n" % (out_fname, in_fpath))
|
||||
|
||||
sources_file.close()
|
||||
|
||||
# Process the utterances for each speaker
|
||||
with ThreadPool(8) as pool:
|
||||
list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
|
||||
unit="speakers"))
|
||||
logger.finalize()
|
||||
print("Done preprocessing %s.\n" % dataset_name)
|
||||
|
||||
def preprocess_aidatatang_200zh(datasets_root: Path, out_dir: Path, skip_existing=False):
|
||||
dataset_name = "aidatatang_200zh"
|
||||
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
||||
if not dataset_root:
|
||||
return
|
||||
# Preprocess all speakers
|
||||
speaker_dirs = list(dataset_root.joinpath("corpus", "train").glob("*"))
|
||||
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
|
||||
skip_existing, logger)
|
||||
|
||||
def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
|
||||
for dataset_name in librispeech_datasets["train"]["other"]:
|
||||
# Initialize the preprocessing
|
||||
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
||||
if not dataset_root:
|
||||
return
|
||||
|
||||
# Preprocess all speakers
|
||||
speaker_dirs = list(dataset_root.glob("*"))
|
||||
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac",
|
||||
skip_existing, logger)
|
||||
|
||||
|
||||
def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
|
||||
# Initialize the preprocessing
|
||||
dataset_name = "VoxCeleb1"
|
||||
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
||||
if not dataset_root:
|
||||
return
|
||||
|
||||
# Get the contents of the meta file
|
||||
with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
|
||||
metadata = [line.split("\t") for line in metafile][1:]
|
||||
|
||||
# Select the ID and the nationality, filter out non-anglophone speakers
|
||||
nationalities = {line[0]: line[3] for line in metadata}
|
||||
keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
|
||||
nationality.lower() in anglophone_nationalites]
|
||||
print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
|
||||
(len(keep_speaker_ids), len(nationalities)))
|
||||
|
||||
# Get the speaker directories for anglophone speakers only
|
||||
speaker_dirs = dataset_root.joinpath("wav").glob("*")
|
||||
speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
|
||||
speaker_dir.name in keep_speaker_ids]
|
||||
print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
|
||||
(len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
|
||||
|
||||
# Preprocess all speakers
|
||||
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
|
||||
skip_existing, logger)
|
||||
|
||||
|
||||
def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
|
||||
# Initialize the preprocessing
|
||||
dataset_name = "VoxCeleb2"
|
||||
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
||||
if not dataset_root:
|
||||
return
|
||||
|
||||
# Get the speaker directories
|
||||
# Preprocess all speakers
|
||||
speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
|
||||
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a",
|
||||
skip_existing, logger)
|
||||
123
models/encoder/train.py
Normal file
123
models/encoder/train.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from models.encoder.visualizations import Visualizations
|
||||
from models.encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
|
||||
from models.encoder.params_model import *
|
||||
from models.encoder.model import SpeakerEncoder
|
||||
from utils.profiler import Profiler
|
||||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
def sync(device: torch.device):
|
||||
# For correct profiling (cuda operations are async)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
|
||||
def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
|
||||
backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
|
||||
no_visdom: bool):
|
||||
# Create a dataset and a dataloader
|
||||
dataset = SpeakerVerificationDataset(clean_data_root)
|
||||
loader = SpeakerVerificationDataLoader(
|
||||
dataset,
|
||||
speakers_per_batch,
|
||||
utterances_per_speaker,
|
||||
num_workers=8,
|
||||
)
|
||||
|
||||
# Setup the device on which to run the forward pass and the loss. These can be different,
|
||||
# because the forward pass is faster on the GPU whereas the loss is often (depending on your
|
||||
# hyperparameters) faster on the CPU.
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# FIXME: currently, the gradient is None if loss_device is cuda
|
||||
loss_device = torch.device("cpu")
|
||||
|
||||
# Create the model and the optimizer
|
||||
model = SpeakerEncoder(device, loss_device)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
|
||||
init_step = 1
|
||||
|
||||
# Configure file path for the model
|
||||
state_fpath = models_dir.joinpath(run_id + ".pt")
|
||||
backup_dir = models_dir.joinpath(run_id + "_backups")
|
||||
|
||||
# Load any existing model
|
||||
if not force_restart:
|
||||
if state_fpath.exists():
|
||||
print("Found existing model \"%s\", loading it and resuming training." % run_id)
|
||||
checkpoint = torch.load(state_fpath)
|
||||
init_step = checkpoint["step"]
|
||||
model.load_state_dict(checkpoint["model_state"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
||||
optimizer.param_groups[0]["lr"] = learning_rate_init
|
||||
else:
|
||||
print("No model \"%s\" found, starting training from scratch." % run_id)
|
||||
else:
|
||||
print("Starting the training from scratch.")
|
||||
model.train()
|
||||
|
||||
# Initialize the visualization environment
|
||||
vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
|
||||
vis.log_dataset(dataset)
|
||||
vis.log_params()
|
||||
device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
|
||||
vis.log_implementation({"Device": device_name})
|
||||
|
||||
# Training loop
|
||||
profiler = Profiler(summarize_every=10, disabled=False)
|
||||
for step, speaker_batch in enumerate(loader, init_step):
|
||||
profiler.tick("Blocking, waiting for batch (threaded)")
|
||||
|
||||
# Forward pass
|
||||
inputs = torch.from_numpy(speaker_batch.data).to(device)
|
||||
sync(device)
|
||||
profiler.tick("Data to %s" % device)
|
||||
embeds = model(inputs)
|
||||
sync(device)
|
||||
profiler.tick("Forward pass")
|
||||
embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
|
||||
loss, eer = model.loss(embeds_loss)
|
||||
sync(loss_device)
|
||||
profiler.tick("Loss")
|
||||
|
||||
# Backward pass
|
||||
model.zero_grad()
|
||||
loss.backward()
|
||||
profiler.tick("Backward pass")
|
||||
model.do_gradient_ops()
|
||||
optimizer.step()
|
||||
profiler.tick("Parameter update")
|
||||
|
||||
# Update visualizations
|
||||
# learning_rate = optimizer.param_groups[0]["lr"]
|
||||
vis.update(loss.item(), eer, step)
|
||||
|
||||
# Draw projections and save them to the backup folder
|
||||
if umap_every != 0 and step % umap_every == 0:
|
||||
print("Drawing and saving projections (step %d)" % step)
|
||||
backup_dir.mkdir(exist_ok=True)
|
||||
projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
|
||||
embeds = embeds.detach().cpu().numpy()
|
||||
vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
|
||||
vis.save()
|
||||
|
||||
# Overwrite the latest version of the model
|
||||
if save_every != 0 and step % save_every == 0:
|
||||
print("Saving the model (step %d)" % step)
|
||||
torch.save({
|
||||
"step": step + 1,
|
||||
"model_state": model.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
}, state_fpath)
|
||||
|
||||
# Make a backup
|
||||
if backup_every != 0 and step % backup_every == 0:
|
||||
print("Making a backup (step %d)" % step)
|
||||
backup_dir.mkdir(exist_ok=True)
|
||||
backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
|
||||
torch.save({
|
||||
"step": step + 1,
|
||||
"model_state": model.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
}, backup_fpath)
|
||||
|
||||
profiler.tick("Extras (visualizations, saving)")
|
||||
178
models/encoder/visualizations.py
Normal file
178
models/encoder/visualizations.py
Normal file
@@ -0,0 +1,178 @@
|
||||
from models.encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
|
||||
from datetime import datetime
|
||||
from time import perf_counter as timer
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
# import webbrowser
|
||||
import visdom
|
||||
import umap
|
||||
|
||||
colormap = np.array([
|
||||
[76, 255, 0],
|
||||
[0, 127, 70],
|
||||
[255, 0, 0],
|
||||
[255, 217, 38],
|
||||
[0, 135, 255],
|
||||
[165, 0, 165],
|
||||
[255, 167, 255],
|
||||
[0, 255, 255],
|
||||
[255, 96, 38],
|
||||
[142, 76, 0],
|
||||
[33, 0, 127],
|
||||
[0, 0, 0],
|
||||
[183, 183, 183],
|
||||
], dtype=np.float) / 255
|
||||
|
||||
|
||||
class Visualizations:
|
||||
def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
|
||||
# Tracking data
|
||||
self.last_update_timestamp = timer()
|
||||
self.update_every = update_every
|
||||
self.step_times = []
|
||||
self.losses = []
|
||||
self.eers = []
|
||||
print("Updating the visualizations every %d steps." % update_every)
|
||||
|
||||
# If visdom is disabled TODO: use a better paradigm for that
|
||||
self.disabled = disabled
|
||||
if self.disabled:
|
||||
return
|
||||
|
||||
# Set the environment name
|
||||
now = str(datetime.now().strftime("%d-%m %Hh%M"))
|
||||
if env_name is None:
|
||||
self.env_name = now
|
||||
else:
|
||||
self.env_name = "%s (%s)" % (env_name, now)
|
||||
|
||||
# Connect to visdom and open the corresponding window in the browser
|
||||
try:
|
||||
self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
|
||||
except ConnectionError:
|
||||
raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
|
||||
"start it.")
|
||||
# webbrowser.open("http://localhost:8097/env/" + self.env_name)
|
||||
|
||||
# Create the windows
|
||||
self.loss_win = None
|
||||
self.eer_win = None
|
||||
# self.lr_win = None
|
||||
self.implementation_win = None
|
||||
self.projection_win = None
|
||||
self.implementation_string = ""
|
||||
|
||||
def log_params(self):
|
||||
if self.disabled:
|
||||
return
|
||||
from models.encoder import params_data
|
||||
from models.encoder import params_model
|
||||
param_string = "<b>Model parameters</b>:<br>"
|
||||
for param_name in (p for p in dir(params_model) if not p.startswith("__")):
|
||||
value = getattr(params_model, param_name)
|
||||
param_string += "\t%s: %s<br>" % (param_name, value)
|
||||
param_string += "<b>Data parameters</b>:<br>"
|
||||
for param_name in (p for p in dir(params_data) if not p.startswith("__")):
|
||||
value = getattr(params_data, param_name)
|
||||
param_string += "\t%s: %s<br>" % (param_name, value)
|
||||
self.vis.text(param_string, opts={"title": "Parameters"})
|
||||
|
||||
def log_dataset(self, dataset: SpeakerVerificationDataset):
|
||||
if self.disabled:
|
||||
return
|
||||
dataset_string = ""
|
||||
dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
|
||||
dataset_string += "\n" + dataset.get_logs()
|
||||
dataset_string = dataset_string.replace("\n", "<br>")
|
||||
self.vis.text(dataset_string, opts={"title": "Dataset"})
|
||||
|
||||
def log_implementation(self, params):
|
||||
if self.disabled:
|
||||
return
|
||||
implementation_string = ""
|
||||
for param, value in params.items():
|
||||
implementation_string += "<b>%s</b>: %s\n" % (param, value)
|
||||
implementation_string = implementation_string.replace("\n", "<br>")
|
||||
self.implementation_string = implementation_string
|
||||
self.implementation_win = self.vis.text(
|
||||
implementation_string,
|
||||
opts={"title": "Training implementation"}
|
||||
)
|
||||
|
||||
def update(self, loss, eer, step):
|
||||
# Update the tracking data
|
||||
now = timer()
|
||||
self.step_times.append(1000 * (now - self.last_update_timestamp))
|
||||
self.last_update_timestamp = now
|
||||
self.losses.append(loss)
|
||||
self.eers.append(eer)
|
||||
print(".", end="")
|
||||
|
||||
# Update the plots every <update_every> steps
|
||||
if step % self.update_every != 0:
|
||||
return
|
||||
time_string = "Step time: mean: %5dms std: %5dms" % \
|
||||
(int(np.mean(self.step_times)), int(np.std(self.step_times)))
|
||||
print("\nStep %6d Loss: %.4f EER: %.4f %s" %
|
||||
(step, np.mean(self.losses), np.mean(self.eers), time_string))
|
||||
if not self.disabled:
|
||||
self.loss_win = self.vis.line(
|
||||
[np.mean(self.losses)],
|
||||
[step],
|
||||
win=self.loss_win,
|
||||
update="append" if self.loss_win else None,
|
||||
opts=dict(
|
||||
legend=["Avg. loss"],
|
||||
xlabel="Step",
|
||||
ylabel="Loss",
|
||||
title="Loss",
|
||||
)
|
||||
)
|
||||
self.eer_win = self.vis.line(
|
||||
[np.mean(self.eers)],
|
||||
[step],
|
||||
win=self.eer_win,
|
||||
update="append" if self.eer_win else None,
|
||||
opts=dict(
|
||||
legend=["Avg. EER"],
|
||||
xlabel="Step",
|
||||
ylabel="EER",
|
||||
title="Equal error rate"
|
||||
)
|
||||
)
|
||||
if self.implementation_win is not None:
|
||||
self.vis.text(
|
||||
self.implementation_string + ("<b>%s</b>" % time_string),
|
||||
win=self.implementation_win,
|
||||
opts={"title": "Training implementation"},
|
||||
)
|
||||
|
||||
# Reset the tracking
|
||||
self.losses.clear()
|
||||
self.eers.clear()
|
||||
self.step_times.clear()
|
||||
|
||||
def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None,
|
||||
max_speakers=10):
|
||||
max_speakers = min(max_speakers, len(colormap))
|
||||
embeds = embeds[:max_speakers * utterances_per_speaker]
|
||||
|
||||
n_speakers = len(embeds) // utterances_per_speaker
|
||||
ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
|
||||
colors = [colormap[i] for i in ground_truth]
|
||||
|
||||
reducer = umap.UMAP()
|
||||
projected = reducer.fit_transform(embeds)
|
||||
plt.scatter(projected[:, 0], projected[:, 1], c=colors)
|
||||
plt.gca().set_aspect("equal", "datalim")
|
||||
plt.title("UMAP projection (step %d)" % step)
|
||||
if not self.disabled:
|
||||
self.projection_win = self.vis.matplot(plt, win=self.projection_win)
|
||||
if out_fpath is not None:
|
||||
plt.savefig(out_fpath)
|
||||
plt.clf()
|
||||
|
||||
def save(self):
|
||||
if not self.disabled:
|
||||
self.vis.save([self.env_name])
|
||||
|
||||
209
models/ppg2mel/__init__.py
Normal file
209
models/ppg2mel/__init__.py
Normal file
@@ -0,0 +1,209 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2020 Songxiang Liu
|
||||
# Apache 2.0
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .utils.abs_model import AbsMelDecoder
|
||||
from .rnn_decoder_mol import Decoder
|
||||
from .utils.cnn_postnet import Postnet
|
||||
from .utils.vc_utils import get_mask_from_lengths
|
||||
|
||||
from utils.load_yaml import HpsYaml
|
||||
|
||||
class MelDecoderMOLv2(AbsMelDecoder):
|
||||
"""Use an encoder to preprocess ppg."""
|
||||
def __init__(
|
||||
self,
|
||||
num_speakers: int,
|
||||
spk_embed_dim: int,
|
||||
bottle_neck_feature_dim: int,
|
||||
encoder_dim: int = 256,
|
||||
encoder_downsample_rates: List = [2, 2],
|
||||
attention_rnn_dim: int = 512,
|
||||
decoder_rnn_dim: int = 512,
|
||||
num_decoder_rnn_layer: int = 1,
|
||||
concat_context_to_last: bool = True,
|
||||
prenet_dims: List = [256, 128],
|
||||
num_mixtures: int = 5,
|
||||
frames_per_step: int = 2,
|
||||
mask_padding: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.mask_padding = mask_padding
|
||||
self.bottle_neck_feature_dim = bottle_neck_feature_dim
|
||||
self.num_mels = 80
|
||||
self.encoder_down_factor=np.cumprod(encoder_downsample_rates)[-1]
|
||||
self.frames_per_step = frames_per_step
|
||||
self.use_spk_dvec = True
|
||||
|
||||
input_dim = bottle_neck_feature_dim
|
||||
|
||||
# Downsampling convolution
|
||||
self.bnf_prenet = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(input_dim, encoder_dim, kernel_size=1, bias=False),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
torch.nn.Conv1d(
|
||||
encoder_dim, encoder_dim,
|
||||
kernel_size=2*encoder_downsample_rates[0],
|
||||
stride=encoder_downsample_rates[0],
|
||||
padding=encoder_downsample_rates[0]//2,
|
||||
),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
torch.nn.Conv1d(
|
||||
encoder_dim, encoder_dim,
|
||||
kernel_size=2*encoder_downsample_rates[1],
|
||||
stride=encoder_downsample_rates[1],
|
||||
padding=encoder_downsample_rates[1]//2,
|
||||
),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
)
|
||||
decoder_enc_dim = encoder_dim
|
||||
self.pitch_convs = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(2, encoder_dim, kernel_size=1, bias=False),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
torch.nn.Conv1d(
|
||||
encoder_dim, encoder_dim,
|
||||
kernel_size=2*encoder_downsample_rates[0],
|
||||
stride=encoder_downsample_rates[0],
|
||||
padding=encoder_downsample_rates[0]//2,
|
||||
),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
torch.nn.Conv1d(
|
||||
encoder_dim, encoder_dim,
|
||||
kernel_size=2*encoder_downsample_rates[1],
|
||||
stride=encoder_downsample_rates[1],
|
||||
padding=encoder_downsample_rates[1]//2,
|
||||
),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
)
|
||||
|
||||
self.reduce_proj = torch.nn.Linear(encoder_dim + spk_embed_dim, encoder_dim)
|
||||
|
||||
# Decoder
|
||||
self.decoder = Decoder(
|
||||
enc_dim=decoder_enc_dim,
|
||||
num_mels=self.num_mels,
|
||||
frames_per_step=frames_per_step,
|
||||
attention_rnn_dim=attention_rnn_dim,
|
||||
decoder_rnn_dim=decoder_rnn_dim,
|
||||
num_decoder_rnn_layer=num_decoder_rnn_layer,
|
||||
prenet_dims=prenet_dims,
|
||||
num_mixtures=num_mixtures,
|
||||
use_stop_tokens=True,
|
||||
concat_context_to_last=concat_context_to_last,
|
||||
encoder_down_factor=self.encoder_down_factor,
|
||||
)
|
||||
|
||||
# Mel-Spec Postnet: some residual CNN layers
|
||||
self.postnet = Postnet()
|
||||
|
||||
def parse_output(self, outputs, output_lengths=None):
|
||||
if self.mask_padding and output_lengths is not None:
|
||||
mask = ~get_mask_from_lengths(output_lengths, outputs[0].size(1))
|
||||
mask = mask.unsqueeze(2).expand(mask.size(0), mask.size(1), self.num_mels)
|
||||
outputs[0].data.masked_fill_(mask, 0.0)
|
||||
outputs[1].data.masked_fill_(mask, 0.0)
|
||||
return outputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
bottle_neck_features: torch.Tensor,
|
||||
feature_lengths: torch.Tensor,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
logf0_uv: torch.Tensor = None,
|
||||
spembs: torch.Tensor = None,
|
||||
output_att_ws: bool = False,
|
||||
):
|
||||
decoder_inputs = self.bnf_prenet(
|
||||
bottle_neck_features.transpose(1, 2)
|
||||
).transpose(1, 2)
|
||||
logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
|
||||
decoder_inputs = decoder_inputs + logf0_uv
|
||||
|
||||
assert spembs is not None
|
||||
spk_embeds = F.normalize(
|
||||
spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
|
||||
decoder_inputs = torch.cat([decoder_inputs, spk_embeds], dim=-1)
|
||||
decoder_inputs = self.reduce_proj(decoder_inputs)
|
||||
|
||||
# (B, num_mels, T_dec)
|
||||
T_dec = torch.div(feature_lengths, int(self.encoder_down_factor), rounding_mode='floor')
|
||||
mel_outputs, predicted_stop, alignments = self.decoder(
|
||||
decoder_inputs, speech, T_dec)
|
||||
## Post-processing
|
||||
mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
if output_att_ws:
|
||||
return self.parse_output(
|
||||
[mel_outputs, mel_outputs_postnet, predicted_stop, alignments], speech_lengths)
|
||||
else:
|
||||
return self.parse_output(
|
||||
[mel_outputs, mel_outputs_postnet, predicted_stop], speech_lengths)
|
||||
|
||||
# return mel_outputs, mel_outputs_postnet
|
||||
|
||||
def inference(
|
||||
self,
|
||||
bottle_neck_features: torch.Tensor,
|
||||
logf0_uv: torch.Tensor = None,
|
||||
spembs: torch.Tensor = None,
|
||||
):
|
||||
decoder_inputs = self.bnf_prenet(bottle_neck_features.transpose(1, 2)).transpose(1, 2)
|
||||
logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
|
||||
decoder_inputs = decoder_inputs + logf0_uv
|
||||
|
||||
assert spembs is not None
|
||||
spk_embeds = F.normalize(
|
||||
spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
|
||||
bottle_neck_features = torch.cat([decoder_inputs, spk_embeds], dim=-1)
|
||||
bottle_neck_features = self.reduce_proj(bottle_neck_features)
|
||||
|
||||
## Decoder
|
||||
if bottle_neck_features.size(0) > 1:
|
||||
mel_outputs, alignments = self.decoder.inference_batched(bottle_neck_features)
|
||||
else:
|
||||
mel_outputs, alignments = self.decoder.inference(bottle_neck_features,)
|
||||
## Post-processing
|
||||
mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
# outputs = mel_outputs_postnet[0]
|
||||
|
||||
return mel_outputs[0], mel_outputs_postnet[0], alignments[0]
|
||||
|
||||
def load_model(model_file, device=None):
|
||||
# search a config file
|
||||
model_config_fpaths = list(model_file.parent.rglob("*.yaml"))
|
||||
if len(model_config_fpaths) == 0:
|
||||
raise "No model yaml config found for convertor"
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model_config = HpsYaml(model_config_fpaths[0])
|
||||
ppg2mel_model = MelDecoderMOLv2(
|
||||
**model_config["model"]
|
||||
).to(device)
|
||||
ckpt = torch.load(model_file, map_location=device)
|
||||
ppg2mel_model.load_state_dict(ckpt["model"])
|
||||
ppg2mel_model.eval()
|
||||
return ppg2mel_model
|
||||
113
models/ppg2mel/preprocess.py
Normal file
113
models/ppg2mel/preprocess.py
Normal file
@@ -0,0 +1,113 @@
|
||||
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
import soundfile
|
||||
import resampy
|
||||
|
||||
from models.ppg_extractor import load_model
|
||||
import encoder.inference as Encoder
|
||||
from models.encoder.audio import preprocess_wav
|
||||
from models.encoder import audio
|
||||
from utils.f0_utils import compute_f0
|
||||
|
||||
from torch.multiprocessing import Pool, cpu_count
|
||||
from functools import partial
|
||||
|
||||
SAMPLE_RATE=16000
|
||||
|
||||
def _compute_bnf(
|
||||
wav: any,
|
||||
output_fpath: str,
|
||||
device: torch.device,
|
||||
ppg_model_local: any,
|
||||
):
|
||||
"""
|
||||
Compute CTC-Attention Seq2seq ASR encoder bottle-neck features (BNF).
|
||||
"""
|
||||
ppg_model_local.to(device)
|
||||
wav_tensor = torch.from_numpy(wav).float().to(device).unsqueeze(0)
|
||||
wav_length = torch.LongTensor([wav.shape[0]]).to(device)
|
||||
with torch.no_grad():
|
||||
bnf = ppg_model_local(wav_tensor, wav_length)
|
||||
bnf_npy = bnf.squeeze(0).cpu().numpy()
|
||||
np.save(output_fpath, bnf_npy, allow_pickle=False)
|
||||
return bnf_npy, len(bnf_npy)
|
||||
|
||||
def _compute_f0_from_wav(wav, output_fpath):
|
||||
"""Compute merged f0 values."""
|
||||
f0 = compute_f0(wav, SAMPLE_RATE)
|
||||
np.save(output_fpath, f0, allow_pickle=False)
|
||||
return f0, len(f0)
|
||||
|
||||
def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device):
|
||||
Encoder.set_model(encoder_model_local)
|
||||
# Compute where to split the utterance into partials and pad if necessary
|
||||
wave_slices, mel_slices = Encoder.compute_partial_slices(len(wav), rate=1.3, min_pad_coverage=0.75)
|
||||
max_wave_length = wave_slices[-1].stop
|
||||
if max_wave_length >= len(wav):
|
||||
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
||||
|
||||
# Split the utterance into partials
|
||||
frames = audio.wav_to_mel_spectrogram(wav)
|
||||
frames_batch = np.array([frames[s] for s in mel_slices])
|
||||
partial_embeds = Encoder.embed_frames_batch(frames_batch)
|
||||
|
||||
# Compute the utterance embedding from the partial embeddings
|
||||
raw_embed = np.mean(partial_embeds, axis=0)
|
||||
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
||||
|
||||
np.save(output_fpath, embed, allow_pickle=False)
|
||||
return embed, len(embed)
|
||||
|
||||
def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local):
|
||||
# wav = preprocess_wav(wav_path)
|
||||
# try:
|
||||
wav, sr = soundfile.read(wav_path)
|
||||
if len(wav) < sr:
|
||||
return None, sr, len(wav)
|
||||
if sr != SAMPLE_RATE:
|
||||
wav = resampy.resample(wav, sr, SAMPLE_RATE)
|
||||
sr = SAMPLE_RATE
|
||||
utt_id = os.path.basename(wav_path).rstrip(".wav")
|
||||
|
||||
_, length_bnf = _compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local)
|
||||
_, length_f0 = _compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav)
|
||||
_, length_embed = _compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav)
|
||||
|
||||
def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model):
|
||||
# Glob wav files
|
||||
wav_file_list = sorted(Path(f"{datasets_root}/{dataset}").glob("**/*.wav"))
|
||||
print(f"Globbed {len(wav_file_list)} wav files.")
|
||||
|
||||
out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True)
|
||||
out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True)
|
||||
out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True)
|
||||
ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu")
|
||||
encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu")
|
||||
if n_processes is None:
|
||||
n_processes = cpu_count()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, encoder_model_local=encoder_model_local, device=device)
|
||||
job = Pool(n_processes).imap(func, wav_file_list)
|
||||
list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav"))
|
||||
|
||||
# finish processing and mark
|
||||
t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8")
|
||||
d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8")
|
||||
e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8")
|
||||
for file in sorted(out_dir.joinpath("f0").glob("*.npy")):
|
||||
id = os.path.basename(file).split(".f0.npy")[0]
|
||||
if id.endswith("01"):
|
||||
d_fid_file.write(id + "\n")
|
||||
elif id.endswith("09"):
|
||||
e_fid_file.write(id + "\n")
|
||||
else:
|
||||
t_fid_file.write(id + "\n")
|
||||
t_fid_file.close()
|
||||
d_fid_file.close()
|
||||
e_fid_file.close()
|
||||
return len(wav_file_list)
|
||||
374
models/ppg2mel/rnn_decoder_mol.py
Normal file
374
models/ppg2mel/rnn_decoder_mol.py
Normal file
@@ -0,0 +1,374 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from .utils.mol_attention import MOLAttention
|
||||
from .utils.basic_layers import Linear
|
||||
from .utils.vc_utils import get_mask_from_lengths
|
||||
|
||||
|
||||
class DecoderPrenet(nn.Module):
|
||||
def __init__(self, in_dim, sizes):
|
||||
super().__init__()
|
||||
in_sizes = [in_dim] + sizes[:-1]
|
||||
self.layers = nn.ModuleList(
|
||||
[Linear(in_size, out_size, bias=False)
|
||||
for (in_size, out_size) in zip(in_sizes, sizes)])
|
||||
|
||||
def forward(self, x):
|
||||
for linear in self.layers:
|
||||
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""Mixture of Logistic (MoL) attention-based RNN Decoder."""
|
||||
def __init__(
|
||||
self,
|
||||
enc_dim,
|
||||
num_mels,
|
||||
frames_per_step,
|
||||
attention_rnn_dim,
|
||||
decoder_rnn_dim,
|
||||
prenet_dims,
|
||||
num_mixtures,
|
||||
encoder_down_factor=1,
|
||||
num_decoder_rnn_layer=1,
|
||||
use_stop_tokens=False,
|
||||
concat_context_to_last=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.enc_dim = enc_dim
|
||||
self.encoder_down_factor = encoder_down_factor
|
||||
self.num_mels = num_mels
|
||||
self.frames_per_step = frames_per_step
|
||||
self.attention_rnn_dim = attention_rnn_dim
|
||||
self.decoder_rnn_dim = decoder_rnn_dim
|
||||
self.prenet_dims = prenet_dims
|
||||
self.use_stop_tokens = use_stop_tokens
|
||||
self.num_decoder_rnn_layer = num_decoder_rnn_layer
|
||||
self.concat_context_to_last = concat_context_to_last
|
||||
|
||||
# Mel prenet
|
||||
self.prenet = DecoderPrenet(num_mels, prenet_dims)
|
||||
self.prenet_pitch = DecoderPrenet(num_mels, prenet_dims)
|
||||
|
||||
# Attention RNN
|
||||
self.attention_rnn = nn.LSTMCell(
|
||||
prenet_dims[-1] + enc_dim,
|
||||
attention_rnn_dim
|
||||
)
|
||||
|
||||
# Attention
|
||||
self.attention_layer = MOLAttention(
|
||||
attention_rnn_dim,
|
||||
r=frames_per_step/encoder_down_factor,
|
||||
M=num_mixtures,
|
||||
)
|
||||
|
||||
# Decoder RNN
|
||||
self.decoder_rnn_layers = nn.ModuleList()
|
||||
for i in range(num_decoder_rnn_layer):
|
||||
if i == 0:
|
||||
self.decoder_rnn_layers.append(
|
||||
nn.LSTMCell(
|
||||
enc_dim + attention_rnn_dim,
|
||||
decoder_rnn_dim))
|
||||
else:
|
||||
self.decoder_rnn_layers.append(
|
||||
nn.LSTMCell(
|
||||
decoder_rnn_dim,
|
||||
decoder_rnn_dim))
|
||||
# self.decoder_rnn = nn.LSTMCell(
|
||||
# 2 * enc_dim + attention_rnn_dim,
|
||||
# decoder_rnn_dim
|
||||
# )
|
||||
if concat_context_to_last:
|
||||
self.linear_projection = Linear(
|
||||
enc_dim + decoder_rnn_dim,
|
||||
num_mels * frames_per_step
|
||||
)
|
||||
else:
|
||||
self.linear_projection = Linear(
|
||||
decoder_rnn_dim,
|
||||
num_mels * frames_per_step
|
||||
)
|
||||
|
||||
|
||||
# Stop-token layer
|
||||
if self.use_stop_tokens:
|
||||
if concat_context_to_last:
|
||||
self.stop_layer = Linear(
|
||||
enc_dim + decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
|
||||
)
|
||||
else:
|
||||
self.stop_layer = Linear(
|
||||
decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
|
||||
)
|
||||
|
||||
|
||||
def get_go_frame(self, memory):
|
||||
B = memory.size(0)
|
||||
go_frame = torch.zeros((B, self.num_mels), dtype=torch.float,
|
||||
device=memory.device)
|
||||
return go_frame
|
||||
|
||||
def initialize_decoder_states(self, memory, mask):
|
||||
device = next(self.parameters()).device
|
||||
B = memory.size(0)
|
||||
|
||||
# attention rnn states
|
||||
self.attention_hidden = torch.zeros(
|
||||
(B, self.attention_rnn_dim), device=device)
|
||||
self.attention_cell = torch.zeros(
|
||||
(B, self.attention_rnn_dim), device=device)
|
||||
|
||||
# decoder rnn states
|
||||
self.decoder_hiddens = []
|
||||
self.decoder_cells = []
|
||||
for i in range(self.num_decoder_rnn_layer):
|
||||
self.decoder_hiddens.append(
|
||||
torch.zeros((B, self.decoder_rnn_dim),
|
||||
device=device)
|
||||
)
|
||||
self.decoder_cells.append(
|
||||
torch.zeros((B, self.decoder_rnn_dim),
|
||||
device=device)
|
||||
)
|
||||
# self.decoder_hidden = torch.zeros(
|
||||
# (B, self.decoder_rnn_dim), device=device)
|
||||
# self.decoder_cell = torch.zeros(
|
||||
# (B, self.decoder_rnn_dim), device=device)
|
||||
|
||||
self.attention_context = torch.zeros(
|
||||
(B, self.enc_dim), device=device)
|
||||
|
||||
self.memory = memory
|
||||
# self.processed_memory = self.attention_layer.memory_layer(memory)
|
||||
self.mask = mask
|
||||
|
||||
def parse_decoder_inputs(self, decoder_inputs):
|
||||
"""Prepare decoder inputs, i.e. gt mel
|
||||
Args:
|
||||
decoder_inputs:(B, T_out, n_mel_channels) inputs used for teacher-forced training.
|
||||
"""
|
||||
decoder_inputs = decoder_inputs.reshape(
|
||||
decoder_inputs.size(0),
|
||||
int(decoder_inputs.size(1)/self.frames_per_step), -1)
|
||||
# (B, T_out//r, r*num_mels) -> (T_out//r, B, r*num_mels)
|
||||
decoder_inputs = decoder_inputs.transpose(0, 1)
|
||||
# (T_out//r, B, num_mels)
|
||||
decoder_inputs = decoder_inputs[:,:,-self.num_mels:]
|
||||
return decoder_inputs
|
||||
|
||||
def parse_decoder_outputs(self, mel_outputs, alignments, stop_outputs):
|
||||
""" Prepares decoder outputs for output
|
||||
Args:
|
||||
mel_outputs:
|
||||
alignments:
|
||||
"""
|
||||
# (T_out//r, B, T_enc) -> (B, T_out//r, T_enc)
|
||||
alignments = torch.stack(alignments).transpose(0, 1)
|
||||
# (T_out//r, B) -> (B, T_out//r)
|
||||
if stop_outputs is not None:
|
||||
if alignments.size(0) == 1:
|
||||
stop_outputs = torch.stack(stop_outputs).unsqueeze(0)
|
||||
else:
|
||||
stop_outputs = torch.stack(stop_outputs).transpose(0, 1)
|
||||
stop_outputs = stop_outputs.contiguous()
|
||||
# (T_out//r, B, num_mels*r) -> (B, T_out//r, num_mels*r)
|
||||
mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
|
||||
# decouple frames per step
|
||||
# (B, T_out, num_mels)
|
||||
mel_outputs = mel_outputs.view(
|
||||
mel_outputs.size(0), -1, self.num_mels)
|
||||
return mel_outputs, alignments, stop_outputs
|
||||
|
||||
def attend(self, decoder_input):
|
||||
cell_input = torch.cat((decoder_input, self.attention_context), -1)
|
||||
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
||||
cell_input, (self.attention_hidden, self.attention_cell))
|
||||
self.attention_context, attention_weights = self.attention_layer(
|
||||
self.attention_hidden, self.memory, None, self.mask)
|
||||
|
||||
decoder_rnn_input = torch.cat(
|
||||
(self.attention_hidden, self.attention_context), -1)
|
||||
|
||||
return decoder_rnn_input, self.attention_context, attention_weights
|
||||
|
||||
def decode(self, decoder_input):
|
||||
for i in range(self.num_decoder_rnn_layer):
|
||||
if i == 0:
|
||||
self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
|
||||
decoder_input, (self.decoder_hiddens[i], self.decoder_cells[i]))
|
||||
else:
|
||||
self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
|
||||
self.decoder_hiddens[i-1], (self.decoder_hiddens[i], self.decoder_cells[i]))
|
||||
return self.decoder_hiddens[-1]
|
||||
|
||||
def forward(self, memory, mel_inputs, memory_lengths):
|
||||
""" Decoder forward pass for training
|
||||
Args:
|
||||
memory: (B, T_enc, enc_dim) Encoder outputs
|
||||
decoder_inputs: (B, T, num_mels) Decoder inputs for teacher forcing.
|
||||
memory_lengths: (B, ) Encoder output lengths for attention masking.
|
||||
Returns:
|
||||
mel_outputs: (B, T, num_mels) mel outputs from the decoder
|
||||
alignments: (B, T//r, T_enc) attention weights.
|
||||
"""
|
||||
# [1, B, num_mels]
|
||||
go_frame = self.get_go_frame(memory).unsqueeze(0)
|
||||
# [T//r, B, num_mels]
|
||||
mel_inputs = self.parse_decoder_inputs(mel_inputs)
|
||||
# [T//r + 1, B, num_mels]
|
||||
mel_inputs = torch.cat((go_frame, mel_inputs), dim=0)
|
||||
# [T//r + 1, B, prenet_dim]
|
||||
decoder_inputs = self.prenet(mel_inputs)
|
||||
# decoder_inputs_pitch = self.prenet_pitch(decoder_inputs__)
|
||||
|
||||
self.initialize_decoder_states(
|
||||
memory, mask=~get_mask_from_lengths(memory_lengths),
|
||||
)
|
||||
|
||||
self.attention_layer.init_states(memory)
|
||||
# self.attention_layer_pitch.init_states(memory_pitch)
|
||||
|
||||
mel_outputs, alignments = [], []
|
||||
if self.use_stop_tokens:
|
||||
stop_outputs = []
|
||||
else:
|
||||
stop_outputs = None
|
||||
while len(mel_outputs) < decoder_inputs.size(0) - 1:
|
||||
decoder_input = decoder_inputs[len(mel_outputs)]
|
||||
# decoder_input_pitch = decoder_inputs_pitch[len(mel_outputs)]
|
||||
|
||||
decoder_rnn_input, context, attention_weights = self.attend(decoder_input)
|
||||
|
||||
decoder_rnn_output = self.decode(decoder_rnn_input)
|
||||
if self.concat_context_to_last:
|
||||
decoder_rnn_output = torch.cat(
|
||||
(decoder_rnn_output, context), dim=1)
|
||||
|
||||
mel_output = self.linear_projection(decoder_rnn_output)
|
||||
if self.use_stop_tokens:
|
||||
stop_output = self.stop_layer(decoder_rnn_output)
|
||||
stop_outputs += [stop_output.squeeze()]
|
||||
mel_outputs += [mel_output.squeeze(1)] #? perhaps don't need squeeze
|
||||
alignments += [attention_weights]
|
||||
# alignments_pitch += [attention_weights_pitch]
|
||||
|
||||
mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
|
||||
mel_outputs, alignments, stop_outputs)
|
||||
if stop_outputs is None:
|
||||
return mel_outputs, alignments
|
||||
else:
|
||||
return mel_outputs, stop_outputs, alignments
|
||||
|
||||
def inference(self, memory, stop_threshold=0.5):
|
||||
""" Decoder inference
|
||||
Args:
|
||||
memory: (1, T_enc, D_enc) Encoder outputs
|
||||
Returns:
|
||||
mel_outputs: mel outputs from the decoder
|
||||
alignments: sequence of attention weights from the decoder
|
||||
"""
|
||||
# [1, num_mels]
|
||||
decoder_input = self.get_go_frame(memory)
|
||||
|
||||
self.initialize_decoder_states(memory, mask=None)
|
||||
|
||||
self.attention_layer.init_states(memory)
|
||||
|
||||
mel_outputs, alignments = [], []
|
||||
# NOTE(sx): heuristic
|
||||
max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step
|
||||
min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
|
||||
while True:
|
||||
decoder_input = self.prenet(decoder_input)
|
||||
|
||||
decoder_input_final, context, alignment = self.attend(decoder_input)
|
||||
|
||||
#mel_output, stop_output, alignment = self.decode(decoder_input)
|
||||
decoder_rnn_output = self.decode(decoder_input_final)
|
||||
if self.concat_context_to_last:
|
||||
decoder_rnn_output = torch.cat(
|
||||
(decoder_rnn_output, context), dim=1)
|
||||
|
||||
mel_output = self.linear_projection(decoder_rnn_output)
|
||||
stop_output = self.stop_layer(decoder_rnn_output)
|
||||
|
||||
mel_outputs += [mel_output.squeeze(1)]
|
||||
alignments += [alignment]
|
||||
|
||||
if torch.sigmoid(stop_output.data) > stop_threshold and len(mel_outputs) >= min_decoder_step:
|
||||
break
|
||||
if len(mel_outputs) >= max_decoder_step:
|
||||
# print("Warning! Decoding steps reaches max decoder steps.")
|
||||
break
|
||||
|
||||
decoder_input = mel_output[:,-self.num_mels:]
|
||||
|
||||
|
||||
mel_outputs, alignments, _ = self.parse_decoder_outputs(
|
||||
mel_outputs, alignments, None)
|
||||
|
||||
return mel_outputs, alignments
|
||||
|
||||
def inference_batched(self, memory, stop_threshold=0.5):
|
||||
""" Decoder inference
|
||||
Args:
|
||||
memory: (B, T_enc, D_enc) Encoder outputs
|
||||
Returns:
|
||||
mel_outputs: mel outputs from the decoder
|
||||
alignments: sequence of attention weights from the decoder
|
||||
"""
|
||||
# [1, num_mels]
|
||||
decoder_input = self.get_go_frame(memory)
|
||||
|
||||
self.initialize_decoder_states(memory, mask=None)
|
||||
|
||||
self.attention_layer.init_states(memory)
|
||||
|
||||
mel_outputs, alignments = [], []
|
||||
stop_outputs = []
|
||||
# NOTE(sx): heuristic
|
||||
max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step
|
||||
min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
|
||||
while True:
|
||||
decoder_input = self.prenet(decoder_input)
|
||||
|
||||
decoder_input_final, context, alignment = self.attend(decoder_input)
|
||||
|
||||
#mel_output, stop_output, alignment = self.decode(decoder_input)
|
||||
decoder_rnn_output = self.decode(decoder_input_final)
|
||||
if self.concat_context_to_last:
|
||||
decoder_rnn_output = torch.cat(
|
||||
(decoder_rnn_output, context), dim=1)
|
||||
|
||||
mel_output = self.linear_projection(decoder_rnn_output)
|
||||
# (B, 1)
|
||||
stop_output = self.stop_layer(decoder_rnn_output)
|
||||
stop_outputs += [stop_output.squeeze()]
|
||||
# stop_outputs.append(stop_output)
|
||||
|
||||
mel_outputs += [mel_output.squeeze(1)]
|
||||
alignments += [alignment]
|
||||
# print(stop_output.shape)
|
||||
if torch.all(torch.sigmoid(stop_output.squeeze().data) > stop_threshold) \
|
||||
and len(mel_outputs) >= min_decoder_step:
|
||||
break
|
||||
if len(mel_outputs) >= max_decoder_step:
|
||||
# print("Warning! Decoding steps reaches max decoder steps.")
|
||||
break
|
||||
|
||||
decoder_input = mel_output[:,-self.num_mels:]
|
||||
|
||||
|
||||
mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
|
||||
mel_outputs, alignments, stop_outputs)
|
||||
mel_outputs_stacked = []
|
||||
for mel, stop_logit in zip(mel_outputs, stop_outputs):
|
||||
idx = np.argwhere(torch.sigmoid(stop_logit.cpu()) > stop_threshold)[0][0].item()
|
||||
mel_outputs_stacked.append(mel[:idx,:])
|
||||
mel_outputs = torch.cat(mel_outputs_stacked, dim=0).unsqueeze(0)
|
||||
return mel_outputs, alignments
|
||||
62
models/ppg2mel/train.py
Normal file
62
models/ppg2mel/train.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import sys
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from utils.load_yaml import HpsYaml
|
||||
from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||
|
||||
# For reproducibility, comment these may speed up training
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def main():
|
||||
# Arguments
|
||||
parser = argparse.ArgumentParser(description=
|
||||
'Training PPG2Mel VC model.')
|
||||
parser.add_argument('--config', type=str,
|
||||
help='Path to experiment config, e.g., config/vc.yaml')
|
||||
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
||||
parser.add_argument('--logdir', default='log/', type=str,
|
||||
help='Logging path.', required=False)
|
||||
parser.add_argument('--ckpdir', default='ckpt/', type=str,
|
||||
help='Checkpoint path.', required=False)
|
||||
parser.add_argument('--outdir', default='result/', type=str,
|
||||
help='Decode output path.', required=False)
|
||||
parser.add_argument('--load', default=None, type=str,
|
||||
help='Load pre-trained model (for training only)', required=False)
|
||||
parser.add_argument('--warm_start', action='store_true',
|
||||
help='Load model weights only, ignore specified layers.')
|
||||
parser.add_argument('--seed', default=0, type=int,
|
||||
help='Random seed for reproducable results.', required=False)
|
||||
parser.add_argument('--njobs', default=8, type=int,
|
||||
help='Number of threads for dataloader/decoding.', required=False)
|
||||
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
||||
# parser.add_argument('--no-pin', action='store_true',
|
||||
# help='Disable pin-memory for dataloader')
|
||||
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
||||
|
||||
###
|
||||
|
||||
paras = parser.parse_args()
|
||||
setattr(paras, 'gpu', not paras.cpu)
|
||||
setattr(paras, 'pin_memory', not paras.no_pin)
|
||||
setattr(paras, 'verbose', not paras.no_msg)
|
||||
# Make the config dict dot visitable
|
||||
config = HpsYaml(paras.config)
|
||||
|
||||
np.random.seed(paras.seed)
|
||||
torch.manual_seed(paras.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(paras.seed)
|
||||
|
||||
print(">>> OneShot VC training ...")
|
||||
mode = "train"
|
||||
solver = Solver(config, paras, mode)
|
||||
solver.load_data()
|
||||
solver.set_model()
|
||||
solver.exec()
|
||||
print(">>> Oneshot VC train finished!")
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
models/ppg2mel/train/__init__.py
Normal file
1
models/ppg2mel/train/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
#
|
||||
50
models/ppg2mel/train/loss.py
Normal file
50
models/ppg2mel/train/loss.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import Dict
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class MaskedMSELoss(nn.Module):
|
||||
def __init__(self, frames_per_step):
|
||||
super().__init__()
|
||||
self.frames_per_step = frames_per_step
|
||||
self.mel_loss_criterion = nn.MSELoss(reduction='none')
|
||||
# self.loss = nn.MSELoss()
|
||||
self.stop_loss_criterion = nn.BCEWithLogitsLoss(reduction='none')
|
||||
|
||||
def get_mask(self, lengths, max_len=None):
|
||||
# lengths: [B,]
|
||||
if max_len is None:
|
||||
max_len = torch.max(lengths)
|
||||
batch_size = lengths.size(0)
|
||||
seq_range = torch.arange(0, max_len).long()
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len).to(lengths.device)
|
||||
seq_length_expand = lengths.unsqueeze(1).expand_as(seq_range_expand)
|
||||
return (seq_range_expand < seq_length_expand).float()
|
||||
|
||||
def forward(self, mel_pred, mel_pred_postnet, mel_trg, lengths,
|
||||
stop_target, stop_pred):
|
||||
## process stop_target
|
||||
B = stop_target.size(0)
|
||||
stop_target = stop_target.reshape(B, -1, self.frames_per_step)[:, :, 0]
|
||||
stop_lengths = torch.ceil(lengths.float() / self.frames_per_step).long()
|
||||
stop_mask = self.get_mask(stop_lengths, int(mel_trg.size(1)/self.frames_per_step))
|
||||
|
||||
mel_trg.requires_grad = False
|
||||
# (B, T, 1)
|
||||
mel_mask = self.get_mask(lengths, mel_trg.size(1)).unsqueeze(-1)
|
||||
# (B, T, D)
|
||||
mel_mask = mel_mask.expand_as(mel_trg)
|
||||
mel_loss_pre = (self.mel_loss_criterion(mel_pred, mel_trg) * mel_mask).sum() / mel_mask.sum()
|
||||
mel_loss_post = (self.mel_loss_criterion(mel_pred_postnet, mel_trg) * mel_mask).sum() / mel_mask.sum()
|
||||
|
||||
mel_loss = mel_loss_pre + mel_loss_post
|
||||
|
||||
# stop token loss
|
||||
stop_loss = torch.sum(self.stop_loss_criterion(stop_pred, stop_target) * stop_mask) / stop_mask.sum()
|
||||
|
||||
return mel_loss, stop_loss
|
||||
45
models/ppg2mel/train/optim.py
Normal file
45
models/ppg2mel/train/optim.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Optimizer():
|
||||
def __init__(self, parameters, optimizer, lr, eps, lr_scheduler,
|
||||
**kwargs):
|
||||
|
||||
# Setup torch optimizer
|
||||
self.opt_type = optimizer
|
||||
self.init_lr = lr
|
||||
self.sch_type = lr_scheduler
|
||||
opt = getattr(torch.optim, optimizer)
|
||||
if lr_scheduler == 'warmup':
|
||||
warmup_step = 4000.0
|
||||
init_lr = lr
|
||||
self.lr_scheduler = lambda step: init_lr * warmup_step ** 0.5 * \
|
||||
np.minimum((step+1)*warmup_step**-1.5, (step+1)**-0.5)
|
||||
self.opt = opt(parameters, lr=1.0)
|
||||
else:
|
||||
self.lr_scheduler = None
|
||||
self.opt = opt(parameters, lr=lr, eps=eps) # ToDo: 1e-8 better?
|
||||
|
||||
def get_opt_state_dict(self):
|
||||
return self.opt.state_dict()
|
||||
|
||||
def load_opt_state_dict(self, state_dict):
|
||||
self.opt.load_state_dict(state_dict)
|
||||
|
||||
def pre_step(self, step):
|
||||
if self.lr_scheduler is not None:
|
||||
cur_lr = self.lr_scheduler(step)
|
||||
for param_group in self.opt.param_groups:
|
||||
param_group['lr'] = cur_lr
|
||||
else:
|
||||
cur_lr = self.init_lr
|
||||
self.opt.zero_grad()
|
||||
return cur_lr
|
||||
|
||||
def step(self):
|
||||
self.opt.step()
|
||||
|
||||
def create_msg(self):
|
||||
return ['Optim.Info.| Algo. = {}\t| Lr = {}\t (schedule = {})'
|
||||
.format(self.opt_type, self.init_lr, self.sch_type)]
|
||||
10
models/ppg2mel/train/option.py
Normal file
10
models/ppg2mel/train/option.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# Default parameters which will be imported by solver
|
||||
default_hparas = {
|
||||
'GRAD_CLIP': 5.0, # Grad. clip threshold
|
||||
'PROGRESS_STEP': 100, # Std. output refresh freq.
|
||||
# Decode steps for objective validation (step = ratio*input_txt_len)
|
||||
'DEV_STEP_RATIO': 1.2,
|
||||
# Number of examples (alignment/text) to show in tensorboard
|
||||
'DEV_N_EXAMPLE': 4,
|
||||
'TB_FLUSH_FREQ': 180 # Update frequency of tensorboard (secs)
|
||||
}
|
||||
217
models/ppg2mel/train/solver.py
Normal file
217
models/ppg2mel/train/solver.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import os
|
||||
import sys
|
||||
import abc
|
||||
import math
|
||||
import yaml
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from .option import default_hparas
|
||||
from utils.util import human_format, Timer
|
||||
from utils.load_yaml import HpsYaml
|
||||
|
||||
|
||||
class BaseSolver():
|
||||
'''
|
||||
Prototype Solver for all kinds of tasks
|
||||
Arguments
|
||||
config - yaml-styled config
|
||||
paras - argparse outcome
|
||||
mode - "train"/"test"
|
||||
'''
|
||||
|
||||
def __init__(self, config, paras, mode="train"):
|
||||
# General Settings
|
||||
self.config = config # load from yaml file
|
||||
self.paras = paras # command line args
|
||||
self.mode = mode # 'train' or 'test'
|
||||
for k, v in default_hparas.items():
|
||||
setattr(self, k, v)
|
||||
self.device = torch.device('cuda') if self.paras.gpu and torch.cuda.is_available() \
|
||||
else torch.device('cpu')
|
||||
|
||||
# Name experiment
|
||||
self.exp_name = paras.name
|
||||
if self.exp_name is None:
|
||||
if 'exp_name' in self.config:
|
||||
self.exp_name = self.config.exp_name
|
||||
else:
|
||||
# By default, exp is named after config file
|
||||
self.exp_name = paras.config.split('/')[-1].replace('.yaml', '')
|
||||
if mode == 'train':
|
||||
self.exp_name += '_seed{}'.format(paras.seed)
|
||||
|
||||
|
||||
if mode == 'train':
|
||||
# Filepath setup
|
||||
os.makedirs(paras.ckpdir, exist_ok=True)
|
||||
self.ckpdir = os.path.join(paras.ckpdir, self.exp_name)
|
||||
os.makedirs(self.ckpdir, exist_ok=True)
|
||||
|
||||
# Logger settings
|
||||
self.logdir = os.path.join(paras.logdir, self.exp_name)
|
||||
self.log = SummaryWriter(
|
||||
self.logdir, flush_secs=self.TB_FLUSH_FREQ)
|
||||
self.timer = Timer()
|
||||
|
||||
# Hyper-parameters
|
||||
self.step = 0
|
||||
self.valid_step = config.hparas.valid_step
|
||||
self.max_step = config.hparas.max_step
|
||||
|
||||
self.verbose('Exp. name : {}'.format(self.exp_name))
|
||||
self.verbose('Loading data... large corpus may took a while.')
|
||||
|
||||
# elif mode == 'test':
|
||||
# # Output path
|
||||
# os.makedirs(paras.outdir, exist_ok=True)
|
||||
# self.ckpdir = os.path.join(paras.outdir, self.exp_name)
|
||||
|
||||
# Load training config to get acoustic feat and build model
|
||||
# self.src_config = HpsYaml(config.src.config)
|
||||
# self.paras.load = config.src.ckpt
|
||||
|
||||
# self.verbose('Evaluating result of tr. config @ {}'.format(
|
||||
# config.src.config))
|
||||
|
||||
def backward(self, loss):
|
||||
'''
|
||||
Standard backward step with self.timer and debugger
|
||||
Arguments
|
||||
loss - the loss to perform loss.backward()
|
||||
'''
|
||||
self.timer.set()
|
||||
loss.backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(), self.GRAD_CLIP)
|
||||
if math.isnan(grad_norm):
|
||||
self.verbose('Error : grad norm is NaN @ step '+str(self.step))
|
||||
else:
|
||||
self.optimizer.step()
|
||||
self.timer.cnt('bw')
|
||||
return grad_norm
|
||||
|
||||
def load_ckpt(self):
|
||||
''' Load ckpt if --load option is specified '''
|
||||
print(self.paras)
|
||||
if self.paras.load is not None:
|
||||
if self.paras.warm_start:
|
||||
self.verbose(f"Warm starting model from checkpoint {self.paras.load}.")
|
||||
ckpt = torch.load(
|
||||
self.paras.load, map_location=self.device if self.mode == 'train'
|
||||
else 'cpu')
|
||||
model_dict = ckpt['model']
|
||||
if "ignore_layers" in self.config.model and len(self.config.model.ignore_layers) > 0:
|
||||
model_dict = {k:v for k, v in model_dict.items()
|
||||
if k not in self.config.model.ignore_layers}
|
||||
dummy_dict = self.model.state_dict()
|
||||
dummy_dict.update(model_dict)
|
||||
model_dict = dummy_dict
|
||||
self.model.load_state_dict(model_dict)
|
||||
else:
|
||||
# Load weights
|
||||
ckpt = torch.load(
|
||||
self.paras.load, map_location=self.device if self.mode == 'train'
|
||||
else 'cpu')
|
||||
self.model.load_state_dict(ckpt['model'])
|
||||
|
||||
# Load task-dependent items
|
||||
if self.mode == 'train':
|
||||
self.step = ckpt['global_step']
|
||||
self.optimizer.load_opt_state_dict(ckpt['optimizer'])
|
||||
self.verbose('Load ckpt from {}, restarting at step {}'.format(
|
||||
self.paras.load, self.step))
|
||||
else:
|
||||
for k, v in ckpt.items():
|
||||
if type(v) is float:
|
||||
metric, score = k, v
|
||||
self.model.eval()
|
||||
self.verbose('Evaluation target = {} (recorded {} = {:.2f} %)'.format(
|
||||
self.paras.load, metric, score))
|
||||
|
||||
def verbose(self, msg):
|
||||
''' Verbose function for print information to stdout'''
|
||||
if self.paras.verbose:
|
||||
if type(msg) == list:
|
||||
for m in msg:
|
||||
print('[INFO]', m.ljust(100))
|
||||
else:
|
||||
print('[INFO]', msg.ljust(100))
|
||||
|
||||
def progress(self, msg):
|
||||
''' Verbose function for updating progress on stdout (do not include newline) '''
|
||||
if self.paras.verbose:
|
||||
sys.stdout.write("\033[K") # Clear line
|
||||
print('[{}] {}'.format(human_format(self.step), msg), end='\r')
|
||||
|
||||
def write_log(self, log_name, log_dict):
|
||||
'''
|
||||
Write log to TensorBoard
|
||||
log_name - <str> Name of tensorboard variable
|
||||
log_value - <dict>/<array> Value of variable (e.g. dict of losses), passed if value = None
|
||||
'''
|
||||
if type(log_dict) is dict:
|
||||
log_dict = {key: val for key, val in log_dict.items() if (
|
||||
val is not None and not math.isnan(val))}
|
||||
if log_dict is None:
|
||||
pass
|
||||
elif len(log_dict) > 0:
|
||||
if 'align' in log_name or 'spec' in log_name:
|
||||
img, form = log_dict
|
||||
self.log.add_image(
|
||||
log_name, img, global_step=self.step, dataformats=form)
|
||||
elif 'text' in log_name or 'hyp' in log_name:
|
||||
self.log.add_text(log_name, log_dict, self.step)
|
||||
else:
|
||||
self.log.add_scalars(log_name, log_dict, self.step)
|
||||
|
||||
def save_checkpoint(self, f_name, metric, score, show_msg=True):
|
||||
''''
|
||||
Ckpt saver
|
||||
f_name - <str> the name of ckpt file (w/o prefix) to store, overwrite if existed
|
||||
score - <float> The value of metric used to evaluate model
|
||||
'''
|
||||
ckpt_path = os.path.join(self.ckpdir, f_name)
|
||||
full_dict = {
|
||||
"model": self.model.state_dict(),
|
||||
"optimizer": self.optimizer.get_opt_state_dict(),
|
||||
"global_step": self.step,
|
||||
metric: score
|
||||
}
|
||||
|
||||
torch.save(full_dict, ckpt_path)
|
||||
if show_msg:
|
||||
self.verbose("Saved checkpoint (step = {}, {} = {:.2f}) and status @ {}".
|
||||
format(human_format(self.step), metric, score, ckpt_path))
|
||||
|
||||
|
||||
# ----------------------------------- Abtract Methods ------------------------------------------ #
|
||||
@abc.abstractmethod
|
||||
def load_data(self):
|
||||
'''
|
||||
Called by main to load all data
|
||||
After this call, data related attributes should be setup (e.g. self.tr_set, self.dev_set)
|
||||
No return value
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_model(self):
|
||||
'''
|
||||
Called by main to set models
|
||||
After this call, model related attributes should be setup (e.g. self.l2_loss)
|
||||
The followings MUST be setup
|
||||
- self.model (torch.nn.Module)
|
||||
- self.optimizer (src.Optimizer),
|
||||
init. w/ self.optimizer = src.Optimizer(self.model.parameters(),**self.config['hparas'])
|
||||
Loading pre-trained model should also be performed here
|
||||
No return value
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def exec(self):
|
||||
'''
|
||||
Called by main to execute training/inference
|
||||
'''
|
||||
raise NotImplementedError
|
||||
288
models/ppg2mel/train/train_linglf02mel_seq2seq_oneshotvc.py
Normal file
288
models/ppg2mel/train/train_linglf02mel_seq2seq_oneshotvc.py
Normal file
@@ -0,0 +1,288 @@
|
||||
import os, sys
|
||||
# sys.path.append('/home/shaunxliu/projects/nnsp')
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
from .solver import BaseSolver
|
||||
from utils.data_load import OneshotVcDataset, MultiSpkVcCollate
|
||||
# from src.rnn_ppg2mel import BiRnnPpg2MelModel
|
||||
# from src.mel_decoder_mol_encAddlf0 import MelDecoderMOL
|
||||
from .loss import MaskedMSELoss
|
||||
from .optim import Optimizer
|
||||
from utils.util import human_format
|
||||
from models.ppg2mel import MelDecoderMOLv2
|
||||
|
||||
|
||||
class Solver(BaseSolver):
|
||||
"""Customized Solver."""
|
||||
def __init__(self, config, paras, mode):
|
||||
super().__init__(config, paras, mode)
|
||||
self.num_att_plots = 5
|
||||
self.att_ws_dir = f"{self.logdir}/att_ws"
|
||||
os.makedirs(self.att_ws_dir, exist_ok=True)
|
||||
self.best_loss = np.inf
|
||||
|
||||
def fetch_data(self, data):
|
||||
"""Move data to device"""
|
||||
data = [i.to(self.device) for i in data]
|
||||
return data
|
||||
|
||||
def load_data(self):
|
||||
""" Load data for training/validation/plotting."""
|
||||
train_dataset = OneshotVcDataset(
|
||||
meta_file=self.config.data.train_fid_list,
|
||||
vctk_ppg_dir=self.config.data.vctk_ppg_dir,
|
||||
libri_ppg_dir=self.config.data.libri_ppg_dir,
|
||||
vctk_f0_dir=self.config.data.vctk_f0_dir,
|
||||
libri_f0_dir=self.config.data.libri_f0_dir,
|
||||
vctk_wav_dir=self.config.data.vctk_wav_dir,
|
||||
libri_wav_dir=self.config.data.libri_wav_dir,
|
||||
vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir,
|
||||
libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir,
|
||||
ppg_file_ext=self.config.data.ppg_file_ext,
|
||||
min_max_norm_mel=self.config.data.min_max_norm_mel,
|
||||
mel_min=self.config.data.mel_min,
|
||||
mel_max=self.config.data.mel_max,
|
||||
)
|
||||
dev_dataset = OneshotVcDataset(
|
||||
meta_file=self.config.data.dev_fid_list,
|
||||
vctk_ppg_dir=self.config.data.vctk_ppg_dir,
|
||||
libri_ppg_dir=self.config.data.libri_ppg_dir,
|
||||
vctk_f0_dir=self.config.data.vctk_f0_dir,
|
||||
libri_f0_dir=self.config.data.libri_f0_dir,
|
||||
vctk_wav_dir=self.config.data.vctk_wav_dir,
|
||||
libri_wav_dir=self.config.data.libri_wav_dir,
|
||||
vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir,
|
||||
libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir,
|
||||
ppg_file_ext=self.config.data.ppg_file_ext,
|
||||
min_max_norm_mel=self.config.data.min_max_norm_mel,
|
||||
mel_min=self.config.data.mel_min,
|
||||
mel_max=self.config.data.mel_max,
|
||||
)
|
||||
self.train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
num_workers=self.paras.njobs,
|
||||
shuffle=True,
|
||||
batch_size=self.config.hparas.batch_size,
|
||||
pin_memory=False,
|
||||
drop_last=True,
|
||||
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
|
||||
use_spk_dvec=True),
|
||||
)
|
||||
self.dev_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
num_workers=self.paras.njobs,
|
||||
shuffle=False,
|
||||
batch_size=self.config.hparas.batch_size,
|
||||
pin_memory=False,
|
||||
drop_last=False,
|
||||
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
|
||||
use_spk_dvec=True),
|
||||
)
|
||||
self.plot_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
num_workers=self.paras.njobs,
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
pin_memory=False,
|
||||
drop_last=False,
|
||||
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
|
||||
use_spk_dvec=True,
|
||||
give_uttids=True),
|
||||
)
|
||||
msg = "Have prepared training set and dev set."
|
||||
self.verbose(msg)
|
||||
|
||||
def load_pretrained_params(self):
|
||||
print("Load pretrained model from: ", self.config.data.pretrain_model_file)
|
||||
ignore_layer_prefixes = ["speaker_embedding_table"]
|
||||
pretrain_model_file = self.config.data.pretrain_model_file
|
||||
pretrain_ckpt = torch.load(
|
||||
pretrain_model_file, map_location=self.device
|
||||
)["model"]
|
||||
model_dict = self.model.state_dict()
|
||||
print(self.model)
|
||||
|
||||
# 1. filter out unnecessrary keys
|
||||
for prefix in ignore_layer_prefixes:
|
||||
pretrain_ckpt = {k : v
|
||||
for k, v in pretrain_ckpt.items() if not k.startswith(prefix)
|
||||
}
|
||||
# 2. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrain_ckpt)
|
||||
|
||||
# 3. load the new state dict
|
||||
self.model.load_state_dict(model_dict)
|
||||
|
||||
def set_model(self):
|
||||
"""Setup model and optimizer"""
|
||||
# Model
|
||||
print("[INFO] Model name: ", self.config["model_name"])
|
||||
self.model = MelDecoderMOLv2(
|
||||
**self.config["model"]
|
||||
).to(self.device)
|
||||
# self.load_pretrained_params()
|
||||
|
||||
# model_params = [{'params': self.model.spk_embedding.weight}]
|
||||
model_params = [{'params': self.model.parameters()}]
|
||||
|
||||
# Loss criterion
|
||||
self.loss_criterion = MaskedMSELoss(self.config.model.frames_per_step)
|
||||
|
||||
# Optimizer
|
||||
self.optimizer = Optimizer(model_params, **self.config["hparas"])
|
||||
self.verbose(self.optimizer.create_msg())
|
||||
|
||||
# Automatically load pre-trained model if self.paras.load is given
|
||||
self.load_ckpt()
|
||||
|
||||
def exec(self):
|
||||
self.verbose("Total training steps {}.".format(
|
||||
human_format(self.max_step)))
|
||||
|
||||
mel_loss = None
|
||||
n_epochs = 0
|
||||
# Set as current time
|
||||
self.timer.set()
|
||||
|
||||
while self.step < self.max_step:
|
||||
for data in self.train_dataloader:
|
||||
# Pre-step: updata lr_rate and do zero_grad
|
||||
lr_rate = self.optimizer.pre_step(self.step)
|
||||
total_loss = 0
|
||||
# data to device
|
||||
ppgs, lf0_uvs, mels, in_lengths, \
|
||||
out_lengths, spk_ids, stop_tokens = self.fetch_data(data)
|
||||
self.timer.cnt("rd")
|
||||
mel_outputs, mel_outputs_postnet, predicted_stop = self.model(
|
||||
ppgs,
|
||||
in_lengths,
|
||||
mels,
|
||||
out_lengths,
|
||||
lf0_uvs,
|
||||
spk_ids
|
||||
)
|
||||
mel_loss, stop_loss = self.loss_criterion(
|
||||
mel_outputs,
|
||||
mel_outputs_postnet,
|
||||
mels,
|
||||
out_lengths,
|
||||
stop_tokens,
|
||||
predicted_stop
|
||||
)
|
||||
loss = mel_loss + stop_loss
|
||||
|
||||
self.timer.cnt("fw")
|
||||
|
||||
# Back-prop
|
||||
grad_norm = self.backward(loss)
|
||||
self.step += 1
|
||||
|
||||
# Logger
|
||||
if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
|
||||
self.progress("Tr|loss:{:.4f},mel-loss:{:.4f},stop-loss:{:.4f}|Grad.Norm-{:.2f}|{}"
|
||||
.format(loss.cpu().item(), mel_loss.cpu().item(),
|
||||
stop_loss.cpu().item(), grad_norm, self.timer.show()))
|
||||
self.write_log('loss', {'tr/loss': loss,
|
||||
'tr/mel-loss': mel_loss,
|
||||
'tr/stop-loss': stop_loss})
|
||||
|
||||
# Validation
|
||||
if (self.step == 1) or (self.step % self.valid_step == 0):
|
||||
self.validate()
|
||||
|
||||
# End of step
|
||||
# https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
|
||||
torch.cuda.empty_cache()
|
||||
self.timer.set()
|
||||
if self.step > self.max_step:
|
||||
break
|
||||
n_epochs += 1
|
||||
self.log.close()
|
||||
|
||||
def validate(self):
|
||||
self.model.eval()
|
||||
dev_loss, dev_mel_loss, dev_stop_loss = 0.0, 0.0, 0.0
|
||||
|
||||
for i, data in enumerate(self.dev_dataloader):
|
||||
self.progress('Valid step - {}/{}'.format(i+1, len(self.dev_dataloader)))
|
||||
# Fetch data
|
||||
ppgs, lf0_uvs, mels, in_lengths, \
|
||||
out_lengths, spk_ids, stop_tokens = self.fetch_data(data)
|
||||
with torch.no_grad():
|
||||
mel_outputs, mel_outputs_postnet, predicted_stop = self.model(
|
||||
ppgs,
|
||||
in_lengths,
|
||||
mels,
|
||||
out_lengths,
|
||||
lf0_uvs,
|
||||
spk_ids
|
||||
)
|
||||
mel_loss, stop_loss = self.loss_criterion(
|
||||
mel_outputs,
|
||||
mel_outputs_postnet,
|
||||
mels,
|
||||
out_lengths,
|
||||
stop_tokens,
|
||||
predicted_stop
|
||||
)
|
||||
loss = mel_loss + stop_loss
|
||||
|
||||
dev_loss += loss.cpu().item()
|
||||
dev_mel_loss += mel_loss.cpu().item()
|
||||
dev_stop_loss += stop_loss.cpu().item()
|
||||
|
||||
dev_loss = dev_loss / (i + 1)
|
||||
dev_mel_loss = dev_mel_loss / (i + 1)
|
||||
dev_stop_loss = dev_stop_loss / (i + 1)
|
||||
self.save_checkpoint(f'step_{self.step}.pth', 'loss', dev_loss, show_msg=False)
|
||||
if dev_loss < self.best_loss:
|
||||
self.best_loss = dev_loss
|
||||
self.save_checkpoint(f'best_loss_step_{self.step}.pth', 'loss', dev_loss)
|
||||
self.write_log('loss', {'dv/loss': dev_loss,
|
||||
'dv/mel-loss': dev_mel_loss,
|
||||
'dv/stop-loss': dev_stop_loss})
|
||||
|
||||
# plot attention
|
||||
for i, data in enumerate(self.plot_dataloader):
|
||||
if i == self.num_att_plots:
|
||||
break
|
||||
# Fetch data
|
||||
ppgs, lf0_uvs, mels, in_lengths, \
|
||||
out_lengths, spk_ids, stop_tokens = self.fetch_data(data[:-1])
|
||||
fid = data[-1][0]
|
||||
with torch.no_grad():
|
||||
_, _, _, att_ws = self.model(
|
||||
ppgs,
|
||||
in_lengths,
|
||||
mels,
|
||||
out_lengths,
|
||||
lf0_uvs,
|
||||
spk_ids,
|
||||
output_att_ws=True
|
||||
)
|
||||
att_ws = att_ws.squeeze(0).cpu().numpy()
|
||||
att_ws = att_ws[None]
|
||||
w, h = plt.figaspect(1.0 / len(att_ws))
|
||||
fig = plt.Figure(figsize=(w * 1.3, h * 1.3))
|
||||
axes = fig.subplots(1, len(att_ws))
|
||||
if len(att_ws) == 1:
|
||||
axes = [axes]
|
||||
|
||||
for ax, aw in zip(axes, att_ws):
|
||||
ax.imshow(aw.astype(np.float32), aspect="auto")
|
||||
ax.set_title(f"{fid}")
|
||||
ax.set_xlabel("Input")
|
||||
ax.set_ylabel("Output")
|
||||
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||||
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
|
||||
fig_name = f"{self.att_ws_dir}/{fid}_step{self.step}.png"
|
||||
fig.savefig(fig_name)
|
||||
|
||||
# Resume training
|
||||
self.model.train()
|
||||
|
||||
23
models/ppg2mel/utils/abs_model.py
Normal file
23
models/ppg2mel/utils/abs_model.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
class AbsMelDecoder(torch.nn.Module, ABC):
|
||||
"""The abstract PPG-based voice conversion class
|
||||
This "model" is one of mediator objects for "Task" class.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
bottle_neck_features: torch.Tensor,
|
||||
feature_lengths: torch.Tensor,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
logf0_uv: torch.Tensor = None,
|
||||
spembs: torch.Tensor = None,
|
||||
styleembs: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
79
models/ppg2mel/utils/basic_layers.py
Normal file
79
models/ppg2mel/utils/basic_layers.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
def tile(x, count, dim=0):
|
||||
"""
|
||||
Tiles x on dimension dim count times.
|
||||
"""
|
||||
perm = list(range(len(x.size())))
|
||||
if dim != 0:
|
||||
perm[0], perm[dim] = perm[dim], perm[0]
|
||||
x = x.permute(perm).contiguous()
|
||||
out_size = list(x.size())
|
||||
out_size[0] *= count
|
||||
batch = x.size(0)
|
||||
x = x.view(batch, -1) \
|
||||
.transpose(0, 1) \
|
||||
.repeat(count, 1) \
|
||||
.transpose(0, 1) \
|
||||
.contiguous() \
|
||||
.view(*out_size)
|
||||
if dim != 0:
|
||||
x = x.permute(perm).contiguous()
|
||||
return x
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
||||
super(Linear, self).__init__()
|
||||
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
||||
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear_layer.weight,
|
||||
gain=torch.nn.init.calculate_gain(w_init_gain))
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_layer(x)
|
||||
|
||||
class Conv1d(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
|
||||
padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
|
||||
super(Conv1d, self).__init__()
|
||||
if padding is None:
|
||||
assert(kernel_size % 2 == 1)
|
||||
padding = int(dilation * (kernel_size - 1)/2)
|
||||
|
||||
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation,
|
||||
bias=bias)
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
|
||||
|
||||
def forward(self, x):
|
||||
# x: BxDxT
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
|
||||
def tile(x, count, dim=0):
|
||||
"""
|
||||
Tiles x on dimension dim count times.
|
||||
"""
|
||||
perm = list(range(len(x.size())))
|
||||
if dim != 0:
|
||||
perm[0], perm[dim] = perm[dim], perm[0]
|
||||
x = x.permute(perm).contiguous()
|
||||
out_size = list(x.size())
|
||||
out_size[0] *= count
|
||||
batch = x.size(0)
|
||||
x = x.view(batch, -1) \
|
||||
.transpose(0, 1) \
|
||||
.repeat(count, 1) \
|
||||
.transpose(0, 1) \
|
||||
.contiguous() \
|
||||
.view(*out_size)
|
||||
if dim != 0:
|
||||
x = x.permute(perm).contiguous()
|
||||
return x
|
||||
52
models/ppg2mel/utils/cnn_postnet.py
Normal file
52
models/ppg2mel/utils/cnn_postnet.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .basic_layers import Linear, Conv1d
|
||||
|
||||
|
||||
class Postnet(nn.Module):
|
||||
"""Postnet
|
||||
- Five 1-d convolution with 512 channels and kernel size 5
|
||||
"""
|
||||
def __init__(self, num_mels=80,
|
||||
num_layers=5,
|
||||
hidden_dim=512,
|
||||
kernel_size=5):
|
||||
super(Postnet, self).__init__()
|
||||
self.convolutions = nn.ModuleList()
|
||||
|
||||
self.convolutions.append(
|
||||
nn.Sequential(
|
||||
Conv1d(
|
||||
num_mels, hidden_dim,
|
||||
kernel_size=kernel_size, stride=1,
|
||||
padding=int((kernel_size - 1) / 2),
|
||||
dilation=1, w_init_gain='tanh'),
|
||||
nn.BatchNorm1d(hidden_dim)))
|
||||
|
||||
for i in range(1, num_layers - 1):
|
||||
self.convolutions.append(
|
||||
nn.Sequential(
|
||||
Conv1d(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
kernel_size=kernel_size, stride=1,
|
||||
padding=int((kernel_size - 1) / 2),
|
||||
dilation=1, w_init_gain='tanh'),
|
||||
nn.BatchNorm1d(hidden_dim)))
|
||||
|
||||
self.convolutions.append(
|
||||
nn.Sequential(
|
||||
Conv1d(
|
||||
hidden_dim, num_mels,
|
||||
kernel_size=kernel_size, stride=1,
|
||||
padding=int((kernel_size - 1) / 2),
|
||||
dilation=1, w_init_gain='linear'),
|
||||
nn.BatchNorm1d(num_mels)))
|
||||
|
||||
def forward(self, x):
|
||||
# x: (B, num_mels, T_dec)
|
||||
for i in range(len(self.convolutions) - 1):
|
||||
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
|
||||
x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
|
||||
return x
|
||||
123
models/ppg2mel/utils/mol_attention.py
Normal file
123
models/ppg2mel/utils/mol_attention.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class MOLAttention(nn.Module):
|
||||
""" Discretized Mixture of Logistic (MOL) attention.
|
||||
C.f. Section 5 of "MelNet: A Generative Model for Audio in the Frequency Domain" and
|
||||
GMMv2b model in "Location-relative attention mechanisms for robust long-form speech synthesis".
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
query_dim,
|
||||
r=1,
|
||||
M=5,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
query_dim: attention_rnn_dim.
|
||||
M: number of mixtures.
|
||||
"""
|
||||
super().__init__()
|
||||
if r < 1:
|
||||
self.r = float(r)
|
||||
else:
|
||||
self.r = int(r)
|
||||
self.M = M
|
||||
self.score_mask_value = 0.0 # -float("inf")
|
||||
self.eps = 1e-5
|
||||
# Position arrary for encoder time steps
|
||||
self.J = None
|
||||
# Query layer: [w, sigma,]
|
||||
self.query_layer = torch.nn.Sequential(
|
||||
nn.Linear(query_dim, 256, bias=True),
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 3*M, bias=True)
|
||||
)
|
||||
self.mu_prev = None
|
||||
self.initialize_bias()
|
||||
|
||||
def initialize_bias(self):
|
||||
"""Initialize sigma and Delta."""
|
||||
# sigma
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[self.M:2*self.M], 1.0)
|
||||
# Delta: softplus(1.8545) = 2.0; softplus(3.9815) = 4.0; softplus(0.5413) = 1.0
|
||||
# softplus(-0.432) = 0.5003
|
||||
if self.r == 2:
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 1.8545)
|
||||
elif self.r == 4:
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 3.9815)
|
||||
elif self.r == 1:
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 0.5413)
|
||||
else:
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], -0.432)
|
||||
|
||||
|
||||
def init_states(self, memory):
|
||||
"""Initialize mu_prev and J.
|
||||
This function should be called by the decoder before decoding one batch.
|
||||
Args:
|
||||
memory: (B, T, D_enc) encoder output.
|
||||
"""
|
||||
B, T_enc, _ = memory.size()
|
||||
device = memory.device
|
||||
self.J = torch.arange(0, T_enc + 2.0).to(device) + 0.5 # NOTE: for discretize usage
|
||||
# self.J = memory.new_tensor(np.arange(T_enc), dtype=torch.float)
|
||||
self.mu_prev = torch.zeros(B, self.M).to(device)
|
||||
|
||||
def forward(self, att_rnn_h, memory, memory_pitch=None, mask=None):
|
||||
"""
|
||||
att_rnn_h: attetion rnn hidden state.
|
||||
memory: encoder outputs (B, T_enc, D).
|
||||
mask: binary mask for padded data (B, T_enc).
|
||||
"""
|
||||
# [B, 3M]
|
||||
mixture_params = self.query_layer(att_rnn_h)
|
||||
|
||||
# [B, M]
|
||||
w_hat = mixture_params[:, :self.M]
|
||||
sigma_hat = mixture_params[:, self.M:2*self.M]
|
||||
Delta_hat = mixture_params[:, 2*self.M:3*self.M]
|
||||
|
||||
# print("w_hat: ", w_hat)
|
||||
# print("sigma_hat: ", sigma_hat)
|
||||
# print("Delta_hat: ", Delta_hat)
|
||||
|
||||
# Dropout to de-correlate attention heads
|
||||
w_hat = F.dropout(w_hat, p=0.5, training=self.training) # NOTE(sx): needed?
|
||||
|
||||
# Mixture parameters
|
||||
w = torch.softmax(w_hat, dim=-1) + self.eps
|
||||
sigma = F.softplus(sigma_hat) + self.eps
|
||||
Delta = F.softplus(Delta_hat)
|
||||
mu_cur = self.mu_prev + Delta
|
||||
# print("w:", w)
|
||||
j = self.J[:memory.size(1) + 1]
|
||||
|
||||
# Attention weights
|
||||
# CDF of logistic distribution
|
||||
phi_t = w.unsqueeze(-1) * (1 / (1 + torch.sigmoid(
|
||||
(mu_cur.unsqueeze(-1) - j) / sigma.unsqueeze(-1))))
|
||||
# print("phi_t:", phi_t)
|
||||
|
||||
# Discretize attention weights
|
||||
# (B, T_enc + 1)
|
||||
alpha_t = torch.sum(phi_t, dim=1)
|
||||
alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1]
|
||||
alpha_t[alpha_t == 0] = self.eps
|
||||
# print("alpha_t: ", alpha_t.size())
|
||||
# Apply masking
|
||||
if mask is not None:
|
||||
alpha_t.data.masked_fill_(mask, self.score_mask_value)
|
||||
|
||||
context = torch.bmm(alpha_t.unsqueeze(1), memory).squeeze(1)
|
||||
if memory_pitch is not None:
|
||||
context_pitch = torch.bmm(alpha_t.unsqueeze(1), memory_pitch).squeeze(1)
|
||||
|
||||
self.mu_prev = mu_cur
|
||||
|
||||
if memory_pitch is not None:
|
||||
return context, context_pitch, alpha_t
|
||||
return context, alpha_t
|
||||
|
||||
451
models/ppg2mel/utils/nets_utils.py
Normal file
451
models/ppg2mel/utils/nets_utils.py
Normal file
@@ -0,0 +1,451 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""Network related utility tools."""
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def to_device(m, x):
|
||||
"""Send tensor into the device of the module.
|
||||
|
||||
Args:
|
||||
m (torch.nn.Module): Torch module.
|
||||
x (Tensor): Torch tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: Torch tensor located in the same place as torch module.
|
||||
|
||||
"""
|
||||
assert isinstance(m, torch.nn.Module)
|
||||
device = next(m.parameters()).device
|
||||
return x.to(device)
|
||||
|
||||
|
||||
def pad_list(xs, pad_value):
|
||||
"""Perform padding for the list of tensors.
|
||||
|
||||
Args:
|
||||
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
||||
pad_value (float): Value for padding.
|
||||
|
||||
Returns:
|
||||
Tensor: Padded tensor (B, Tmax, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
||||
>>> x
|
||||
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
||||
>>> pad_list(x, 0)
|
||||
tensor([[1., 1., 1., 1.],
|
||||
[1., 1., 0., 0.],
|
||||
[1., 0., 0., 0.]])
|
||||
|
||||
"""
|
||||
n_batch = len(xs)
|
||||
max_len = max(x.size(0) for x in xs)
|
||||
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
||||
|
||||
for i in range(n_batch):
|
||||
pad[i, :xs[i].size(0)] = xs[i]
|
||||
|
||||
return pad
|
||||
|
||||
|
||||
def make_pad_mask(lengths, xs=None, length_dim=-1):
|
||||
"""Make mask tensor containing indices of padded part.
|
||||
|
||||
Args:
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
|
||||
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
|
||||
|
||||
Returns:
|
||||
Tensor: Mask tensor containing indices of padded part.
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||
|
||||
Examples:
|
||||
With only lengths.
|
||||
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_non_pad_mask(lengths)
|
||||
masks = [[0, 0, 0, 0 ,0],
|
||||
[0, 0, 0, 1, 1],
|
||||
[0, 0, 1, 1, 1]]
|
||||
|
||||
With the reference tensor.
|
||||
|
||||
>>> xs = torch.zeros((3, 2, 4))
|
||||
>>> make_pad_mask(lengths, xs)
|
||||
tensor([[[0, 0, 0, 0],
|
||||
[0, 0, 0, 0]],
|
||||
[[0, 0, 0, 1],
|
||||
[0, 0, 0, 1]],
|
||||
[[0, 0, 1, 1],
|
||||
[0, 0, 1, 1]]], dtype=torch.uint8)
|
||||
>>> xs = torch.zeros((3, 2, 6))
|
||||
>>> make_pad_mask(lengths, xs)
|
||||
tensor([[[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1]],
|
||||
[[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1]],
|
||||
[[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
|
||||
With the reference tensor and dimension indicator.
|
||||
|
||||
>>> xs = torch.zeros((3, 6, 6))
|
||||
>>> make_pad_mask(lengths, xs, 1)
|
||||
tensor([[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1]],
|
||||
[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1]],
|
||||
[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
>>> make_pad_mask(lengths, xs, 2)
|
||||
tensor([[[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1]],
|
||||
[[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1]],
|
||||
[[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
|
||||
"""
|
||||
if length_dim == 0:
|
||||
raise ValueError('length_dim cannot be 0: {}'.format(length_dim))
|
||||
|
||||
if not isinstance(lengths, list):
|
||||
lengths = lengths.tolist()
|
||||
bs = int(len(lengths))
|
||||
if xs is None:
|
||||
maxlen = int(max(lengths))
|
||||
else:
|
||||
maxlen = xs.size(length_dim)
|
||||
|
||||
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
||||
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
|
||||
if xs is not None:
|
||||
assert xs.size(0) == bs, (xs.size(0), bs)
|
||||
|
||||
if length_dim < 0:
|
||||
length_dim = xs.dim() + length_dim
|
||||
# ind = (:, None, ..., None, :, , None, ..., None)
|
||||
ind = tuple(slice(None) if i in (0, length_dim) else None
|
||||
for i in range(xs.dim()))
|
||||
mask = mask[ind].expand_as(xs).to(xs.device)
|
||||
return mask
|
||||
|
||||
|
||||
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
|
||||
"""Make mask tensor containing indices of non-padded part.
|
||||
|
||||
Args:
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
|
||||
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
|
||||
|
||||
Returns:
|
||||
ByteTensor: mask tensor containing indices of padded part.
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||
|
||||
Examples:
|
||||
With only lengths.
|
||||
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_non_pad_mask(lengths)
|
||||
masks = [[1, 1, 1, 1 ,1],
|
||||
[1, 1, 1, 0, 0],
|
||||
[1, 1, 0, 0, 0]]
|
||||
|
||||
With the reference tensor.
|
||||
|
||||
>>> xs = torch.zeros((3, 2, 4))
|
||||
>>> make_non_pad_mask(lengths, xs)
|
||||
tensor([[[1, 1, 1, 1],
|
||||
[1, 1, 1, 1]],
|
||||
[[1, 1, 1, 0],
|
||||
[1, 1, 1, 0]],
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0]]], dtype=torch.uint8)
|
||||
>>> xs = torch.zeros((3, 2, 6))
|
||||
>>> make_non_pad_mask(lengths, xs)
|
||||
tensor([[[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0]],
|
||||
[[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0]],
|
||||
[[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
|
||||
With the reference tensor and dimension indicator.
|
||||
|
||||
>>> xs = torch.zeros((3, 6, 6))
|
||||
>>> make_non_pad_mask(lengths, xs, 1)
|
||||
tensor([[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0]],
|
||||
[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0]],
|
||||
[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
>>> make_non_pad_mask(lengths, xs, 2)
|
||||
tensor([[[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0]],
|
||||
[[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0]],
|
||||
[[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
|
||||
"""
|
||||
return ~make_pad_mask(lengths, xs, length_dim)
|
||||
|
||||
|
||||
def mask_by_length(xs, lengths, fill=0):
|
||||
"""Mask tensor according to length.
|
||||
|
||||
Args:
|
||||
xs (Tensor): Batch of input tensor (B, `*`).
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
fill (int or float): Value to fill masked part.
|
||||
|
||||
Returns:
|
||||
Tensor: Batch of masked input tensor (B, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = torch.arange(5).repeat(3, 1) + 1
|
||||
>>> x
|
||||
tensor([[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 4, 5]])
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> mask_by_length(x, lengths)
|
||||
tensor([[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 0, 0],
|
||||
[1, 2, 0, 0, 0]])
|
||||
|
||||
"""
|
||||
assert xs.size(0) == len(lengths)
|
||||
ret = xs.data.new(*xs.size()).fill_(fill)
|
||||
for i, l in enumerate(lengths):
|
||||
ret[i, :l] = xs[i, :l]
|
||||
return ret
|
||||
|
||||
|
||||
def th_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||
"""Calculate accuracy.
|
||||
|
||||
Args:
|
||||
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||||
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
||||
ignore_label (int): Ignore label id.
|
||||
|
||||
Returns:
|
||||
float: Accuracy value (0.0 - 1.0).
|
||||
|
||||
"""
|
||||
pad_pred = pad_outputs.view(
|
||||
pad_targets.size(0),
|
||||
pad_targets.size(1),
|
||||
pad_outputs.size(1)).argmax(2)
|
||||
mask = pad_targets != ignore_label
|
||||
numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||||
denominator = torch.sum(mask)
|
||||
return float(numerator) / float(denominator)
|
||||
|
||||
|
||||
def to_torch_tensor(x):
|
||||
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
|
||||
|
||||
Args:
|
||||
x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
|
||||
|
||||
Returns:
|
||||
Tensor or ComplexTensor: Type converted inputs.
|
||||
|
||||
Examples:
|
||||
>>> xs = np.ones(3, dtype=np.float32)
|
||||
>>> xs = to_torch_tensor(xs)
|
||||
tensor([1., 1., 1.])
|
||||
>>> xs = torch.ones(3, 4, 5)
|
||||
>>> assert to_torch_tensor(xs) is xs
|
||||
>>> xs = {'real': xs, 'imag': xs}
|
||||
>>> to_torch_tensor(xs)
|
||||
ComplexTensor(
|
||||
Real:
|
||||
tensor([1., 1., 1.])
|
||||
Imag;
|
||||
tensor([1., 1., 1.])
|
||||
)
|
||||
|
||||
"""
|
||||
# If numpy, change to torch tensor
|
||||
if isinstance(x, np.ndarray):
|
||||
if x.dtype.kind == 'c':
|
||||
# Dynamically importing because torch_complex requires python3
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
return ComplexTensor(x)
|
||||
else:
|
||||
return torch.from_numpy(x)
|
||||
|
||||
# If {'real': ..., 'imag': ...}, convert to ComplexTensor
|
||||
elif isinstance(x, dict):
|
||||
# Dynamically importing because torch_complex requires python3
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
if 'real' not in x or 'imag' not in x:
|
||||
raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
|
||||
# Relative importing because of using python3 syntax
|
||||
return ComplexTensor(x['real'], x['imag'])
|
||||
|
||||
# If torch.Tensor, as it is
|
||||
elif isinstance(x, torch.Tensor):
|
||||
return x
|
||||
|
||||
else:
|
||||
error = ("x must be numpy.ndarray, torch.Tensor or a dict like "
|
||||
"{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
|
||||
"but got {}".format(type(x)))
|
||||
try:
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
except Exception:
|
||||
# If PY2
|
||||
raise ValueError(error)
|
||||
else:
|
||||
# If PY3
|
||||
if isinstance(x, ComplexTensor):
|
||||
return x
|
||||
else:
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def get_subsample(train_args, mode, arch):
|
||||
"""Parse the subsampling factors from the training args for the specified `mode` and `arch`.
|
||||
|
||||
Args:
|
||||
train_args: argument Namespace containing options.
|
||||
mode: one of ('asr', 'mt', 'st')
|
||||
arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
|
||||
|
||||
Returns:
|
||||
np.ndarray / List[np.ndarray]: subsampling factors.
|
||||
"""
|
||||
if arch == 'transformer':
|
||||
return np.array([1])
|
||||
|
||||
elif mode == 'mt' and arch == 'rnn':
|
||||
# +1 means input (+1) and layers outputs (train_args.elayer)
|
||||
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||
logging.warning('Subsampling is not performed for machine translation.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif (mode == 'asr' and arch in ('rnn', 'rnn-t')) or \
|
||||
(mode == 'mt' and arch == 'rnn') or \
|
||||
(mode == 'st' and arch == 'rnn'):
|
||||
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||
ss = train_args.subsample.split("_")
|
||||
for j in range(min(train_args.elayers + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif mode == 'asr' and arch == 'rnn_mix':
|
||||
subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int)
|
||||
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||
ss = train_args.subsample.split("_")
|
||||
for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif mode == 'asr' and arch == 'rnn_mulenc':
|
||||
subsample_list = []
|
||||
for idx in range(train_args.num_encs):
|
||||
subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int)
|
||||
if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"):
|
||||
ss = train_args.subsample[idx].split("_")
|
||||
for j in range(min(train_args.elayers[idx] + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Encoder %d: Subsampling is not performed for vgg*. '
|
||||
'It is performed in max pooling layers at CNN.', idx + 1)
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
subsample_list.append(subsample)
|
||||
return subsample_list
|
||||
|
||||
else:
|
||||
raise ValueError('Invalid options: mode={}, arch={}'.format(mode, arch))
|
||||
|
||||
|
||||
def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]):
|
||||
"""Replace keys of old prefix with new prefix in state dict."""
|
||||
# need this list not to break the dict iterator
|
||||
old_keys = [k for k in state_dict if k.startswith(old_prefix)]
|
||||
if len(old_keys) > 0:
|
||||
logging.warning(f'Rename: {old_prefix} -> {new_prefix}')
|
||||
for k in old_keys:
|
||||
v = state_dict.pop(k)
|
||||
new_k = k.replace(old_prefix, new_prefix)
|
||||
state_dict[new_k] = v
|
||||
22
models/ppg2mel/utils/vc_utils.py
Normal file
22
models/ppg2mel/utils/vc_utils.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
|
||||
|
||||
def gcd(a, b):
|
||||
"""Greatest common divisor."""
|
||||
a, b = (a, b) if a >=b else (b, a)
|
||||
if a%b == 0:
|
||||
return b
|
||||
else :
|
||||
return gcd(b, a%b)
|
||||
|
||||
def lcm(a, b):
|
||||
"""Least common multiple"""
|
||||
return a * b // gcd(a, b)
|
||||
|
||||
def get_mask_from_lengths(lengths, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = torch.max(lengths).item()
|
||||
ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
|
||||
mask = (ids < lengths.unsqueeze(1)).bool()
|
||||
return mask
|
||||
|
||||
102
models/ppg_extractor/__init__.py
Normal file
102
models/ppg_extractor/__init__.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import argparse
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
|
||||
from .frontend import DefaultFrontend
|
||||
from .utterance_mvn import UtteranceMVN
|
||||
from .encoder.conformer_encoder import ConformerEncoder
|
||||
|
||||
_model = None # type: PPGModel
|
||||
_device = None
|
||||
|
||||
class PPGModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
frontend,
|
||||
normalizer,
|
||||
encoder,
|
||||
):
|
||||
super().__init__()
|
||||
self.frontend = frontend
|
||||
self.normalize = normalizer
|
||||
self.encoder = encoder
|
||||
|
||||
def forward(self, speech, speech_lengths):
|
||||
"""
|
||||
|
||||
Args:
|
||||
speech (tensor): (B, L)
|
||||
speech_lengths (tensor): (B, )
|
||||
|
||||
Returns:
|
||||
bottle_neck_feats (tensor): (B, L//hop_size, 144)
|
||||
|
||||
"""
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
||||
return encoder_out
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
):
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
|
||||
if self.frontend is not None:
|
||||
# Frontend
|
||||
# e.g. STFT and Feature extract
|
||||
# data_loader may send time-domain signal in this case
|
||||
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||||
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
# No frontend and no feature extract
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return feats, feats_lengths
|
||||
|
||||
def extract_from_wav(self, src_wav):
|
||||
src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(_device)
|
||||
src_wav_lengths = torch.LongTensor([len(src_wav)]).to(_device)
|
||||
return self(src_wav_tensor, src_wav_lengths)
|
||||
|
||||
|
||||
def build_model(args):
|
||||
normalizer = UtteranceMVN(**args.normalize_conf)
|
||||
frontend = DefaultFrontend(**args.frontend_conf)
|
||||
encoder = ConformerEncoder(input_size=80, **args.encoder_conf)
|
||||
model = PPGModel(frontend, normalizer, encoder)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_model(model_file, device=None):
|
||||
global _model, _device
|
||||
|
||||
if device is None:
|
||||
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
_device = device
|
||||
# search a config file
|
||||
model_config_fpaths = list(model_file.parent.rglob("*.yaml"))
|
||||
config_file = model_config_fpaths[0]
|
||||
with config_file.open("r", encoding="utf-8") as f:
|
||||
args = yaml.safe_load(f)
|
||||
|
||||
args = argparse.Namespace(**args)
|
||||
|
||||
model = build_model(args)
|
||||
model_state_dict = model.state_dict()
|
||||
|
||||
ckpt_state_dict = torch.load(model_file, map_location=_device)
|
||||
ckpt_state_dict = {k:v for k,v in ckpt_state_dict.items() if 'encoder' in k}
|
||||
|
||||
model_state_dict.update(ckpt_state_dict)
|
||||
model.load_state_dict(model_state_dict)
|
||||
|
||||
_model = model.eval().to(_device)
|
||||
return _model
|
||||
|
||||
|
||||
398
models/ppg_extractor/e2e_asr_common.py
Normal file
398
models/ppg_extractor/e2e_asr_common.py
Normal file
@@ -0,0 +1,398 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Common functions for ASR."""
|
||||
|
||||
import argparse
|
||||
import editdistance
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
import six
|
||||
import sys
|
||||
|
||||
from itertools import groupby
|
||||
|
||||
|
||||
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
|
||||
"""End detection.
|
||||
|
||||
desribed in Eq. (50) of S. Watanabe et al
|
||||
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
|
||||
|
||||
:param ended_hyps:
|
||||
:param i:
|
||||
:param M:
|
||||
:param D_end:
|
||||
:return:
|
||||
"""
|
||||
if len(ended_hyps) == 0:
|
||||
return False
|
||||
count = 0
|
||||
best_hyp = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[0]
|
||||
for m in six.moves.range(M):
|
||||
# get ended_hyps with their length is i - m
|
||||
hyp_length = i - m
|
||||
hyps_same_length = [x for x in ended_hyps if len(x['yseq']) == hyp_length]
|
||||
if len(hyps_same_length) > 0:
|
||||
best_hyp_same_length = sorted(hyps_same_length, key=lambda x: x['score'], reverse=True)[0]
|
||||
if best_hyp_same_length['score'] - best_hyp['score'] < D_end:
|
||||
count += 1
|
||||
|
||||
if count == M:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
# TODO(takaaki-hori): add different smoothing methods
|
||||
def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
|
||||
"""Obtain label distribution for loss smoothing.
|
||||
|
||||
:param odim:
|
||||
:param lsm_type:
|
||||
:param blank:
|
||||
:param transcript:
|
||||
:return:
|
||||
"""
|
||||
if transcript is not None:
|
||||
with open(transcript, 'rb') as f:
|
||||
trans_json = json.load(f)['utts']
|
||||
|
||||
if lsm_type == 'unigram':
|
||||
assert transcript is not None, 'transcript is required for %s label smoothing' % lsm_type
|
||||
labelcount = np.zeros(odim)
|
||||
for k, v in trans_json.items():
|
||||
ids = np.array([int(n) for n in v['output'][0]['tokenid'].split()])
|
||||
# to avoid an error when there is no text in an uttrance
|
||||
if len(ids) > 0:
|
||||
labelcount[ids] += 1
|
||||
labelcount[odim - 1] = len(transcript) # count <eos>
|
||||
labelcount[labelcount == 0] = 1 # flooring
|
||||
labelcount[blank] = 0 # remove counts for blank
|
||||
labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
|
||||
else:
|
||||
logging.error(
|
||||
"Error: unexpected label smoothing type: %s" % lsm_type)
|
||||
sys.exit()
|
||||
|
||||
return labeldist
|
||||
|
||||
|
||||
def get_vgg2l_odim(idim, in_channel=3, out_channel=128, downsample=True):
|
||||
"""Return the output size of the VGG frontend.
|
||||
|
||||
:param in_channel: input channel size
|
||||
:param out_channel: output channel size
|
||||
:return: output size
|
||||
:rtype int
|
||||
"""
|
||||
idim = idim / in_channel
|
||||
if downsample:
|
||||
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling
|
||||
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling
|
||||
return int(idim) * out_channel # numer of channels
|
||||
|
||||
|
||||
class ErrorCalculator(object):
|
||||
"""Calculate CER and WER for E2E_ASR and CTC models during training.
|
||||
|
||||
:param y_hats: numpy array with predicted text
|
||||
:param y_pads: numpy array with true (target) text
|
||||
:param char_list:
|
||||
:param sym_space:
|
||||
:param sym_blank:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def __init__(self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False,
|
||||
trans_type="char"):
|
||||
"""Construct an ErrorCalculator object."""
|
||||
super(ErrorCalculator, self).__init__()
|
||||
|
||||
self.report_cer = report_cer
|
||||
self.report_wer = report_wer
|
||||
self.trans_type = trans_type
|
||||
self.char_list = char_list
|
||||
self.space = sym_space
|
||||
self.blank = sym_blank
|
||||
self.idx_blank = self.char_list.index(self.blank)
|
||||
if self.space in self.char_list:
|
||||
self.idx_space = self.char_list.index(self.space)
|
||||
else:
|
||||
self.idx_space = None
|
||||
|
||||
def __call__(self, ys_hat, ys_pad, is_ctc=False):
|
||||
"""Calculate sentence-level WER/CER score.
|
||||
|
||||
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
||||
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
||||
:param bool is_ctc: calculate CER score for CTC
|
||||
:return: sentence-level WER score
|
||||
:rtype float
|
||||
:return: sentence-level CER score
|
||||
:rtype float
|
||||
"""
|
||||
cer, wer = None, None
|
||||
if is_ctc:
|
||||
return self.calculate_cer_ctc(ys_hat, ys_pad)
|
||||
elif not self.report_cer and not self.report_wer:
|
||||
return cer, wer
|
||||
|
||||
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
|
||||
if self.report_cer:
|
||||
cer = self.calculate_cer(seqs_hat, seqs_true)
|
||||
|
||||
if self.report_wer:
|
||||
wer = self.calculate_wer(seqs_hat, seqs_true)
|
||||
return cer, wer
|
||||
|
||||
def calculate_cer_ctc(self, ys_hat, ys_pad):
|
||||
"""Calculate sentence-level CER score for CTC.
|
||||
|
||||
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
||||
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
||||
:return: average sentence-level CER score
|
||||
:rtype float
|
||||
"""
|
||||
cers, char_ref_lens = [], []
|
||||
for i, y in enumerate(ys_hat):
|
||||
y_hat = [x[0] for x in groupby(y)]
|
||||
y_true = ys_pad[i]
|
||||
seq_hat, seq_true = [], []
|
||||
for idx in y_hat:
|
||||
idx = int(idx)
|
||||
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
||||
seq_hat.append(self.char_list[int(idx)])
|
||||
|
||||
for idx in y_true:
|
||||
idx = int(idx)
|
||||
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
||||
seq_true.append(self.char_list[int(idx)])
|
||||
if self.trans_type == "char":
|
||||
hyp_chars = "".join(seq_hat)
|
||||
ref_chars = "".join(seq_true)
|
||||
else:
|
||||
hyp_chars = " ".join(seq_hat)
|
||||
ref_chars = " ".join(seq_true)
|
||||
|
||||
if len(ref_chars) > 0:
|
||||
cers.append(editdistance.eval(hyp_chars, ref_chars))
|
||||
char_ref_lens.append(len(ref_chars))
|
||||
|
||||
cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
|
||||
return cer_ctc
|
||||
|
||||
def convert_to_char(self, ys_hat, ys_pad):
|
||||
"""Convert index to character.
|
||||
|
||||
:param torch.Tensor seqs_hat: prediction (batch, seqlen)
|
||||
:param torch.Tensor seqs_true: reference (batch, seqlen)
|
||||
:return: token list of prediction
|
||||
:rtype list
|
||||
:return: token list of reference
|
||||
:rtype list
|
||||
"""
|
||||
seqs_hat, seqs_true = [], []
|
||||
for i, y_hat in enumerate(ys_hat):
|
||||
y_true = ys_pad[i]
|
||||
eos_true = np.where(y_true == -1)[0]
|
||||
eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true)
|
||||
# To avoid wrong higher WER than the one obtained from the decoding
|
||||
# eos from y_true is used to mark the eos in y_hat
|
||||
# because of that y_hats has not padded outs with -1.
|
||||
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]]
|
||||
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
|
||||
# seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
|
||||
seq_hat_text = " ".join(seq_hat).replace(self.space, ' ')
|
||||
seq_hat_text = seq_hat_text.replace(self.blank, '')
|
||||
# seq_true_text = "".join(seq_true).replace(self.space, ' ')
|
||||
seq_true_text = " ".join(seq_true).replace(self.space, ' ')
|
||||
seqs_hat.append(seq_hat_text)
|
||||
seqs_true.append(seq_true_text)
|
||||
return seqs_hat, seqs_true
|
||||
|
||||
def calculate_cer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level CER score.
|
||||
|
||||
:param list seqs_hat: prediction
|
||||
:param list seqs_true: reference
|
||||
:return: average sentence-level CER score
|
||||
:rtype float
|
||||
"""
|
||||
char_eds, char_ref_lens = [], []
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_chars = seq_hat_text.replace(' ', '')
|
||||
ref_chars = seq_true_text.replace(' ', '')
|
||||
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
|
||||
char_ref_lens.append(len(ref_chars))
|
||||
return float(sum(char_eds)) / sum(char_ref_lens)
|
||||
|
||||
def calculate_wer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level WER score.
|
||||
|
||||
:param list seqs_hat: prediction
|
||||
:param list seqs_true: reference
|
||||
:return: average sentence-level WER score
|
||||
:rtype float
|
||||
"""
|
||||
word_eds, word_ref_lens = [], []
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_words = seq_hat_text.split()
|
||||
ref_words = seq_true_text.split()
|
||||
word_eds.append(editdistance.eval(hyp_words, ref_words))
|
||||
word_ref_lens.append(len(ref_words))
|
||||
return float(sum(word_eds)) / sum(word_ref_lens)
|
||||
|
||||
|
||||
class ErrorCalculatorTrans(object):
|
||||
"""Calculate CER and WER for transducer models.
|
||||
|
||||
Args:
|
||||
decoder (nn.Module): decoder module
|
||||
args (Namespace): argument Namespace containing options
|
||||
report_cer (boolean): compute CER option
|
||||
report_wer (boolean): compute WER option
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, decoder, args, report_cer=False, report_wer=False):
|
||||
"""Construct an ErrorCalculator object for transducer model."""
|
||||
super(ErrorCalculatorTrans, self).__init__()
|
||||
|
||||
self.dec = decoder
|
||||
|
||||
recog_args = {'beam_size': args.beam_size,
|
||||
'nbest': args.nbest,
|
||||
'space': args.sym_space,
|
||||
'score_norm_transducer': args.score_norm_transducer}
|
||||
|
||||
self.recog_args = argparse.Namespace(**recog_args)
|
||||
|
||||
self.char_list = args.char_list
|
||||
self.space = args.sym_space
|
||||
self.blank = args.sym_blank
|
||||
|
||||
self.report_cer = args.report_cer
|
||||
self.report_wer = args.report_wer
|
||||
|
||||
def __call__(self, hs_pad, ys_pad):
|
||||
"""Calculate sentence-level WER/CER score for transducer models.
|
||||
|
||||
Args:
|
||||
hs_pad (torch.Tensor): batch of padded input sequence (batch, T, D)
|
||||
ys_pad (torch.Tensor): reference (batch, seqlen)
|
||||
|
||||
Returns:
|
||||
(float): sentence-level CER score
|
||||
(float): sentence-level WER score
|
||||
|
||||
"""
|
||||
cer, wer = None, None
|
||||
|
||||
if not self.report_cer and not self.report_wer:
|
||||
return cer, wer
|
||||
|
||||
batchsize = int(hs_pad.size(0))
|
||||
batch_nbest = []
|
||||
|
||||
for b in six.moves.range(batchsize):
|
||||
if self.recog_args.beam_size == 1:
|
||||
nbest_hyps = self.dec.recognize(hs_pad[b], self.recog_args)
|
||||
else:
|
||||
nbest_hyps = self.dec.recognize_beam(hs_pad[b], self.recog_args)
|
||||
batch_nbest.append(nbest_hyps)
|
||||
|
||||
ys_hat = [nbest_hyp[0]['yseq'][1:] for nbest_hyp in batch_nbest]
|
||||
|
||||
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad.cpu())
|
||||
|
||||
if self.report_cer:
|
||||
cer = self.calculate_cer(seqs_hat, seqs_true)
|
||||
|
||||
if self.report_wer:
|
||||
wer = self.calculate_wer(seqs_hat, seqs_true)
|
||||
|
||||
return cer, wer
|
||||
|
||||
def convert_to_char(self, ys_hat, ys_pad):
|
||||
"""Convert index to character.
|
||||
|
||||
Args:
|
||||
ys_hat (torch.Tensor): prediction (batch, seqlen)
|
||||
ys_pad (torch.Tensor): reference (batch, seqlen)
|
||||
|
||||
Returns:
|
||||
(list): token list of prediction
|
||||
(list): token list of reference
|
||||
|
||||
"""
|
||||
seqs_hat, seqs_true = [], []
|
||||
|
||||
for i, y_hat in enumerate(ys_hat):
|
||||
y_true = ys_pad[i]
|
||||
|
||||
eos_true = np.where(y_true == -1)[0]
|
||||
eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true)
|
||||
|
||||
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]]
|
||||
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
|
||||
|
||||
seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
|
||||
seq_hat_text = seq_hat_text.replace(self.blank, '')
|
||||
seq_true_text = "".join(seq_true).replace(self.space, ' ')
|
||||
|
||||
seqs_hat.append(seq_hat_text)
|
||||
seqs_true.append(seq_true_text)
|
||||
|
||||
return seqs_hat, seqs_true
|
||||
|
||||
def calculate_cer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level CER score for transducer model.
|
||||
|
||||
Args:
|
||||
seqs_hat (torch.Tensor): prediction (batch, seqlen)
|
||||
seqs_true (torch.Tensor): reference (batch, seqlen)
|
||||
|
||||
Returns:
|
||||
(float): average sentence-level CER score
|
||||
|
||||
"""
|
||||
char_eds, char_ref_lens = [], []
|
||||
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_chars = seq_hat_text.replace(' ', '')
|
||||
ref_chars = seq_true_text.replace(' ', '')
|
||||
|
||||
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
|
||||
char_ref_lens.append(len(ref_chars))
|
||||
|
||||
return float(sum(char_eds)) / sum(char_ref_lens)
|
||||
|
||||
def calculate_wer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level WER score for transducer model.
|
||||
|
||||
Args:
|
||||
seqs_hat (torch.Tensor): prediction (batch, seqlen)
|
||||
seqs_true (torch.Tensor): reference (batch, seqlen)
|
||||
|
||||
Returns:
|
||||
(float): average sentence-level WER score
|
||||
|
||||
"""
|
||||
word_eds, word_ref_lens = [], []
|
||||
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_words = seq_hat_text.split()
|
||||
ref_words = seq_true_text.split()
|
||||
|
||||
word_eds.append(editdistance.eval(hyp_words, ref_words))
|
||||
word_ref_lens.append(len(ref_words))
|
||||
|
||||
return float(sum(word_eds)) / sum(word_ref_lens)
|
||||
0
models/ppg_extractor/encoder/__init__.py
Normal file
0
models/ppg_extractor/encoder/__init__.py
Normal file
183
models/ppg_extractor/encoder/attention.py
Normal file
183
models/ppg_extractor/encoder/attention.py
Normal file
@@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Multi-Head Attention layer definition."""
|
||||
|
||||
import math
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class MultiHeadedAttention(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
:param int n_head: the number of head s
|
||||
:param int n_feat: the number of features
|
||||
:param float dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_head, n_feat, dropout_rate):
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super(MultiHeadedAttention, self).__init__()
|
||||
assert n_feat % n_head == 0
|
||||
# We assume d_v always equals d_k
|
||||
self.d_k = n_feat // n_head
|
||||
self.h = n_head
|
||||
self.linear_q = nn.Linear(n_feat, n_feat)
|
||||
self.linear_k = nn.Linear(n_feat, n_feat)
|
||||
self.linear_v = nn.Linear(n_feat, n_feat)
|
||||
self.linear_out = nn.Linear(n_feat, n_feat)
|
||||
self.attn = None
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
def forward_qkv(self, query, key, value):
|
||||
"""Transform query, key and value.
|
||||
|
||||
:param torch.Tensor query: (batch, time1, size)
|
||||
:param torch.Tensor key: (batch, time2, size)
|
||||
:param torch.Tensor value: (batch, time2, size)
|
||||
:return torch.Tensor transformed query, key and value
|
||||
|
||||
"""
|
||||
n_batch = query.size(0)
|
||||
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
||||
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
||||
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
||||
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
||||
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
||||
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward_attention(self, value, scores, mask):
|
||||
"""Compute attention context vector.
|
||||
|
||||
:param torch.Tensor value: (batch, head, time2, size)
|
||||
:param torch.Tensor scores: (batch, head, time1, time2)
|
||||
:param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2)
|
||||
:return torch.Tensor transformed `value` (batch, time1, d_model)
|
||||
weighted by the attention score (batch, time1, time2)
|
||||
|
||||
"""
|
||||
n_batch = value.size(0)
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
min_value = float(
|
||||
numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
|
||||
)
|
||||
scores = scores.masked_fill(mask, min_value)
|
||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||
mask, 0.0
|
||||
) # (batch, head, time1, time2)
|
||||
else:
|
||||
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
||||
|
||||
p_attn = self.dropout(self.attn)
|
||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||
x = (
|
||||
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
||||
) # (batch, time1, d_model)
|
||||
|
||||
return self.linear_out(x) # (batch, time1, d_model)
|
||||
|
||||
def forward(self, query, key, value, mask):
|
||||
"""Compute 'Scaled Dot Product Attention'.
|
||||
|
||||
:param torch.Tensor query: (batch, time1, size)
|
||||
:param torch.Tensor key: (batch, time2, size)
|
||||
:param torch.Tensor value: (batch, time2, size)
|
||||
:param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2)
|
||||
:param torch.nn.Dropout dropout:
|
||||
:return torch.Tensor: attention output (batch, time1, d_model)
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
||||
return self.forward_attention(v, scores, mask)
|
||||
|
||||
|
||||
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
||||
"""Multi-Head Attention layer with relative position encoding.
|
||||
|
||||
Paper: https://arxiv.org/abs/1901.02860
|
||||
|
||||
:param int n_head: the number of head s
|
||||
:param int n_feat: the number of features
|
||||
:param float dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_head, n_feat, dropout_rate):
|
||||
"""Construct an RelPositionMultiHeadedAttention object."""
|
||||
super().__init__(n_head, n_feat, dropout_rate)
|
||||
# linear transformation for positional ecoding
|
||||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
def rel_shift(self, x, zero_triu=False):
|
||||
"""Compute relative positinal encoding.
|
||||
|
||||
:param torch.Tensor x: (batch, time, size)
|
||||
:param bool zero_triu: return the lower triangular part of the matrix
|
||||
"""
|
||||
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
||||
x_padded = torch.cat([zero_pad, x], dim=-1)
|
||||
|
||||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
||||
x = x_padded[:, :, 1:].view_as(x)
|
||||
|
||||
if zero_triu:
|
||||
ones = torch.ones((x.size(2), x.size(3)))
|
||||
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, query, key, value, pos_emb, mask):
|
||||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
||||
|
||||
:param torch.Tensor query: (batch, time1, size)
|
||||
:param torch.Tensor key: (batch, time2, size)
|
||||
:param torch.Tensor value: (batch, time2, size)
|
||||
:param torch.Tensor pos_emb: (batch, time1, size)
|
||||
:param torch.Tensor mask: (batch, time1, time2)
|
||||
:param torch.nn.Dropout dropout:
|
||||
:return torch.Tensor: attention output (batch, time1, d_model)
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
||||
|
||||
n_batch_pos = pos_emb.size(0)
|
||||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
||||
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
||||
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
||||
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
# (batch, head, time1, time2)
|
||||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
||||
|
||||
# compute matrix b and matrix d
|
||||
# (batch, head, time1, time2)
|
||||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
||||
self.d_k
|
||||
) # (batch, head, time1, time2)
|
||||
|
||||
return self.forward_attention(v, scores, mask)
|
||||
262
models/ppg_extractor/encoder/conformer_encoder.py
Normal file
262
models/ppg_extractor/encoder/conformer_encoder.py
Normal file
@@ -0,0 +1,262 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Encoder definition."""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from typing import Callable
|
||||
from typing import Collection
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
from .convolution import ConvolutionModule
|
||||
from .encoder_layer import EncoderLayer
|
||||
from ..nets_utils import get_activation, make_pad_mask
|
||||
from .vgg import VGG2L
|
||||
from .attention import MultiHeadedAttention, RelPositionMultiHeadedAttention
|
||||
from .embedding import PositionalEncoding, ScaledPositionalEncoding, RelPositionalEncoding
|
||||
from .layer_norm import LayerNorm
|
||||
from .multi_layer_conv import Conv1dLinear, MultiLayeredConv1d
|
||||
from .positionwise_feed_forward import PositionwiseFeedForward
|
||||
from .repeat import repeat
|
||||
from .subsampling import Conv2dNoSubsampling, Conv2dSubsampling
|
||||
|
||||
|
||||
class ConformerEncoder(torch.nn.Module):
|
||||
"""Conformer encoder module.
|
||||
|
||||
:param int idim: input dim
|
||||
:param int attention_dim: dimention of attention
|
||||
:param int attention_heads: the number of heads of multi head attention
|
||||
:param int linear_units: the number of units of position-wise feed forward
|
||||
:param int num_blocks: the number of decoder blocks
|
||||
:param float dropout_rate: dropout rate
|
||||
:param float attention_dropout_rate: dropout rate in attention
|
||||
:param float positional_dropout_rate: dropout rate after adding positional encoding
|
||||
:param str or torch.nn.Module input_layer: input layer type
|
||||
:param bool normalize_before: whether to use layer_norm before the first block
|
||||
:param bool concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
:param str positionwise_layer_type: linear of conv1d
|
||||
:param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
|
||||
:param str encoder_pos_enc_layer_type: encoder positional encoding layer type
|
||||
:param str encoder_attn_layer_type: encoder attention layer type
|
||||
:param str activation_type: encoder activation function type
|
||||
:param bool macaron_style: whether to use macaron style for positionwise layer
|
||||
:param bool use_cnn_module: whether to use convolution module
|
||||
:param int cnn_module_kernel: kernerl size of convolution module
|
||||
:param int padding_idx: padding_idx for input_layer=embed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
attention_dim=256,
|
||||
attention_heads=4,
|
||||
linear_units=2048,
|
||||
num_blocks=6,
|
||||
dropout_rate=0.1,
|
||||
positional_dropout_rate=0.1,
|
||||
attention_dropout_rate=0.0,
|
||||
input_layer="conv2d",
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
positionwise_layer_type="linear",
|
||||
positionwise_conv_kernel_size=1,
|
||||
macaron_style=False,
|
||||
pos_enc_layer_type="abs_pos",
|
||||
selfattention_layer_type="selfattn",
|
||||
activation_type="swish",
|
||||
use_cnn_module=False,
|
||||
cnn_module_kernel=31,
|
||||
padding_idx=-1,
|
||||
no_subsample=False,
|
||||
subsample_by_2=False,
|
||||
):
|
||||
"""Construct an Encoder object."""
|
||||
super().__init__()
|
||||
|
||||
self._output_size = attention_dim
|
||||
idim = input_size
|
||||
|
||||
activation = get_activation(activation_type)
|
||||
if pos_enc_layer_type == "abs_pos":
|
||||
pos_enc_class = PositionalEncoding
|
||||
elif pos_enc_layer_type == "scaled_abs_pos":
|
||||
pos_enc_class = ScaledPositionalEncoding
|
||||
elif pos_enc_layer_type == "rel_pos":
|
||||
assert selfattention_layer_type == "rel_selfattn"
|
||||
pos_enc_class = RelPositionalEncoding
|
||||
else:
|
||||
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(idim, attention_dim),
|
||||
torch.nn.LayerNorm(attention_dim),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
logging.info("Encoder input layer type: conv2d")
|
||||
if no_subsample:
|
||||
self.embed = Conv2dNoSubsampling(
|
||||
idim,
|
||||
attention_dim,
|
||||
dropout_rate,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
else:
|
||||
self.embed = Conv2dSubsampling(
|
||||
idim,
|
||||
attention_dim,
|
||||
dropout_rate,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
subsample_by_2, # NOTE(Sx): added by songxiang
|
||||
)
|
||||
elif input_layer == "vgg2l":
|
||||
self.embed = VGG2L(idim, attention_dim)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif isinstance(input_layer, torch.nn.Module):
|
||||
self.embed = torch.nn.Sequential(
|
||||
input_layer,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
self.embed = torch.nn.Sequential(
|
||||
pos_enc_class(attention_dim, positional_dropout_rate)
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
activation,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
logging.info("encoder self-attention layer type = self-attention")
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
elif selfattention_layer_type == "rel_selfattn":
|
||||
assert pos_enc_layer_type == "rel_pos"
|
||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
|
||||
|
||||
convolution_layer = ConvolutionModule
|
||||
convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
attention_dim,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
||||
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input lengths (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
Position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
|
||||
if isinstance(self.embed, (Conv2dSubsampling, Conv2dNoSubsampling, VGG2L)):
|
||||
# print(xs_pad.shape)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
# print(xs_pad[0].size())
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
xs_pad, masks = self.encoders(xs_pad, masks)
|
||||
if isinstance(xs_pad, tuple):
|
||||
xs_pad = xs_pad[0]
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
return xs_pad, olens, None
|
||||
|
||||
# def forward(self, xs, masks):
|
||||
# """Encode input sequence.
|
||||
|
||||
# :param torch.Tensor xs: input tensor
|
||||
# :param torch.Tensor masks: input mask
|
||||
# :return: position embedded tensor and mask
|
||||
# :rtype Tuple[torch.Tensor, torch.Tensor]:
|
||||
# """
|
||||
# if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
|
||||
# xs, masks = self.embed(xs, masks)
|
||||
# else:
|
||||
# xs = self.embed(xs)
|
||||
|
||||
# xs, masks = self.encoders(xs, masks)
|
||||
# if isinstance(xs, tuple):
|
||||
# xs = xs[0]
|
||||
|
||||
# if self.normalize_before:
|
||||
# xs = self.after_norm(xs)
|
||||
# return xs, masks
|
||||
74
models/ppg_extractor/encoder/convolution.py
Normal file
74
models/ppg_extractor/encoder/convolution.py
Normal file
@@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
||||
# Northwestern Polytechnical University (Pengcheng Guo)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""ConvolutionModule definition."""
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model.
|
||||
|
||||
:param int channels: channels of cnn
|
||||
:param int kernel_size: kernerl size of cnn
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
|
||||
self.pointwise_conv1 = nn.Conv1d(
|
||||
channels,
|
||||
2 * channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
self.pointwise_conv2 = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
"""Compute convolution module.
|
||||
|
||||
:param torch.Tensor x: (batch, time, size)
|
||||
:return torch.Tensor: convoluted `value` (batch, time, d_model)
|
||||
"""
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
# GLU mechanism
|
||||
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
||||
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.activation(self.norm(x))
|
||||
|
||||
x = self.pointwise_conv2(x)
|
||||
|
||||
return x.transpose(1, 2)
|
||||
166
models/ppg_extractor/encoder/embedding.py
Normal file
166
models/ppg_extractor/encoder/embedding.py
Normal file
@@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Positonal Encoding Module."""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _pre_hook(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
"""Perform pre-hook in load_state_dict for backward compatibility.
|
||||
|
||||
Note:
|
||||
We saved self.pe until v.0.5.2 but we have omitted it later.
|
||||
Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
|
||||
|
||||
"""
|
||||
k = prefix + "pe"
|
||||
if k in state_dict:
|
||||
state_dict.pop(k)
|
||||
|
||||
|
||||
class PositionalEncoding(torch.nn.Module):
|
||||
"""Positional encoding.
|
||||
|
||||
:param int d_model: embedding dim
|
||||
:param float dropout_rate: dropout rate
|
||||
:param int max_len: maximum input length
|
||||
:param reverse: whether to reverse the input position
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.reverse = reverse
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
self._register_load_state_dict_pre_hook(_pre_hook)
|
||||
|
||||
def extend_pe(self, x):
|
||||
"""Reset the positional encodings."""
|
||||
if self.pe is not None:
|
||||
if self.pe.size(1) >= x.size(1):
|
||||
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
pe = torch.zeros(x.size(1), self.d_model)
|
||||
if self.reverse:
|
||||
position = torch.arange(
|
||||
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
||||
).unsqueeze(1)
|
||||
else:
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale + self.pe[:, : x.size(1)]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class ScaledPositionalEncoding(PositionalEncoding):
|
||||
"""Scaled positional encoding module.
|
||||
|
||||
See also: Sec. 3.2 https://arxiv.org/pdf/1809.08895.pdf
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000):
|
||||
"""Initialize class.
|
||||
|
||||
:param int d_model: embedding dim
|
||||
:param float dropout_rate: dropout rate
|
||||
:param int max_len: maximum input length
|
||||
|
||||
"""
|
||||
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
|
||||
|
||||
def reset_parameters(self):
|
||||
"""Reset parameters."""
|
||||
self.alpha.data = torch.tensor(1.0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x + self.alpha * self.pe[:, : x.size(1)]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class RelPositionalEncoding(PositionalEncoding):
|
||||
"""Relitive positional encoding module.
|
||||
|
||||
See : Appendix B in https://arxiv.org/abs/1901.02860
|
||||
|
||||
:param int d_model: embedding dim
|
||||
:param float dropout_rate: dropout rate
|
||||
:param int max_len: maximum input length
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000):
|
||||
"""Initialize class.
|
||||
|
||||
:param int d_model: embedding dim
|
||||
:param float dropout_rate: dropout rate
|
||||
:param int max_len: maximum input length
|
||||
|
||||
"""
|
||||
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
||||
|
||||
def forward(self, x):
|
||||
"""Compute positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: x. Its shape is (batch, time, ...)
|
||||
torch.Tensor: pos_emb. Its shape is (1, time, ...)
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale
|
||||
pos_emb = self.pe[:, : x.size(1)]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
217
models/ppg_extractor/encoder/encoder.py
Normal file
217
models/ppg_extractor/encoder/encoder.py
Normal file
@@ -0,0 +1,217 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Encoder definition."""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule
|
||||
from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer
|
||||
from espnet.nets.pytorch_backend.nets_utils import get_activation
|
||||
from espnet.nets.pytorch_backend.transducer.vgg import VGG2L
|
||||
from espnet.nets.pytorch_backend.transformer.attention import (
|
||||
MultiHeadedAttention, # noqa: H301
|
||||
RelPositionMultiHeadedAttention, # noqa: H301
|
||||
)
|
||||
from espnet.nets.pytorch_backend.transformer.embedding import (
|
||||
PositionalEncoding, # noqa: H301
|
||||
ScaledPositionalEncoding, # noqa: H301
|
||||
RelPositionalEncoding, # noqa: H301
|
||||
)
|
||||
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
|
||||
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear
|
||||
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d
|
||||
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from espnet.nets.pytorch_backend.transformer.repeat import repeat
|
||||
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling
|
||||
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
"""Conformer encoder module.
|
||||
|
||||
:param int idim: input dim
|
||||
:param int attention_dim: dimention of attention
|
||||
:param int attention_heads: the number of heads of multi head attention
|
||||
:param int linear_units: the number of units of position-wise feed forward
|
||||
:param int num_blocks: the number of decoder blocks
|
||||
:param float dropout_rate: dropout rate
|
||||
:param float attention_dropout_rate: dropout rate in attention
|
||||
:param float positional_dropout_rate: dropout rate after adding positional encoding
|
||||
:param str or torch.nn.Module input_layer: input layer type
|
||||
:param bool normalize_before: whether to use layer_norm before the first block
|
||||
:param bool concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
:param str positionwise_layer_type: linear of conv1d
|
||||
:param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
|
||||
:param str encoder_pos_enc_layer_type: encoder positional encoding layer type
|
||||
:param str encoder_attn_layer_type: encoder attention layer type
|
||||
:param str activation_type: encoder activation function type
|
||||
:param bool macaron_style: whether to use macaron style for positionwise layer
|
||||
:param bool use_cnn_module: whether to use convolution module
|
||||
:param int cnn_module_kernel: kernerl size of convolution module
|
||||
:param int padding_idx: padding_idx for input_layer=embed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
idim,
|
||||
attention_dim=256,
|
||||
attention_heads=4,
|
||||
linear_units=2048,
|
||||
num_blocks=6,
|
||||
dropout_rate=0.1,
|
||||
positional_dropout_rate=0.1,
|
||||
attention_dropout_rate=0.0,
|
||||
input_layer="conv2d",
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
positionwise_layer_type="linear",
|
||||
positionwise_conv_kernel_size=1,
|
||||
macaron_style=False,
|
||||
pos_enc_layer_type="abs_pos",
|
||||
selfattention_layer_type="selfattn",
|
||||
activation_type="swish",
|
||||
use_cnn_module=False,
|
||||
cnn_module_kernel=31,
|
||||
padding_idx=-1,
|
||||
):
|
||||
"""Construct an Encoder object."""
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
activation = get_activation(activation_type)
|
||||
if pos_enc_layer_type == "abs_pos":
|
||||
pos_enc_class = PositionalEncoding
|
||||
elif pos_enc_layer_type == "scaled_abs_pos":
|
||||
pos_enc_class = ScaledPositionalEncoding
|
||||
elif pos_enc_layer_type == "rel_pos":
|
||||
assert selfattention_layer_type == "rel_selfattn"
|
||||
pos_enc_class = RelPositionalEncoding
|
||||
else:
|
||||
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(idim, attention_dim),
|
||||
torch.nn.LayerNorm(attention_dim),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(
|
||||
idim,
|
||||
attention_dim,
|
||||
dropout_rate,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "vgg2l":
|
||||
self.embed = VGG2L(idim, attention_dim)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif isinstance(input_layer, torch.nn.Module):
|
||||
self.embed = torch.nn.Sequential(
|
||||
input_layer,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
self.embed = torch.nn.Sequential(
|
||||
pos_enc_class(attention_dim, positional_dropout_rate)
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
activation,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
logging.info("encoder self-attention layer type = self-attention")
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
elif selfattention_layer_type == "rel_selfattn":
|
||||
assert pos_enc_layer_type == "rel_pos"
|
||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
|
||||
|
||||
convolution_layer = ConvolutionModule
|
||||
convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
attention_dim,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
||||
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
|
||||
def forward(self, xs, masks):
|
||||
"""Encode input sequence.
|
||||
|
||||
:param torch.Tensor xs: input tensor
|
||||
:param torch.Tensor masks: input mask
|
||||
:return: position embedded tensor and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
|
||||
xs, masks = self.embed(xs, masks)
|
||||
else:
|
||||
xs = self.embed(xs)
|
||||
|
||||
xs, masks = self.encoders(xs, masks)
|
||||
if isinstance(xs, tuple):
|
||||
xs = xs[0]
|
||||
|
||||
if self.normalize_before:
|
||||
xs = self.after_norm(xs)
|
||||
return xs, masks
|
||||
152
models/ppg_extractor/encoder/encoder_layer.py
Normal file
152
models/ppg_extractor/encoder/encoder_layer.py
Normal file
@@ -0,0 +1,152 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Encoder self-attention layer definition."""
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .layer_norm import LayerNorm
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
"""Encoder layer module.
|
||||
|
||||
:param int size: input dim
|
||||
:param espnet.nets.pytorch_backend.transformer.attention.
|
||||
MultiHeadedAttention self_attn: self attention module
|
||||
RelPositionMultiHeadedAttention self_attn: self attention module
|
||||
:param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.
|
||||
PositionwiseFeedForward feed_forward:
|
||||
feed forward module
|
||||
:param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward
|
||||
for macaron style
|
||||
PositionwiseFeedForward feed_forward:
|
||||
feed forward module
|
||||
:param espnet.nets.pytorch_backend.conformer.convolution.
|
||||
ConvolutionModule feed_foreard:
|
||||
feed forward module
|
||||
:param float dropout_rate: dropout rate
|
||||
:param bool normalize_before: whether to use layer_norm before the first block
|
||||
:param bool concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
feed_forward_macaron,
|
||||
conv_module,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayer, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.feed_forward_macaron = feed_forward_macaron
|
||||
self.conv_module = conv_module
|
||||
self.norm_ff = LayerNorm(size) # for the FNN module
|
||||
self.norm_mha = LayerNorm(size) # for the MHA module
|
||||
if feed_forward_macaron is not None:
|
||||
self.norm_ff_macaron = LayerNorm(size)
|
||||
self.ff_scale = 0.5
|
||||
else:
|
||||
self.ff_scale = 1.0
|
||||
if self.conv_module is not None:
|
||||
self.norm_conv = LayerNorm(size) # for the CNN module
|
||||
self.norm_final = LayerNorm(size) # for the final output of the block
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
|
||||
def forward(self, x_input, mask, cache=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
:param torch.Tensor x_input: encoded source features, w/o pos_emb
|
||||
tuple((batch, max_time_in, size), (1, max_time_in, size))
|
||||
or (batch, max_time_in, size)
|
||||
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
||||
:param torch.Tensor cache: cache for x (batch, max_time_in - 1, size)
|
||||
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
if isinstance(x_input, tuple):
|
||||
x, pos_emb = x_input[0], x_input[1]
|
||||
else:
|
||||
x, pos_emb = x_input, None
|
||||
|
||||
# whether to use macaron style
|
||||
if self.feed_forward_macaron is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_ff_macaron(x)
|
||||
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_ff_macaron(x)
|
||||
|
||||
# multi-headed self-attention module
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_mha(x)
|
||||
|
||||
if cache is None:
|
||||
x_q = x
|
||||
else:
|
||||
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
|
||||
x_q = x[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
mask = None if mask is None else mask[:, -1:, :]
|
||||
|
||||
if pos_emb is not None:
|
||||
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
|
||||
else:
|
||||
x_att = self.self_attn(x_q, x, x, mask)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat((x, x_att), dim=-1)
|
||||
x = residual + self.concat_linear(x_concat)
|
||||
else:
|
||||
x = residual + self.dropout(x_att)
|
||||
if not self.normalize_before:
|
||||
x = self.norm_mha(x)
|
||||
|
||||
# convolution module
|
||||
if self.conv_module is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_conv(x)
|
||||
x = residual + self.dropout(self.conv_module(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_conv(x)
|
||||
|
||||
# feed forward module
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_ff(x)
|
||||
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_ff(x)
|
||||
|
||||
if self.conv_module is not None:
|
||||
x = self.norm_final(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
if pos_emb is not None:
|
||||
return (x, pos_emb), mask
|
||||
|
||||
return x, mask
|
||||
33
models/ppg_extractor/encoder/layer_norm.py
Normal file
33
models/ppg_extractor/encoder/layer_norm.py
Normal file
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Layer normalization module."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
"""Layer normalization module.
|
||||
|
||||
:param int nout: output dim size
|
||||
:param int dim: dimension to be normalized
|
||||
"""
|
||||
|
||||
def __init__(self, nout, dim=-1):
|
||||
"""Construct an LayerNorm object."""
|
||||
super(LayerNorm, self).__init__(nout, eps=1e-12)
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
"""Apply layer normalization.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:return: layer normalized tensor
|
||||
:rtype torch.Tensor
|
||||
"""
|
||||
if self.dim == -1:
|
||||
return super(LayerNorm, self).forward(x)
|
||||
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
||||
105
models/ppg_extractor/encoder/multi_layer_conv.py
Normal file
105
models/ppg_extractor/encoder/multi_layer_conv.py
Normal file
@@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Layer modules for FFT block in FastSpeech (Feed-forward Transformer)."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MultiLayeredConv1d(torch.nn.Module):
|
||||
"""Multi-layered conv1d for Transformer block.
|
||||
|
||||
This is a module of multi-leyered conv1d designed
|
||||
to replace positionwise feed-forward network
|
||||
in Transforner block, which is introduced in
|
||||
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
||||
|
||||
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
||||
https://arxiv.org/pdf/1905.09263.pdf
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
||||
"""Initialize MultiLayeredConv1d module.
|
||||
|
||||
Args:
|
||||
in_chans (int): Number of input channels.
|
||||
hidden_chans (int): Number of hidden channels.
|
||||
kernel_size (int): Kernel size of conv1d.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
super(MultiLayeredConv1d, self).__init__()
|
||||
self.w_1 = torch.nn.Conv1d(
|
||||
in_chans,
|
||||
hidden_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
self.w_2 = torch.nn.Conv1d(
|
||||
hidden_chans,
|
||||
in_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, x):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of input tensors (B, ..., in_chans).
|
||||
|
||||
Returns:
|
||||
Tensor: Batch of output tensors (B, ..., hidden_chans).
|
||||
|
||||
"""
|
||||
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
||||
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
|
||||
|
||||
|
||||
class Conv1dLinear(torch.nn.Module):
|
||||
"""Conv1D + Linear for Transformer block.
|
||||
|
||||
A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
||||
"""Initialize Conv1dLinear module.
|
||||
|
||||
Args:
|
||||
in_chans (int): Number of input channels.
|
||||
hidden_chans (int): Number of hidden channels.
|
||||
kernel_size (int): Kernel size of conv1d.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
super(Conv1dLinear, self).__init__()
|
||||
self.w_1 = torch.nn.Conv1d(
|
||||
in_chans,
|
||||
hidden_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, x):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of input tensors (B, ..., in_chans).
|
||||
|
||||
Returns:
|
||||
Tensor: Batch of output tensors (B, ..., hidden_chans).
|
||||
|
||||
"""
|
||||
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
||||
return self.w_2(self.dropout(x))
|
||||
31
models/ppg_extractor/encoder/positionwise_feed_forward.py
Normal file
31
models/ppg_extractor/encoder/positionwise_feed_forward.py
Normal file
@@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Positionwise feed forward layer definition."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class PositionwiseFeedForward(torch.nn.Module):
|
||||
"""Positionwise feed forward layer.
|
||||
|
||||
:param int idim: input dimenstion
|
||||
:param int hidden_units: number of hidden units
|
||||
:param float dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
|
||||
"""Construct an PositionwiseFeedForward object."""
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
||||
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward funciton."""
|
||||
return self.w_2(self.dropout(self.activation(self.w_1(x))))
|
||||
30
models/ppg_extractor/encoder/repeat.py
Normal file
30
models/ppg_extractor/encoder/repeat.py
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Repeat the same layer definition."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MultiSequential(torch.nn.Sequential):
|
||||
"""Multi-input multi-output torch.nn.Sequential."""
|
||||
|
||||
def forward(self, *args):
|
||||
"""Repeat."""
|
||||
for m in self:
|
||||
args = m(*args)
|
||||
return args
|
||||
|
||||
|
||||
def repeat(N, fn):
|
||||
"""Repeat module N times.
|
||||
|
||||
:param int N: repeat time
|
||||
:param function fn: function to generate module
|
||||
:return: repeated modules
|
||||
:rtype: MultiSequential
|
||||
"""
|
||||
return MultiSequential(*[fn(n) for n in range(N)])
|
||||
218
models/ppg_extractor/encoder/subsampling.py
Normal file
218
models/ppg_extractor/encoder/subsampling.py
Normal file
@@ -0,0 +1,218 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Subsampling layer definition."""
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
|
||||
|
||||
|
||||
class Conv2dSubsampling(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length or 1/2 length).
|
||||
|
||||
:param int idim: input dim
|
||||
:param int odim: output dim
|
||||
:param flaot dropout_rate: dropout rate
|
||||
:param torch.nn.Module pos_enc: custom position encoding layer
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate, pos_enc=None,
|
||||
subsample_by_2=False,
|
||||
):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super(Conv2dSubsampling, self).__init__()
|
||||
self.subsample_by_2 = subsample_by_2
|
||||
if subsample_by_2:
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * (idim // 2), odim),
|
||||
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
else:
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, kernel_size=4, stride=2, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * (idim // 4), odim),
|
||||
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:param torch.Tensor x_mask: input mask
|
||||
:return: subsampled x and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
if self.subsample_by_2:
|
||||
return x, x_mask[:, :, ::2]
|
||||
else:
|
||||
return x, x_mask[:, :, ::2][:, :, ::2]
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Subsample x.
|
||||
|
||||
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
||||
return the positioning encoding.
|
||||
|
||||
"""
|
||||
if key != -1:
|
||||
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
||||
return self.out[key]
|
||||
|
||||
|
||||
class Conv2dNoSubsampling(torch.nn.Module):
|
||||
"""Convolutional 2D without subsampling.
|
||||
|
||||
:param int idim: input dim
|
||||
:param int odim: output dim
|
||||
:param flaot dropout_rate: dropout rate
|
||||
:param torch.nn.Module pos_enc: custom position encoding layer
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super().__init__()
|
||||
logging.info("Encoder does not do down-sample on mel-spectrogram.")
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, kernel_size=5, stride=1, padding=2),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * idim, odim),
|
||||
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:param torch.Tensor x_mask: input mask
|
||||
:return: subsampled x and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Subsample x.
|
||||
|
||||
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
||||
return the positioning encoding.
|
||||
|
||||
"""
|
||||
if key != -1:
|
||||
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
||||
return self.out[key]
|
||||
|
||||
|
||||
class Conv2dSubsampling6(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/6 length).
|
||||
|
||||
:param int idim: input dim
|
||||
:param int odim: output dim
|
||||
:param flaot dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super(Conv2dSubsampling6, self).__init__()
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 5, 3),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim),
|
||||
PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:param torch.Tensor x_mask: input mask
|
||||
:return: subsampled x and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask[:, :, :-2:2][:, :, :-4:3]
|
||||
|
||||
|
||||
class Conv2dSubsampling8(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/8 length).
|
||||
|
||||
:param int idim: input dim
|
||||
:param int odim: output dim
|
||||
:param flaot dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super(Conv2dSubsampling8, self).__init__()
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim),
|
||||
PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:param torch.Tensor x_mask: input mask
|
||||
:return: subsampled x and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
|
||||
18
models/ppg_extractor/encoder/swish.py
Normal file
18
models/ppg_extractor/encoder/swish.py
Normal file
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
||||
# Northwestern Polytechnical University (Pengcheng Guo)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Swish() activation function for Conformer."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Swish(torch.nn.Module):
|
||||
"""Construct an Swish object."""
|
||||
|
||||
def forward(self, x):
|
||||
"""Return Swich activation function."""
|
||||
return x * torch.sigmoid(x)
|
||||
77
models/ppg_extractor/encoder/vgg.py
Normal file
77
models/ppg_extractor/encoder/vgg.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""VGG2L definition for transformer-transducer."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class VGG2L(torch.nn.Module):
|
||||
"""VGG2L module for transformer-transducer encoder."""
|
||||
|
||||
def __init__(self, idim, odim):
|
||||
"""Construct a VGG2L object.
|
||||
|
||||
Args:
|
||||
idim (int): dimension of inputs
|
||||
odim (int): dimension of outputs
|
||||
|
||||
"""
|
||||
super(VGG2L, self).__init__()
|
||||
|
||||
self.vgg2l = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 64, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(64, 64, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.MaxPool2d((3, 2)),
|
||||
torch.nn.Conv2d(64, 128, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.MaxPool2d((2, 2)),
|
||||
)
|
||||
|
||||
self.output = torch.nn.Linear(128 * ((idim // 2) // 2), odim)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""VGG2L forward for x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): input torch (B, T, idim)
|
||||
x_mask (torch.Tensor): (B, 1, T)
|
||||
|
||||
Returns:
|
||||
x (torch.Tensor): input torch (B, sub(T), attention_dim)
|
||||
x_mask (torch.Tensor): (B, 1, sub(T))
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1)
|
||||
x = self.vgg2l(x)
|
||||
|
||||
b, c, t, f = x.size()
|
||||
|
||||
x = self.output(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
else:
|
||||
x_mask = self.create_new_mask(x_mask, x)
|
||||
|
||||
return x, x_mask
|
||||
|
||||
def create_new_mask(self, x_mask, x):
|
||||
"""Create a subsampled version of x_mask.
|
||||
|
||||
Args:
|
||||
x_mask (torch.Tensor): (B, 1, T)
|
||||
x (torch.Tensor): (B, sub(T), attention_dim)
|
||||
|
||||
Returns:
|
||||
x_mask (torch.Tensor): (B, 1, sub(T))
|
||||
|
||||
"""
|
||||
x_t1 = x_mask.size(2) - (x_mask.size(2) % 3)
|
||||
x_mask = x_mask[:, :, :x_t1][:, :, ::3]
|
||||
|
||||
x_t2 = x_mask.size(2) - (x_mask.size(2) % 2)
|
||||
x_mask = x_mask[:, :, :x_t2][:, :, ::2]
|
||||
|
||||
return x_mask
|
||||
298
models/ppg_extractor/encoders.py
Normal file
298
models/ppg_extractor/encoders.py
Normal file
@@ -0,0 +1,298 @@
|
||||
import logging
|
||||
import six
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pack_padded_sequence
|
||||
from torch.nn.utils.rnn import pad_packed_sequence
|
||||
|
||||
from .e2e_asr_common import get_vgg2l_odim
|
||||
from .nets_utils import make_pad_mask, to_device
|
||||
|
||||
|
||||
class RNNP(torch.nn.Module):
|
||||
"""RNN with projection layer module
|
||||
|
||||
:param int idim: dimension of inputs
|
||||
:param int elayers: number of encoder layers
|
||||
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
|
||||
:param int hdim: number of projection units
|
||||
:param np.ndarray subsample: list of subsampling numbers
|
||||
:param float dropout: dropout rate
|
||||
:param str typ: The RNN type
|
||||
"""
|
||||
|
||||
def __init__(self, idim, elayers, cdim, hdim, subsample, dropout, typ="blstm"):
|
||||
super(RNNP, self).__init__()
|
||||
bidir = typ[0] == "b"
|
||||
for i in six.moves.range(elayers):
|
||||
if i == 0:
|
||||
inputdim = idim
|
||||
else:
|
||||
inputdim = hdim
|
||||
rnn = torch.nn.LSTM(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir,
|
||||
batch_first=True) if "lstm" in typ \
|
||||
else torch.nn.GRU(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir, batch_first=True)
|
||||
setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn)
|
||||
# bottleneck layer to merge
|
||||
if bidir:
|
||||
setattr(self, "bt%d" % i, torch.nn.Linear(2 * cdim, hdim))
|
||||
else:
|
||||
setattr(self, "bt%d" % i, torch.nn.Linear(cdim, hdim))
|
||||
|
||||
self.elayers = elayers
|
||||
self.cdim = cdim
|
||||
self.subsample = subsample
|
||||
self.typ = typ
|
||||
self.bidir = bidir
|
||||
|
||||
def forward(self, xs_pad, ilens, prev_state=None):
|
||||
"""RNNP forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:param torch.Tensor prev_state: batch of previous RNN states
|
||||
:return: batch of hidden state sequences (B, Tmax, hdim)
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
|
||||
elayer_states = []
|
||||
for layer in six.moves.range(self.elayers):
|
||||
xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True, enforce_sorted=False)
|
||||
rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer))
|
||||
rnn.flatten_parameters()
|
||||
if prev_state is not None and rnn.bidirectional:
|
||||
prev_state = reset_backward_rnn_state(prev_state)
|
||||
ys, states = rnn(xs_pack, hx=None if prev_state is None else prev_state[layer])
|
||||
elayer_states.append(states)
|
||||
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
|
||||
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
|
||||
sub = self.subsample[layer + 1]
|
||||
if sub > 1:
|
||||
ys_pad = ys_pad[:, ::sub]
|
||||
ilens = [int(i + 1) // sub for i in ilens]
|
||||
# (sum _utt frame_utt) x dim
|
||||
projected = getattr(self, 'bt' + str(layer)
|
||||
)(ys_pad.contiguous().view(-1, ys_pad.size(2)))
|
||||
if layer == self.elayers - 1:
|
||||
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
|
||||
else:
|
||||
xs_pad = torch.tanh(projected.view(ys_pad.size(0), ys_pad.size(1), -1))
|
||||
|
||||
return xs_pad, ilens, elayer_states # x: utt list of frame x dim
|
||||
|
||||
|
||||
class RNN(torch.nn.Module):
|
||||
"""RNN module
|
||||
|
||||
:param int idim: dimension of inputs
|
||||
:param int elayers: number of encoder layers
|
||||
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
|
||||
:param int hdim: number of final projection units
|
||||
:param float dropout: dropout rate
|
||||
:param str typ: The RNN type
|
||||
"""
|
||||
|
||||
def __init__(self, idim, elayers, cdim, hdim, dropout, typ="blstm"):
|
||||
super(RNN, self).__init__()
|
||||
bidir = typ[0] == "b"
|
||||
self.nbrnn = torch.nn.LSTM(idim, cdim, elayers, batch_first=True,
|
||||
dropout=dropout, bidirectional=bidir) if "lstm" in typ \
|
||||
else torch.nn.GRU(idim, cdim, elayers, batch_first=True, dropout=dropout,
|
||||
bidirectional=bidir)
|
||||
if bidir:
|
||||
self.l_last = torch.nn.Linear(cdim * 2, hdim)
|
||||
else:
|
||||
self.l_last = torch.nn.Linear(cdim, hdim)
|
||||
self.typ = typ
|
||||
|
||||
def forward(self, xs_pad, ilens, prev_state=None):
|
||||
"""RNN forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:param torch.Tensor prev_state: batch of previous RNN states
|
||||
:return: batch of hidden state sequences (B, Tmax, eprojs)
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
|
||||
xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True)
|
||||
self.nbrnn.flatten_parameters()
|
||||
if prev_state is not None and self.nbrnn.bidirectional:
|
||||
# We assume that when previous state is passed, it means that we're streaming the input
|
||||
# and therefore cannot propagate backward BRNN state (otherwise it goes in the wrong direction)
|
||||
prev_state = reset_backward_rnn_state(prev_state)
|
||||
ys, states = self.nbrnn(xs_pack, hx=prev_state)
|
||||
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
|
||||
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
|
||||
# (sum _utt frame_utt) x dim
|
||||
projected = torch.tanh(self.l_last(
|
||||
ys_pad.contiguous().view(-1, ys_pad.size(2))))
|
||||
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
|
||||
return xs_pad, ilens, states # x: utt list of frame x dim
|
||||
|
||||
|
||||
def reset_backward_rnn_state(states):
|
||||
"""Sets backward BRNN states to zeroes - useful in processing of sliding windows over the inputs"""
|
||||
if isinstance(states, (list, tuple)):
|
||||
for state in states:
|
||||
state[1::2] = 0.
|
||||
else:
|
||||
states[1::2] = 0.
|
||||
return states
|
||||
|
||||
|
||||
class VGG2L(torch.nn.Module):
|
||||
"""VGG-like module
|
||||
|
||||
:param int in_channel: number of input channels
|
||||
"""
|
||||
|
||||
def __init__(self, in_channel=1, downsample=True):
|
||||
super(VGG2L, self).__init__()
|
||||
# CNN layer (VGG motivated)
|
||||
self.conv1_1 = torch.nn.Conv2d(in_channel, 64, 3, stride=1, padding=1)
|
||||
self.conv1_2 = torch.nn.Conv2d(64, 64, 3, stride=1, padding=1)
|
||||
self.conv2_1 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1)
|
||||
self.conv2_2 = torch.nn.Conv2d(128, 128, 3, stride=1, padding=1)
|
||||
|
||||
self.in_channel = in_channel
|
||||
self.downsample = downsample
|
||||
if downsample:
|
||||
self.stride = 2
|
||||
else:
|
||||
self.stride = 1
|
||||
|
||||
def forward(self, xs_pad, ilens, **kwargs):
|
||||
"""VGG2L forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:return: batch of padded hidden state sequences (B, Tmax // 4, 128 * D // 4) if downsample
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
|
||||
|
||||
# x: utt x frame x dim
|
||||
# xs_pad = F.pad_sequence(xs_pad)
|
||||
|
||||
# x: utt x 1 (input channel num) x frame x dim
|
||||
xs_pad = xs_pad.view(xs_pad.size(0), xs_pad.size(1), self.in_channel,
|
||||
xs_pad.size(2) // self.in_channel).transpose(1, 2)
|
||||
|
||||
# NOTE: max_pool1d ?
|
||||
xs_pad = F.relu(self.conv1_1(xs_pad))
|
||||
xs_pad = F.relu(self.conv1_2(xs_pad))
|
||||
if self.downsample:
|
||||
xs_pad = F.max_pool2d(xs_pad, 2, stride=self.stride, ceil_mode=True)
|
||||
|
||||
xs_pad = F.relu(self.conv2_1(xs_pad))
|
||||
xs_pad = F.relu(self.conv2_2(xs_pad))
|
||||
if self.downsample:
|
||||
xs_pad = F.max_pool2d(xs_pad, 2, stride=self.stride, ceil_mode=True)
|
||||
if torch.is_tensor(ilens):
|
||||
ilens = ilens.cpu().numpy()
|
||||
else:
|
||||
ilens = np.array(ilens, dtype=np.float32)
|
||||
if self.downsample:
|
||||
ilens = np.array(np.ceil(ilens / 2), dtype=np.int64)
|
||||
ilens = np.array(
|
||||
np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64).tolist()
|
||||
|
||||
# x: utt_list of frame (remove zeropaded frames) x (input channel num x dim)
|
||||
xs_pad = xs_pad.transpose(1, 2)
|
||||
xs_pad = xs_pad.contiguous().view(
|
||||
xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3))
|
||||
return xs_pad, ilens, None # no state in this layer
|
||||
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
"""Encoder module
|
||||
|
||||
:param str etype: type of encoder network
|
||||
:param int idim: number of dimensions of encoder network
|
||||
:param int elayers: number of layers of encoder network
|
||||
:param int eunits: number of lstm units of encoder network
|
||||
:param int eprojs: number of projection units of encoder network
|
||||
:param np.ndarray subsample: list of subsampling numbers
|
||||
:param float dropout: dropout rate
|
||||
:param int in_channel: number of input channels
|
||||
"""
|
||||
|
||||
def __init__(self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1):
|
||||
super(Encoder, self).__init__()
|
||||
typ = etype.lstrip("vgg").rstrip("p")
|
||||
if typ not in ['lstm', 'gru', 'blstm', 'bgru']:
|
||||
logging.error("Error: need to specify an appropriate encoder architecture")
|
||||
|
||||
if etype.startswith("vgg"):
|
||||
if etype[-1] == "p":
|
||||
self.enc = torch.nn.ModuleList([VGG2L(in_channel),
|
||||
RNNP(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits,
|
||||
eprojs,
|
||||
subsample, dropout, typ=typ)])
|
||||
logging.info('Use CNN-VGG + ' + typ.upper() + 'P for encoder')
|
||||
else:
|
||||
self.enc = torch.nn.ModuleList([VGG2L(in_channel),
|
||||
RNN(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits,
|
||||
eprojs,
|
||||
dropout, typ=typ)])
|
||||
logging.info('Use CNN-VGG + ' + typ.upper() + ' for encoder')
|
||||
else:
|
||||
if etype[-1] == "p":
|
||||
self.enc = torch.nn.ModuleList(
|
||||
[RNNP(idim, elayers, eunits, eprojs, subsample, dropout, typ=typ)])
|
||||
logging.info(typ.upper() + ' with every-layer projection for encoder')
|
||||
else:
|
||||
self.enc = torch.nn.ModuleList([RNN(idim, elayers, eunits, eprojs, dropout, typ=typ)])
|
||||
logging.info(typ.upper() + ' without projection for encoder')
|
||||
|
||||
def forward(self, xs_pad, ilens, prev_states=None):
|
||||
"""Encoder forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...)
|
||||
:return: batch of hidden state sequences (B, Tmax, eprojs)
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if prev_states is None:
|
||||
prev_states = [None] * len(self.enc)
|
||||
assert len(prev_states) == len(self.enc)
|
||||
|
||||
current_states = []
|
||||
for module, prev_state in zip(self.enc, prev_states):
|
||||
xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
|
||||
current_states.append(states)
|
||||
|
||||
# make mask to remove bias value in padded part
|
||||
mask = to_device(self, make_pad_mask(ilens).unsqueeze(-1))
|
||||
|
||||
return xs_pad.masked_fill(mask, 0.0), ilens, current_states
|
||||
|
||||
|
||||
def encoder_for(args, idim, subsample):
|
||||
"""Instantiates an encoder module given the program arguments
|
||||
|
||||
:param Namespace args: The arguments
|
||||
:param int or List of integer idim: dimension of input, e.g. 83, or
|
||||
List of dimensions of inputs, e.g. [83,83]
|
||||
:param List or List of List subsample: subsample factors, e.g. [1,2,2,1,1], or
|
||||
List of subsample factors of each encoder. e.g. [[1,2,2,1,1], [1,2,2,1,1]]
|
||||
:rtype torch.nn.Module
|
||||
:return: The encoder module
|
||||
"""
|
||||
num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility
|
||||
if num_encs == 1:
|
||||
# compatible with single encoder asr mode
|
||||
return Encoder(args.etype, idim, args.elayers, args.eunits, args.eprojs, subsample, args.dropout_rate)
|
||||
elif num_encs >= 1:
|
||||
enc_list = torch.nn.ModuleList()
|
||||
for idx in range(num_encs):
|
||||
enc = Encoder(args.etype[idx], idim[idx], args.elayers[idx], args.eunits[idx], args.eprojs, subsample[idx],
|
||||
args.dropout_rate[idx])
|
||||
enc_list.append(enc)
|
||||
return enc_list
|
||||
else:
|
||||
raise ValueError("Number of encoders needs to be more than one. {}".format(num_encs))
|
||||
115
models/ppg_extractor/frontend.py
Normal file
115
models/ppg_extractor/frontend.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import copy
|
||||
from typing import Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
from .log_mel import LogMel
|
||||
from .stft import Stft
|
||||
|
||||
|
||||
class DefaultFrontend(torch.nn.Module):
|
||||
"""Conventional frontend structure for ASR
|
||||
|
||||
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: 16000,
|
||||
n_fft: int = 1024,
|
||||
win_length: int = 800,
|
||||
hop_length: int = 160,
|
||||
center: bool = True,
|
||||
pad_mode: str = "reflect",
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
n_mels: int = 80,
|
||||
fmin: int = None,
|
||||
fmax: int = None,
|
||||
htk: bool = False,
|
||||
norm=1,
|
||||
frontend_conf=None, #Optional[dict] = get_default_kwargs(Frontend),
|
||||
kaldi_padding_mode=False,
|
||||
downsample_rate: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.downsample_rate = downsample_rate
|
||||
|
||||
# Deepcopy (In general, dict shouldn't be used as default arg)
|
||||
frontend_conf = copy.deepcopy(frontend_conf)
|
||||
|
||||
self.stft = Stft(
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
center=center,
|
||||
pad_mode=pad_mode,
|
||||
normalized=normalized,
|
||||
onesided=onesided,
|
||||
kaldi_padding_mode=kaldi_padding_mode
|
||||
)
|
||||
if frontend_conf is not None:
|
||||
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
|
||||
else:
|
||||
self.frontend = None
|
||||
|
||||
self.logmel = LogMel(
|
||||
fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm,
|
||||
)
|
||||
self.n_mels = n_mels
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Domain-conversion: e.g. Stft: time -> time-freq
|
||||
input_stft, feats_lens = self.stft(input, input_lengths)
|
||||
|
||||
assert input_stft.dim() >= 4, input_stft.shape
|
||||
# "2" refers to the real/imag parts of Complex
|
||||
assert input_stft.shape[-1] == 2, input_stft.shape
|
||||
|
||||
# Change torch.Tensor to ComplexTensor
|
||||
# input_stft: (..., F, 2) -> (..., F)
|
||||
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
|
||||
|
||||
# 2. [Option] Speech enhancement
|
||||
if self.frontend is not None:
|
||||
assert isinstance(input_stft, ComplexTensor), type(input_stft)
|
||||
# input_stft: (Batch, Length, [Channel], Freq)
|
||||
input_stft, _, mask = self.frontend(input_stft, feats_lens)
|
||||
|
||||
# 3. [Multi channel case]: Select a channel
|
||||
if input_stft.dim() == 4:
|
||||
# h: (B, T, C, F) -> h: (B, T, F)
|
||||
if self.training:
|
||||
# Select 1ch randomly
|
||||
ch = np.random.randint(input_stft.size(2))
|
||||
input_stft = input_stft[:, :, ch, :]
|
||||
else:
|
||||
# Use the first channel
|
||||
input_stft = input_stft[:, :, 0, :]
|
||||
|
||||
# 4. STFT -> Power spectrum
|
||||
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
|
||||
input_power = input_stft.real ** 2 + input_stft.imag ** 2
|
||||
|
||||
# 5. Feature transform e.g. Stft -> Log-Mel-Fbank
|
||||
# input_power: (Batch, [Channel,] Length, Freq)
|
||||
# -> input_feats: (Batch, Length, Dim)
|
||||
input_feats, _ = self.logmel(input_power, feats_lens)
|
||||
|
||||
# NOTE(sx): pad
|
||||
max_len = input_feats.size(1)
|
||||
if self.downsample_rate > 1 and max_len % self.downsample_rate != 0:
|
||||
padding = self.downsample_rate - max_len % self.downsample_rate
|
||||
# print("Logmel: ", input_feats.size())
|
||||
input_feats = torch.nn.functional.pad(input_feats, (0, 0, 0, padding),
|
||||
"constant", 0)
|
||||
# print("Logmel(after padding): ",input_feats.size())
|
||||
feats_lens[torch.argmax(feats_lens)] = max_len + padding
|
||||
|
||||
return input_feats, feats_lens
|
||||
74
models/ppg_extractor/log_mel.py
Normal file
74
models/ppg_extractor/log_mel.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from .nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class LogMel(torch.nn.Module):
|
||||
"""Convert STFT to fbank feats
|
||||
|
||||
The arguments is same as librosa.filters.mel
|
||||
|
||||
Args:
|
||||
fs: number > 0 [scalar] sampling rate of the incoming signal
|
||||
n_fft: int > 0 [scalar] number of FFT components
|
||||
n_mels: int > 0 [scalar] number of Mel bands to generate
|
||||
fmin: float >= 0 [scalar] lowest frequency (in Hz)
|
||||
fmax: float >= 0 [scalar] highest frequency (in Hz).
|
||||
If `None`, use `fmax = fs / 2.0`
|
||||
htk: use HTK formula instead of Slaney
|
||||
norm: {None, 1, np.inf} [scalar]
|
||||
if 1, divide the triangular mel weights by the width of the mel band
|
||||
(area normalization). Otherwise, leave all the triangles aiming for
|
||||
a peak value of 1.0
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: int = 16000,
|
||||
n_fft: int = 512,
|
||||
n_mels: int = 80,
|
||||
fmin: float = None,
|
||||
fmax: float = None,
|
||||
htk: bool = False,
|
||||
norm=1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
fmin = 0 if fmin is None else fmin
|
||||
fmax = fs / 2 if fmax is None else fmax
|
||||
_mel_options = dict(
|
||||
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
|
||||
)
|
||||
self.mel_options = _mel_options
|
||||
|
||||
# Note(kamo): The mel matrix of librosa is different from kaldi.
|
||||
melmat = librosa.filters.mel(**_mel_options)
|
||||
# melmat: (D2, D1) -> (D1, D2)
|
||||
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
|
||||
inv_mel = np.linalg.pinv(melmat)
|
||||
self.register_buffer("inv_melmat", torch.from_numpy(inv_mel.T).float())
|
||||
|
||||
def extra_repr(self):
|
||||
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
|
||||
|
||||
def forward(
|
||||
self, feat: torch.Tensor, ilens: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
|
||||
mel_feat = torch.matmul(feat, self.melmat)
|
||||
|
||||
logmel_feat = (mel_feat + 1e-20).log()
|
||||
# Zero padding
|
||||
if ilens is not None:
|
||||
logmel_feat = logmel_feat.masked_fill(
|
||||
make_pad_mask(ilens, logmel_feat, 1), 0.0
|
||||
)
|
||||
else:
|
||||
ilens = feat.new_full(
|
||||
[feat.size(0)], fill_value=feat.size(1), dtype=torch.long
|
||||
)
|
||||
return logmel_feat, ilens
|
||||
465
models/ppg_extractor/nets_utils.py
Normal file
465
models/ppg_extractor/nets_utils.py
Normal file
@@ -0,0 +1,465 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""Network related utility tools."""
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def to_device(m, x):
|
||||
"""Send tensor into the device of the module.
|
||||
|
||||
Args:
|
||||
m (torch.nn.Module): Torch module.
|
||||
x (Tensor): Torch tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: Torch tensor located in the same place as torch module.
|
||||
|
||||
"""
|
||||
assert isinstance(m, torch.nn.Module)
|
||||
device = next(m.parameters()).device
|
||||
return x.to(device)
|
||||
|
||||
|
||||
def pad_list(xs, pad_value):
|
||||
"""Perform padding for the list of tensors.
|
||||
|
||||
Args:
|
||||
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
||||
pad_value (float): Value for padding.
|
||||
|
||||
Returns:
|
||||
Tensor: Padded tensor (B, Tmax, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
||||
>>> x
|
||||
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
||||
>>> pad_list(x, 0)
|
||||
tensor([[1., 1., 1., 1.],
|
||||
[1., 1., 0., 0.],
|
||||
[1., 0., 0., 0.]])
|
||||
|
||||
"""
|
||||
n_batch = len(xs)
|
||||
max_len = max(x.size(0) for x in xs)
|
||||
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
||||
|
||||
for i in range(n_batch):
|
||||
pad[i, :xs[i].size(0)] = xs[i]
|
||||
|
||||
return pad
|
||||
|
||||
|
||||
def make_pad_mask(lengths, xs=None, length_dim=-1):
|
||||
"""Make mask tensor containing indices of padded part.
|
||||
|
||||
Args:
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
|
||||
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
|
||||
|
||||
Returns:
|
||||
Tensor: Mask tensor containing indices of padded part.
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||
|
||||
Examples:
|
||||
With only lengths.
|
||||
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_non_pad_mask(lengths)
|
||||
masks = [[0, 0, 0, 0 ,0],
|
||||
[0, 0, 0, 1, 1],
|
||||
[0, 0, 1, 1, 1]]
|
||||
|
||||
With the reference tensor.
|
||||
|
||||
>>> xs = torch.zeros((3, 2, 4))
|
||||
>>> make_pad_mask(lengths, xs)
|
||||
tensor([[[0, 0, 0, 0],
|
||||
[0, 0, 0, 0]],
|
||||
[[0, 0, 0, 1],
|
||||
[0, 0, 0, 1]],
|
||||
[[0, 0, 1, 1],
|
||||
[0, 0, 1, 1]]], dtype=torch.uint8)
|
||||
>>> xs = torch.zeros((3, 2, 6))
|
||||
>>> make_pad_mask(lengths, xs)
|
||||
tensor([[[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1]],
|
||||
[[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1]],
|
||||
[[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
|
||||
With the reference tensor and dimension indicator.
|
||||
|
||||
>>> xs = torch.zeros((3, 6, 6))
|
||||
>>> make_pad_mask(lengths, xs, 1)
|
||||
tensor([[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1]],
|
||||
[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1]],
|
||||
[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
>>> make_pad_mask(lengths, xs, 2)
|
||||
tensor([[[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1]],
|
||||
[[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1]],
|
||||
[[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
|
||||
"""
|
||||
if length_dim == 0:
|
||||
raise ValueError('length_dim cannot be 0: {}'.format(length_dim))
|
||||
|
||||
if not isinstance(lengths, list):
|
||||
lengths = lengths.tolist()
|
||||
bs = int(len(lengths))
|
||||
if xs is None:
|
||||
maxlen = int(max(lengths))
|
||||
else:
|
||||
maxlen = xs.size(length_dim)
|
||||
|
||||
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
||||
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
|
||||
if xs is not None:
|
||||
assert xs.size(0) == bs, (xs.size(0), bs)
|
||||
|
||||
if length_dim < 0:
|
||||
length_dim = xs.dim() + length_dim
|
||||
# ind = (:, None, ..., None, :, , None, ..., None)
|
||||
ind = tuple(slice(None) if i in (0, length_dim) else None
|
||||
for i in range(xs.dim()))
|
||||
mask = mask[ind].expand_as(xs).to(xs.device)
|
||||
return mask
|
||||
|
||||
|
||||
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
|
||||
"""Make mask tensor containing indices of non-padded part.
|
||||
|
||||
Args:
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
|
||||
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
|
||||
|
||||
Returns:
|
||||
ByteTensor: mask tensor containing indices of padded part.
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||
|
||||
Examples:
|
||||
With only lengths.
|
||||
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_non_pad_mask(lengths)
|
||||
masks = [[1, 1, 1, 1 ,1],
|
||||
[1, 1, 1, 0, 0],
|
||||
[1, 1, 0, 0, 0]]
|
||||
|
||||
With the reference tensor.
|
||||
|
||||
>>> xs = torch.zeros((3, 2, 4))
|
||||
>>> make_non_pad_mask(lengths, xs)
|
||||
tensor([[[1, 1, 1, 1],
|
||||
[1, 1, 1, 1]],
|
||||
[[1, 1, 1, 0],
|
||||
[1, 1, 1, 0]],
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0]]], dtype=torch.uint8)
|
||||
>>> xs = torch.zeros((3, 2, 6))
|
||||
>>> make_non_pad_mask(lengths, xs)
|
||||
tensor([[[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0]],
|
||||
[[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0]],
|
||||
[[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
|
||||
With the reference tensor and dimension indicator.
|
||||
|
||||
>>> xs = torch.zeros((3, 6, 6))
|
||||
>>> make_non_pad_mask(lengths, xs, 1)
|
||||
tensor([[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0]],
|
||||
[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0]],
|
||||
[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
>>> make_non_pad_mask(lengths, xs, 2)
|
||||
tensor([[[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0]],
|
||||
[[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0]],
|
||||
[[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
|
||||
"""
|
||||
return ~make_pad_mask(lengths, xs, length_dim)
|
||||
|
||||
|
||||
def mask_by_length(xs, lengths, fill=0):
|
||||
"""Mask tensor according to length.
|
||||
|
||||
Args:
|
||||
xs (Tensor): Batch of input tensor (B, `*`).
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
fill (int or float): Value to fill masked part.
|
||||
|
||||
Returns:
|
||||
Tensor: Batch of masked input tensor (B, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = torch.arange(5).repeat(3, 1) + 1
|
||||
>>> x
|
||||
tensor([[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 4, 5]])
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> mask_by_length(x, lengths)
|
||||
tensor([[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 0, 0],
|
||||
[1, 2, 0, 0, 0]])
|
||||
|
||||
"""
|
||||
assert xs.size(0) == len(lengths)
|
||||
ret = xs.data.new(*xs.size()).fill_(fill)
|
||||
for i, l in enumerate(lengths):
|
||||
ret[i, :l] = xs[i, :l]
|
||||
return ret
|
||||
|
||||
|
||||
def th_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||
"""Calculate accuracy.
|
||||
|
||||
Args:
|
||||
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||||
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
||||
ignore_label (int): Ignore label id.
|
||||
|
||||
Returns:
|
||||
float: Accuracy value (0.0 - 1.0).
|
||||
|
||||
"""
|
||||
pad_pred = pad_outputs.view(
|
||||
pad_targets.size(0),
|
||||
pad_targets.size(1),
|
||||
pad_outputs.size(1)).argmax(2)
|
||||
mask = pad_targets != ignore_label
|
||||
numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||||
denominator = torch.sum(mask)
|
||||
return float(numerator) / float(denominator)
|
||||
|
||||
|
||||
def to_torch_tensor(x):
|
||||
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
|
||||
|
||||
Args:
|
||||
x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
|
||||
|
||||
Returns:
|
||||
Tensor or ComplexTensor: Type converted inputs.
|
||||
|
||||
Examples:
|
||||
>>> xs = np.ones(3, dtype=np.float32)
|
||||
>>> xs = to_torch_tensor(xs)
|
||||
tensor([1., 1., 1.])
|
||||
>>> xs = torch.ones(3, 4, 5)
|
||||
>>> assert to_torch_tensor(xs) is xs
|
||||
>>> xs = {'real': xs, 'imag': xs}
|
||||
>>> to_torch_tensor(xs)
|
||||
ComplexTensor(
|
||||
Real:
|
||||
tensor([1., 1., 1.])
|
||||
Imag;
|
||||
tensor([1., 1., 1.])
|
||||
)
|
||||
|
||||
"""
|
||||
# If numpy, change to torch tensor
|
||||
if isinstance(x, np.ndarray):
|
||||
if x.dtype.kind == 'c':
|
||||
# Dynamically importing because torch_complex requires python3
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
return ComplexTensor(x)
|
||||
else:
|
||||
return torch.from_numpy(x)
|
||||
|
||||
# If {'real': ..., 'imag': ...}, convert to ComplexTensor
|
||||
elif isinstance(x, dict):
|
||||
# Dynamically importing because torch_complex requires python3
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
if 'real' not in x or 'imag' not in x:
|
||||
raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
|
||||
# Relative importing because of using python3 syntax
|
||||
return ComplexTensor(x['real'], x['imag'])
|
||||
|
||||
# If torch.Tensor, as it is
|
||||
elif isinstance(x, torch.Tensor):
|
||||
return x
|
||||
|
||||
else:
|
||||
error = ("x must be numpy.ndarray, torch.Tensor or a dict like "
|
||||
"{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
|
||||
"but got {}".format(type(x)))
|
||||
try:
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
except Exception:
|
||||
# If PY2
|
||||
raise ValueError(error)
|
||||
else:
|
||||
# If PY3
|
||||
if isinstance(x, ComplexTensor):
|
||||
return x
|
||||
else:
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def get_subsample(train_args, mode, arch):
|
||||
"""Parse the subsampling factors from the training args for the specified `mode` and `arch`.
|
||||
|
||||
Args:
|
||||
train_args: argument Namespace containing options.
|
||||
mode: one of ('asr', 'mt', 'st')
|
||||
arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
|
||||
|
||||
Returns:
|
||||
np.ndarray / List[np.ndarray]: subsampling factors.
|
||||
"""
|
||||
if arch == 'transformer':
|
||||
return np.array([1])
|
||||
|
||||
elif mode == 'mt' and arch == 'rnn':
|
||||
# +1 means input (+1) and layers outputs (train_args.elayer)
|
||||
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||
logging.warning('Subsampling is not performed for machine translation.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif (mode == 'asr' and arch in ('rnn', 'rnn-t')) or \
|
||||
(mode == 'mt' and arch == 'rnn') or \
|
||||
(mode == 'st' and arch == 'rnn'):
|
||||
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||
ss = train_args.subsample.split("_")
|
||||
for j in range(min(train_args.elayers + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif mode == 'asr' and arch == 'rnn_mix':
|
||||
subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int)
|
||||
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||
ss = train_args.subsample.split("_")
|
||||
for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif mode == 'asr' and arch == 'rnn_mulenc':
|
||||
subsample_list = []
|
||||
for idx in range(train_args.num_encs):
|
||||
subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int)
|
||||
if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"):
|
||||
ss = train_args.subsample[idx].split("_")
|
||||
for j in range(min(train_args.elayers[idx] + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Encoder %d: Subsampling is not performed for vgg*. '
|
||||
'It is performed in max pooling layers at CNN.', idx + 1)
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
subsample_list.append(subsample)
|
||||
return subsample_list
|
||||
|
||||
else:
|
||||
raise ValueError('Invalid options: mode={}, arch={}'.format(mode, arch))
|
||||
|
||||
|
||||
def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]):
|
||||
"""Replace keys of old prefix with new prefix in state dict."""
|
||||
# need this list not to break the dict iterator
|
||||
old_keys = [k for k in state_dict if k.startswith(old_prefix)]
|
||||
if len(old_keys) > 0:
|
||||
logging.warning(f'Rename: {old_prefix} -> {new_prefix}')
|
||||
for k in old_keys:
|
||||
v = state_dict.pop(k)
|
||||
new_k = k.replace(old_prefix, new_prefix)
|
||||
state_dict[new_k] = v
|
||||
|
||||
def get_activation(act):
|
||||
"""Return activation function."""
|
||||
# Lazy load to avoid unused import
|
||||
from .encoder.swish import Swish
|
||||
|
||||
activation_funcs = {
|
||||
"hardtanh": torch.nn.Hardtanh,
|
||||
"relu": torch.nn.ReLU,
|
||||
"selu": torch.nn.SELU,
|
||||
"swish": Swish,
|
||||
}
|
||||
|
||||
return activation_funcs[act]()
|
||||
118
models/ppg_extractor/stft.py
Normal file
118
models/ppg_extractor/stft.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from .nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class Stft(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_fft: int = 512,
|
||||
win_length: Union[int, None] = 512,
|
||||
hop_length: int = 128,
|
||||
center: bool = True,
|
||||
pad_mode: str = "reflect",
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
kaldi_padding_mode=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
if win_length is None:
|
||||
self.win_length = n_fft
|
||||
else:
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.center = center
|
||||
self.pad_mode = pad_mode
|
||||
self.normalized = normalized
|
||||
self.onesided = onesided
|
||||
self.kaldi_padding_mode = kaldi_padding_mode
|
||||
if self.kaldi_padding_mode:
|
||||
self.win_length = 400
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
f"n_fft={self.n_fft}, "
|
||||
f"win_length={self.win_length}, "
|
||||
f"hop_length={self.hop_length}, "
|
||||
f"center={self.center}, "
|
||||
f"pad_mode={self.pad_mode}, "
|
||||
f"normalized={self.normalized}, "
|
||||
f"onesided={self.onesided}"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, ilens: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""STFT forward function.
|
||||
|
||||
Args:
|
||||
input: (Batch, Nsamples) or (Batch, Nsample, Channels)
|
||||
ilens: (Batch)
|
||||
Returns:
|
||||
output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
|
||||
|
||||
"""
|
||||
bs = input.size(0)
|
||||
if input.dim() == 3:
|
||||
multi_channel = True
|
||||
# input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
|
||||
input = input.transpose(1, 2).reshape(-1, input.size(1))
|
||||
else:
|
||||
multi_channel = False
|
||||
|
||||
# output: (Batch, Freq, Frames, 2=real_imag)
|
||||
# or (Batch, Channel, Freq, Frames, 2=real_imag)
|
||||
if not self.kaldi_padding_mode:
|
||||
output = torch.stft(
|
||||
input,
|
||||
n_fft=self.n_fft,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length,
|
||||
center=self.center,
|
||||
pad_mode=self.pad_mode,
|
||||
normalized=self.normalized,
|
||||
onesided=self.onesided,
|
||||
return_complex=False
|
||||
)
|
||||
else:
|
||||
# NOTE(sx): Use Kaldi-fasion padding, maybe wrong
|
||||
num_pads = self.n_fft - self.win_length
|
||||
input = torch.nn.functional.pad(input, (num_pads, 0))
|
||||
output = torch.stft(
|
||||
input,
|
||||
n_fft=self.n_fft,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length,
|
||||
center=False,
|
||||
pad_mode=self.pad_mode,
|
||||
normalized=self.normalized,
|
||||
onesided=self.onesided,
|
||||
return_complex=False
|
||||
)
|
||||
|
||||
# output: (Batch, Freq, Frames, 2=real_imag)
|
||||
# -> (Batch, Frames, Freq, 2=real_imag)
|
||||
output = output.transpose(1, 2)
|
||||
if multi_channel:
|
||||
# output: (Batch * Channel, Frames, Freq, 2=real_imag)
|
||||
# -> (Batch, Frame, Channel, Freq, 2=real_imag)
|
||||
output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
|
||||
1, 2
|
||||
)
|
||||
|
||||
if ilens is not None:
|
||||
if self.center:
|
||||
pad = self.win_length // 2
|
||||
ilens = ilens + 2 * pad
|
||||
olens = torch.div(ilens - self.win_length, self.hop_length, rounding_mode='floor') + 1
|
||||
# olens = ilens - self.win_length // self.hop_length + 1
|
||||
output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
|
||||
else:
|
||||
olens = None
|
||||
|
||||
return output, olens
|
||||
82
models/ppg_extractor/utterance_mvn.py
Normal file
82
models/ppg_extractor/utterance_mvn.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class UtteranceMVN(torch.nn.Module):
|
||||
def __init__(
|
||||
self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_means = norm_means
|
||||
self.norm_vars = norm_vars
|
||||
self.eps = eps
|
||||
|
||||
def extra_repr(self):
|
||||
return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, ilens: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward function
|
||||
|
||||
Args:
|
||||
x: (B, L, ...)
|
||||
ilens: (B,)
|
||||
|
||||
"""
|
||||
return utterance_mvn(
|
||||
x,
|
||||
ilens,
|
||||
norm_means=self.norm_means,
|
||||
norm_vars=self.norm_vars,
|
||||
eps=self.eps,
|
||||
)
|
||||
|
||||
|
||||
def utterance_mvn(
|
||||
x: torch.Tensor,
|
||||
ilens: torch.Tensor = None,
|
||||
norm_means: bool = True,
|
||||
norm_vars: bool = False,
|
||||
eps: float = 1.0e-20,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Apply utterance mean and variance normalization
|
||||
|
||||
Args:
|
||||
x: (B, T, D), assumed zero padded
|
||||
ilens: (B,)
|
||||
norm_means:
|
||||
norm_vars:
|
||||
eps:
|
||||
|
||||
"""
|
||||
if ilens is None:
|
||||
ilens = x.new_full([x.size(0)], x.size(1))
|
||||
ilens_ = ilens.to(x.device, x.dtype).view(-1, *[1 for _ in range(x.dim() - 1)])
|
||||
# Zero padding
|
||||
if x.requires_grad:
|
||||
x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
|
||||
else:
|
||||
x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
|
||||
# mean: (B, 1, D)
|
||||
mean = x.sum(dim=1, keepdim=True) / ilens_
|
||||
|
||||
if norm_means:
|
||||
x -= mean
|
||||
|
||||
if norm_vars:
|
||||
var = x.pow(2).sum(dim=1, keepdim=True) / ilens_
|
||||
std = torch.clamp(var.sqrt(), min=eps)
|
||||
x = x / std.sqrt()
|
||||
return x, ilens
|
||||
else:
|
||||
if norm_vars:
|
||||
y = x - mean
|
||||
y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0)
|
||||
var = y.pow(2).sum(dim=1, keepdim=True) / ilens_
|
||||
std = torch.clamp(var.sqrt(), min=eps)
|
||||
x /= std
|
||||
return x, ilens
|
||||
24
models/synthesizer/LICENSE.txt
Normal file
24
models/synthesizer/LICENSE.txt
Normal file
@@ -0,0 +1,24 @@
|
||||
MIT License
|
||||
|
||||
Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah)
|
||||
Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
|
||||
Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
|
||||
Modified work Copyright (c) 2020 blue-fish (https://github.com/blue-fish)
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
0
models/synthesizer/__init__.py
Normal file
0
models/synthesizer/__init__.py
Normal file
206
models/synthesizer/audio.py
Normal file
206
models/synthesizer/audio.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import librosa
|
||||
import librosa.filters
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
from scipy.io import wavfile
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def load_wav(path, sr):
|
||||
return librosa.core.load(path, sr=sr)[0]
|
||||
|
||||
def save_wav(wav, path, sr):
|
||||
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||
#proposed by @dsmiller
|
||||
wavfile.write(path, sr, wav.astype(np.int16))
|
||||
|
||||
def save_wavenet_wav(wav, path, sr):
|
||||
sf.write(path, wav.astype(np.float32), sr)
|
||||
|
||||
def preemphasis(wav, k, preemphasize=True):
|
||||
if preemphasize:
|
||||
return signal.lfilter([1, -k], [1], wav)
|
||||
return wav
|
||||
|
||||
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
||||
if inv_preemphasize:
|
||||
return signal.lfilter([1], [1, -k], wav)
|
||||
return wav
|
||||
|
||||
#From https://github.com/r9y9/wavenet_vocoder/blob/master/audio.py
|
||||
def start_and_end_indices(quantized, silence_threshold=2):
|
||||
for start in range(quantized.size):
|
||||
if abs(quantized[start] - 127) > silence_threshold:
|
||||
break
|
||||
for end in range(quantized.size - 1, 1, -1):
|
||||
if abs(quantized[end] - 127) > silence_threshold:
|
||||
break
|
||||
|
||||
assert abs(quantized[start] - 127) > silence_threshold
|
||||
assert abs(quantized[end] - 127) > silence_threshold
|
||||
|
||||
return start, end
|
||||
|
||||
def get_hop_size(hparams):
|
||||
hop_size = hparams.hop_size
|
||||
if hop_size is None:
|
||||
assert hparams.frame_shift_ms is not None
|
||||
hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
|
||||
return hop_size
|
||||
|
||||
def linearspectrogram(wav, hparams):
|
||||
D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
|
||||
S = _amp_to_db(np.abs(D), hparams) - hparams.ref_level_db
|
||||
|
||||
if hparams.signal_normalization:
|
||||
return _normalize(S, hparams)
|
||||
return S
|
||||
|
||||
def melspectrogram(wav, hparams):
|
||||
D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
|
||||
S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams.ref_level_db
|
||||
|
||||
if hparams.signal_normalization:
|
||||
return _normalize(S, hparams)
|
||||
return S
|
||||
|
||||
def inv_linear_spectrogram(linear_spectrogram, hparams):
|
||||
"""Converts linear spectrogram to waveform using librosa"""
|
||||
if hparams.signal_normalization:
|
||||
D = _denormalize(linear_spectrogram, hparams)
|
||||
else:
|
||||
D = linear_spectrogram
|
||||
|
||||
S = _db_to_amp(D + hparams.ref_level_db) #Convert back to linear
|
||||
|
||||
if hparams.use_lws:
|
||||
processor = _lws_processor(hparams)
|
||||
D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
|
||||
y = processor.istft(D).astype(np.float32)
|
||||
return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
|
||||
else:
|
||||
return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
|
||||
|
||||
def inv_mel_spectrogram(mel_spectrogram, hparams):
|
||||
"""Converts mel spectrogram to waveform using librosa"""
|
||||
if hparams.signal_normalization:
|
||||
D = _denormalize(mel_spectrogram, hparams)
|
||||
else:
|
||||
D = mel_spectrogram
|
||||
|
||||
S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db), hparams) # Convert back to linear
|
||||
|
||||
if hparams.use_lws:
|
||||
processor = _lws_processor(hparams)
|
||||
D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
|
||||
y = processor.istft(D).astype(np.float32)
|
||||
return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
|
||||
else:
|
||||
return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
|
||||
|
||||
def _lws_processor(hparams):
|
||||
import lws
|
||||
return lws.lws(hparams.n_fft, get_hop_size(hparams), fftsize=hparams.win_size, mode="speech")
|
||||
|
||||
def _griffin_lim(S, hparams):
|
||||
"""librosa implementation of Griffin-Lim
|
||||
Based on https://github.com/librosa/librosa/issues/434
|
||||
"""
|
||||
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
||||
S_complex = np.abs(S).astype(np.complex)
|
||||
y = _istft(S_complex * angles, hparams)
|
||||
for i in range(hparams.griffin_lim_iters):
|
||||
angles = np.exp(1j * np.angle(_stft(y, hparams)))
|
||||
y = _istft(S_complex * angles, hparams)
|
||||
return y
|
||||
|
||||
def _stft(y, hparams):
|
||||
if hparams.use_lws:
|
||||
return _lws_processor(hparams).stft(y).T
|
||||
else:
|
||||
return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
|
||||
|
||||
def _istft(y, hparams):
|
||||
return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
|
||||
|
||||
##########################################################
|
||||
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
||||
def num_frames(length, fsize, fshift):
|
||||
"""Compute number of time frames of spectrogram
|
||||
"""
|
||||
pad = (fsize - fshift)
|
||||
if length % fshift == 0:
|
||||
M = (length + pad * 2 - fsize) // fshift + 1
|
||||
else:
|
||||
M = (length + pad * 2 - fsize) // fshift + 2
|
||||
return M
|
||||
|
||||
|
||||
def pad_lr(x, fsize, fshift):
|
||||
"""Compute left and right padding
|
||||
"""
|
||||
M = num_frames(len(x), fsize, fshift)
|
||||
pad = (fsize - fshift)
|
||||
T = len(x) + 2 * pad
|
||||
r = (M - 1) * fshift + fsize - T
|
||||
return pad, pad + r
|
||||
##########################################################
|
||||
#Librosa correct padding
|
||||
def librosa_pad_lr(x, fsize, fshift):
|
||||
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
||||
|
||||
# Conversions
|
||||
_mel_basis = None
|
||||
_inv_mel_basis = None
|
||||
|
||||
def _linear_to_mel(spectogram, hparams):
|
||||
global _mel_basis
|
||||
if _mel_basis is None:
|
||||
_mel_basis = _build_mel_basis(hparams)
|
||||
return np.dot(_mel_basis, spectogram)
|
||||
|
||||
def _mel_to_linear(mel_spectrogram, hparams):
|
||||
global _inv_mel_basis
|
||||
if _inv_mel_basis is None:
|
||||
_inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams))
|
||||
return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
|
||||
|
||||
def _build_mel_basis(hparams):
|
||||
assert hparams.fmax <= hparams.sample_rate // 2
|
||||
return librosa.filters.mel(sr=hparams.sample_rate, n_fft=hparams.n_fft, n_mels=hparams.num_mels,
|
||||
fmin=hparams.fmin, fmax=hparams.fmax)
|
||||
|
||||
def _amp_to_db(x, hparams):
|
||||
min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
|
||||
return 20 * np.log10(np.maximum(min_level, x))
|
||||
|
||||
def _db_to_amp(x):
|
||||
return np.power(10.0, (x) * 0.05)
|
||||
|
||||
def _normalize(S, hparams):
|
||||
if hparams.allow_clipping_in_normalization:
|
||||
if hparams.symmetric_mels:
|
||||
return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value,
|
||||
-hparams.max_abs_value, hparams.max_abs_value)
|
||||
else:
|
||||
return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value)
|
||||
|
||||
assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
|
||||
if hparams.symmetric_mels:
|
||||
return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value
|
||||
else:
|
||||
return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db))
|
||||
|
||||
def _denormalize(D, hparams):
|
||||
if hparams.allow_clipping_in_normalization:
|
||||
if hparams.symmetric_mels:
|
||||
return (((np.clip(D, -hparams.max_abs_value,
|
||||
hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
|
||||
+ hparams.min_level_db)
|
||||
else:
|
||||
return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
|
||||
|
||||
if hparams.symmetric_mels:
|
||||
return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
|
||||
else:
|
||||
return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
|
||||
13
models/synthesizer/gst_hyperparameters.py
Normal file
13
models/synthesizer/gst_hyperparameters.py
Normal file
@@ -0,0 +1,13 @@
|
||||
class GSTHyperparameters():
|
||||
E = 512
|
||||
|
||||
# reference encoder
|
||||
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
||||
|
||||
# style token layer
|
||||
token_num = 10
|
||||
# token_emb_size = 256
|
||||
num_heads = 8
|
||||
|
||||
n_mels = 256 # Number of Mel banks to generate
|
||||
|
||||
110
models/synthesizer/hparams.py
Normal file
110
models/synthesizer/hparams.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import ast
|
||||
import pprint
|
||||
import json
|
||||
|
||||
class HParams(object):
|
||||
def __init__(self, **kwargs): self.__dict__.update(kwargs)
|
||||
def __setitem__(self, key, value): setattr(self, key, value)
|
||||
def __getitem__(self, key): return getattr(self, key)
|
||||
def __repr__(self): return pprint.pformat(self.__dict__)
|
||||
|
||||
def parse(self, string):
|
||||
# Overrides hparams from a comma-separated string of name=value pairs
|
||||
if len(string) > 0:
|
||||
overrides = [s.split("=") for s in string.split(",")]
|
||||
keys, values = zip(*overrides)
|
||||
keys = list(map(str.strip, keys))
|
||||
values = list(map(str.strip, values))
|
||||
for k in keys:
|
||||
self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
|
||||
return self
|
||||
|
||||
def loadJson(self, dict):
|
||||
print("\Loading the json with %s\n", dict)
|
||||
for k in dict.keys():
|
||||
if k not in ["tts_schedule", "tts_finetune_layers"]:
|
||||
self.__dict__[k] = dict[k]
|
||||
return self
|
||||
|
||||
def dumpJson(self, fp):
|
||||
print("\Saving the json with %s\n", fp)
|
||||
with fp.open("w", encoding="utf-8") as f:
|
||||
json.dump(self.__dict__, f)
|
||||
return self
|
||||
|
||||
hparams = HParams(
|
||||
### Signal Processing (used in both synthesizer and vocoder)
|
||||
sample_rate = 16000,
|
||||
n_fft = 800,
|
||||
num_mels = 80,
|
||||
hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
|
||||
win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
|
||||
fmin = 55,
|
||||
min_level_db = -100,
|
||||
ref_level_db = 20,
|
||||
max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small.
|
||||
preemphasis = 0.97, # Filter coefficient to use if preemphasize is True
|
||||
preemphasize = True,
|
||||
|
||||
### Tacotron Text-to-Speech (TTS)
|
||||
tts_embed_dims = 512, # Embedding dimension for the graphemes/phoneme inputs
|
||||
tts_encoder_dims = 256,
|
||||
tts_decoder_dims = 128,
|
||||
tts_postnet_dims = 512,
|
||||
tts_encoder_K = 5,
|
||||
tts_lstm_dims = 1024,
|
||||
tts_postnet_K = 5,
|
||||
tts_num_highways = 4,
|
||||
tts_dropout = 0.5,
|
||||
tts_cleaner_names = ["basic_cleaners"],
|
||||
tts_stop_threshold = -3.4, # Value below which audio generation ends.
|
||||
# For example, for a range of [-4, 4], this
|
||||
# will terminate the sequence at the first
|
||||
# frame that has all values < -3.4
|
||||
|
||||
### Tacotron Training
|
||||
tts_schedule = [(2, 1e-3, 10_000, 12), # Progressive training schedule
|
||||
(2, 5e-4, 15_000, 12), # (r, lr, step, batch_size)
|
||||
(2, 2e-4, 20_000, 12), # (r, lr, step, batch_size)
|
||||
(2, 1e-4, 30_000, 12), #
|
||||
(2, 5e-5, 40_000, 12), #
|
||||
(2, 1e-5, 60_000, 12), #
|
||||
(2, 5e-6, 160_000, 12), # r = reduction factor (# of mel frames
|
||||
(2, 3e-6, 320_000, 12), # synthesized for each decoder iteration)
|
||||
(2, 1e-6, 640_000, 12)], # lr = learning rate
|
||||
|
||||
tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
|
||||
tts_eval_interval = 500, # Number of steps between model evaluation (sample generation)
|
||||
# Set to -1 to generate after completing epoch, or 0 to disable
|
||||
tts_eval_num_samples = 1, # Makes this number of samples
|
||||
|
||||
## For finetune usage, if set, only selected layers will be trained, available: encoder,encoder_proj,gst,decoder,postnet,post_proj
|
||||
tts_finetune_layers = [],
|
||||
|
||||
### Data Preprocessing
|
||||
max_mel_frames = 900,
|
||||
rescale = True,
|
||||
rescaling_max = 0.9,
|
||||
synthesis_batch_size = 16, # For vocoder preprocessing and inference.
|
||||
|
||||
### Mel Visualization and Griffin-Lim
|
||||
signal_normalization = True,
|
||||
power = 1.5,
|
||||
griffin_lim_iters = 60,
|
||||
|
||||
### Audio processing options
|
||||
fmax = 7600, # Should not exceed (sample_rate // 2)
|
||||
allow_clipping_in_normalization = True, # Used when signal_normalization = True
|
||||
clip_mels_length = True, # If true, discards samples exceeding max_mel_frames
|
||||
use_lws = False, # "Fast spectrogram phase recovery using local weighted sums"
|
||||
symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,
|
||||
# and [0, max_abs_value] if False
|
||||
trim_silence = True, # Use with sample_rate of 16000 for best results
|
||||
|
||||
### SV2TTS
|
||||
speaker_embedding_size = 256, # Dimension for the speaker embedding
|
||||
silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
|
||||
utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded
|
||||
use_gst = True, # Whether to use global style token
|
||||
use_ser_for_gst = True, # Whether to use speaker embedding referenced for global style token
|
||||
)
|
||||
187
models/synthesizer/inference.py
Normal file
187
models/synthesizer/inference.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import torch
|
||||
from models.synthesizer import audio
|
||||
from models.synthesizer.hparams import hparams
|
||||
from models.synthesizer.models.tacotron import Tacotron
|
||||
from models.synthesizer.utils.symbols import symbols
|
||||
from models.synthesizer.utils.text import text_to_sequence
|
||||
from models.vocoder.display import simple_table
|
||||
from pathlib import Path
|
||||
from typing import Union, List
|
||||
import numpy as np
|
||||
import librosa
|
||||
from utils import logmmse
|
||||
import json
|
||||
from pypinyin import lazy_pinyin, Style
|
||||
|
||||
class Synthesizer:
|
||||
sample_rate = hparams.sample_rate
|
||||
hparams = hparams
|
||||
|
||||
def __init__(self, model_fpath: Path, verbose=True):
|
||||
"""
|
||||
The model isn't instantiated and loaded in memory until needed or until load() is called.
|
||||
|
||||
:param model_fpath: path to the trained model file
|
||||
:param verbose: if False, prints less information when using the model
|
||||
"""
|
||||
self.model_fpath = model_fpath
|
||||
self.verbose = verbose
|
||||
|
||||
# Check for GPU
|
||||
if torch.cuda.is_available():
|
||||
self.device = torch.device("cuda")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
if self.verbose:
|
||||
print("Synthesizer using device:", self.device)
|
||||
|
||||
# Tacotron model will be instantiated later on first use.
|
||||
self._model = None
|
||||
|
||||
def is_loaded(self):
|
||||
"""
|
||||
Whether the model is loaded in memory.
|
||||
"""
|
||||
return self._model is not None
|
||||
|
||||
def load(self):
|
||||
# Try to scan config file
|
||||
model_config_fpaths = list(self.model_fpath.parent.rglob("*.json"))
|
||||
if len(model_config_fpaths)>0 and model_config_fpaths[0].exists():
|
||||
with model_config_fpaths[0].open("r", encoding="utf-8") as f:
|
||||
hparams.loadJson(json.load(f))
|
||||
"""
|
||||
Instantiates and loads the model given the weights file that was passed in the constructor.
|
||||
"""
|
||||
self._model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
||||
num_chars=len(symbols),
|
||||
encoder_dims=hparams.tts_encoder_dims,
|
||||
decoder_dims=hparams.tts_decoder_dims,
|
||||
n_mels=hparams.num_mels,
|
||||
fft_bins=hparams.num_mels,
|
||||
postnet_dims=hparams.tts_postnet_dims,
|
||||
encoder_K=hparams.tts_encoder_K,
|
||||
lstm_dims=hparams.tts_lstm_dims,
|
||||
postnet_K=hparams.tts_postnet_K,
|
||||
num_highways=hparams.tts_num_highways,
|
||||
dropout=hparams.tts_dropout,
|
||||
stop_threshold=hparams.tts_stop_threshold,
|
||||
speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
|
||||
|
||||
self._model.load(self.model_fpath, self.device)
|
||||
self._model.eval()
|
||||
|
||||
if self.verbose:
|
||||
print("Loaded synthesizer \"%s\" trained to step %d" % (self.model_fpath.name, self._model.state_dict()["step"]))
|
||||
|
||||
def synthesize_spectrograms(self, texts: List[str],
|
||||
embeddings: Union[np.ndarray, List[np.ndarray]],
|
||||
return_alignments=False, style_idx=0, min_stop_token=5, steps=2000):
|
||||
"""
|
||||
Synthesizes mel spectrograms from texts and speaker embeddings.
|
||||
|
||||
:param texts: a list of N text prompts to be synthesized
|
||||
:param embeddings: a numpy array or list of speaker embeddings of shape (N, 256)
|
||||
:param return_alignments: if True, a matrix representing the alignments between the
|
||||
characters
|
||||
and each decoder output step will be returned for each spectrogram
|
||||
:return: a list of N melspectrograms as numpy arrays of shape (80, Mi), where Mi is the
|
||||
sequence length of spectrogram i, and possibly the alignments.
|
||||
"""
|
||||
# Load the model on the first request.
|
||||
if not self.is_loaded():
|
||||
self.load()
|
||||
|
||||
# Print some info about the model when it is loaded
|
||||
tts_k = self._model.get_step() // 1000
|
||||
|
||||
simple_table([("Tacotron", str(tts_k) + "k"),
|
||||
("r", self._model.r)])
|
||||
|
||||
print("Read " + str(texts))
|
||||
texts = [" ".join(lazy_pinyin(v, style=Style.TONE3, neutral_tone_with_five=True)) for v in texts]
|
||||
print("Synthesizing " + str(texts))
|
||||
# Preprocess text inputs
|
||||
inputs = [text_to_sequence(text, hparams.tts_cleaner_names) for text in texts]
|
||||
if not isinstance(embeddings, list):
|
||||
embeddings = [embeddings]
|
||||
|
||||
# Batch inputs
|
||||
batched_inputs = [inputs[i:i+hparams.synthesis_batch_size]
|
||||
for i in range(0, len(inputs), hparams.synthesis_batch_size)]
|
||||
batched_embeds = [embeddings[i:i+hparams.synthesis_batch_size]
|
||||
for i in range(0, len(embeddings), hparams.synthesis_batch_size)]
|
||||
|
||||
specs = []
|
||||
for i, batch in enumerate(batched_inputs, 1):
|
||||
if self.verbose:
|
||||
print(f"\n| Generating {i}/{len(batched_inputs)}")
|
||||
|
||||
# Pad texts so they are all the same length
|
||||
text_lens = [len(text) for text in batch]
|
||||
max_text_len = max(text_lens)
|
||||
chars = [pad1d(text, max_text_len) for text in batch]
|
||||
chars = np.stack(chars)
|
||||
|
||||
# Stack speaker embeddings into 2D array for batch processing
|
||||
speaker_embeds = np.stack(batched_embeds[i-1])
|
||||
|
||||
# Convert to tensor
|
||||
chars = torch.tensor(chars).long().to(self.device)
|
||||
speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
|
||||
|
||||
# Inference
|
||||
_, mels, alignments = self._model.generate(chars, speaker_embeddings, style_idx=style_idx, min_stop_token=min_stop_token, steps=steps)
|
||||
mels = mels.detach().cpu().numpy()
|
||||
for m in mels:
|
||||
# Trim silence from end of each spectrogram
|
||||
while np.max(m[:, -1]) < hparams.tts_stop_threshold:
|
||||
m = m[:, :-1]
|
||||
specs.append(m)
|
||||
|
||||
if self.verbose:
|
||||
print("\n\nDone.\n")
|
||||
return (specs, alignments) if return_alignments else specs
|
||||
|
||||
@staticmethod
|
||||
def load_preprocess_wav(fpath):
|
||||
"""
|
||||
Loads and preprocesses an audio file under the same conditions the audio files were used to
|
||||
train the synthesizer.
|
||||
"""
|
||||
wav = librosa.load(path=str(fpath), sr=hparams.sample_rate)[0]
|
||||
if hparams.rescale:
|
||||
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
||||
# denoise
|
||||
if len(wav) > hparams.sample_rate*(0.3+0.1):
|
||||
noise_wav = np.concatenate([wav[:int(hparams.sample_rate*0.15)],
|
||||
wav[-int(hparams.sample_rate*0.15):]])
|
||||
profile = logmmse.profile_noise(noise_wav, hparams.sample_rate)
|
||||
wav = logmmse.denoise(wav, profile)
|
||||
return wav
|
||||
|
||||
@staticmethod
|
||||
def make_spectrogram(fpath_or_wav: Union[str, Path, np.ndarray]):
|
||||
"""
|
||||
Creates a mel spectrogram from an audio file in the same manner as the mel spectrograms that
|
||||
were fed to the synthesizer when training.
|
||||
"""
|
||||
if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
|
||||
wav = Synthesizer.load_preprocess_wav(fpath_or_wav)
|
||||
else:
|
||||
wav = fpath_or_wav
|
||||
|
||||
mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
|
||||
return mel_spectrogram
|
||||
|
||||
@staticmethod
|
||||
def griffin_lim(mel):
|
||||
"""
|
||||
Inverts a mel spectrogram using Griffin-Lim. The mel spectrogram is expected to have been built
|
||||
with the same parameters present in hparams.py.
|
||||
"""
|
||||
return audio.inv_mel_spectrogram(mel, hparams)
|
||||
|
||||
|
||||
def pad1d(x, max_len, pad_value=0):
|
||||
return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
|
||||
73
models/synthesizer/models/base.py
Normal file
73
models/synthesizer/models/base.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import imp
|
||||
import numpy as np
|
||||
|
||||
class Base(nn.Module):
|
||||
def __init__(self, stop_threshold):
|
||||
super().__init__()
|
||||
|
||||
self.init_model()
|
||||
self.num_params()
|
||||
|
||||
self.register_buffer("step", torch.zeros(1, dtype=torch.long))
|
||||
self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
|
||||
|
||||
@property
|
||||
def r(self):
|
||||
return self.decoder.r.item()
|
||||
|
||||
@r.setter
|
||||
def r(self, value):
|
||||
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
||||
|
||||
def init_model(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1: nn.init.xavier_uniform_(p)
|
||||
|
||||
def finetune_partial(self, whitelist_layers):
|
||||
self.zero_grad()
|
||||
for name, child in self.named_children():
|
||||
if name in whitelist_layers:
|
||||
print("Trainable Layer: %s" % name)
|
||||
print("Trainable Parameters: %.3f" % sum([np.prod(p.size()) for p in child.parameters()]))
|
||||
for param in child.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def get_step(self):
|
||||
return self.step.data.item()
|
||||
|
||||
def reset_step(self):
|
||||
# assignment to parameters or buffers is overloaded, updates internal dict entry
|
||||
self.step = self.step.data.new_tensor(1)
|
||||
|
||||
def log(self, path, msg):
|
||||
with open(path, "a") as f:
|
||||
print(msg, file=f)
|
||||
|
||||
def load(self, path, device, optimizer=None):
|
||||
# Use device of model params as location for loaded state
|
||||
checkpoint = torch.load(str(path), map_location=device)
|
||||
self.load_state_dict(checkpoint["model_state"], strict=False)
|
||||
|
||||
if "optimizer_state" in checkpoint and optimizer is not None:
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
||||
|
||||
def save(self, path, optimizer=None):
|
||||
if optimizer is not None:
|
||||
torch.save({
|
||||
"model_state": self.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
}, str(path))
|
||||
else:
|
||||
torch.save({
|
||||
"model_state": self.state_dict(),
|
||||
}, str(path))
|
||||
|
||||
|
||||
def num_params(self, print_out=True):
|
||||
parameters = filter(lambda p: p.requires_grad, self.parameters())
|
||||
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
||||
if print_out:
|
||||
print("Trainable Parameters: %.3fM" % parameters)
|
||||
return parameters
|
||||
1
models/synthesizer/models/sublayer/__init__.py
Normal file
1
models/synthesizer/models/sublayer/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
#
|
||||
85
models/synthesizer/models/sublayer/cbhg.py
Normal file
85
models/synthesizer/models/sublayer/cbhg.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .common.batch_norm_conv import BatchNormConv
|
||||
from .common.highway_network import HighwayNetwork
|
||||
|
||||
class CBHG(nn.Module):
|
||||
def __init__(self, K, in_channels, channels, proj_channels, num_highways):
|
||||
super().__init__()
|
||||
|
||||
# List of all rnns to call `flatten_parameters()` on
|
||||
self._to_flatten = []
|
||||
|
||||
self.bank_kernels = [i for i in range(1, K + 1)]
|
||||
self.conv1d_bank = nn.ModuleList()
|
||||
for k in self.bank_kernels:
|
||||
conv = BatchNormConv(in_channels, channels, k)
|
||||
self.conv1d_bank.append(conv)
|
||||
|
||||
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
||||
|
||||
self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
|
||||
self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
|
||||
|
||||
# Fix the highway input if necessary
|
||||
if proj_channels[-1] != channels:
|
||||
self.highway_mismatch = True
|
||||
self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
|
||||
else:
|
||||
self.highway_mismatch = False
|
||||
|
||||
self.highways = nn.ModuleList()
|
||||
for i in range(num_highways):
|
||||
hn = HighwayNetwork(channels)
|
||||
self.highways.append(hn)
|
||||
|
||||
self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
|
||||
self._to_flatten.append(self.rnn)
|
||||
|
||||
# Avoid fragmentation of RNN parameters and associated warning
|
||||
self._flatten_parameters()
|
||||
|
||||
def forward(self, x):
|
||||
# Although we `_flatten_parameters()` on init, when using DataParallel
|
||||
# the model gets replicated, making it no longer guaranteed that the
|
||||
# weights are contiguous in GPU memory. Hence, we must call it again
|
||||
self.rnn.flatten_parameters()
|
||||
|
||||
# Save these for later
|
||||
residual = x
|
||||
seq_len = x.size(-1)
|
||||
conv_bank = []
|
||||
|
||||
# Convolution Bank
|
||||
for conv in self.conv1d_bank:
|
||||
c = conv(x) # Convolution
|
||||
conv_bank.append(c[:, :, :seq_len])
|
||||
|
||||
# Stack along the channel axis
|
||||
conv_bank = torch.cat(conv_bank, dim=1)
|
||||
|
||||
# dump the last padding to fit residual
|
||||
x = self.maxpool(conv_bank)[:, :, :seq_len]
|
||||
|
||||
# Conv1d projections
|
||||
x = self.conv_project1(x)
|
||||
x = self.conv_project2(x)
|
||||
|
||||
# Residual Connect
|
||||
x = x + residual
|
||||
|
||||
# Through the highways
|
||||
x = x.transpose(1, 2)
|
||||
if self.highway_mismatch is True:
|
||||
x = self.pre_highway(x)
|
||||
for h in self.highways: x = h(x)
|
||||
|
||||
# And then the RNN
|
||||
x, _ = self.rnn(x)
|
||||
return x
|
||||
|
||||
def _flatten_parameters(self):
|
||||
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
|
||||
to improve efficiency and avoid PyTorch yelling at us."""
|
||||
[m.flatten_parameters() for m in self._to_flatten]
|
||||
|
||||
14
models/synthesizer/models/sublayer/common/batch_norm_conv.py
Normal file
14
models/synthesizer/models/sublayer/common/batch_norm_conv.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class BatchNormConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel, relu=True):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
|
||||
self.bnorm = nn.BatchNorm1d(out_channels)
|
||||
self.relu = relu
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = F.relu(x) if self.relu is True else x
|
||||
return self.bnorm(x)
|
||||
17
models/synthesizer/models/sublayer/common/highway_network.py
Normal file
17
models/synthesizer/models/sublayer/common/highway_network.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class HighwayNetwork(nn.Module):
|
||||
def __init__(self, size):
|
||||
super().__init__()
|
||||
self.W1 = nn.Linear(size, size)
|
||||
self.W2 = nn.Linear(size, size)
|
||||
self.W1.bias.data.fill_(0.)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.W1(x)
|
||||
x2 = self.W2(x)
|
||||
g = torch.sigmoid(x2)
|
||||
y = g * F.relu(x1) + (1. - g) * x
|
||||
return y
|
||||
145
models/synthesizer/models/sublayer/global_style_token.py
Normal file
145
models/synthesizer/models/sublayer/global_style_token.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
import torch.nn.functional as tFunctional
|
||||
from models.synthesizer.gst_hyperparameters import GSTHyperparameters as hp
|
||||
from models.synthesizer.hparams import hparams
|
||||
|
||||
|
||||
class GlobalStyleToken(nn.Module):
|
||||
"""
|
||||
inputs: style mel spectrograms [batch_size, num_spec_frames, num_mel]
|
||||
speaker_embedding: speaker mel spectrograms [batch_size, num_spec_frames, num_mel]
|
||||
outputs: [batch_size, embedding_dim]
|
||||
"""
|
||||
def __init__(self, speaker_embedding_dim=None):
|
||||
|
||||
super().__init__()
|
||||
self.encoder = ReferenceEncoder()
|
||||
self.stl = STL(speaker_embedding_dim)
|
||||
|
||||
def forward(self, inputs, speaker_embedding=None):
|
||||
enc_out = self.encoder(inputs)
|
||||
# concat speaker_embedding according to https://github.com/mozilla/TTS/blob/master/TTS/tts/layers/gst_layers.py
|
||||
if hparams.use_ser_for_gst and speaker_embedding is not None:
|
||||
enc_out = torch.cat([enc_out, speaker_embedding], dim=-1)
|
||||
style_embed = self.stl(enc_out)
|
||||
|
||||
return style_embed
|
||||
|
||||
|
||||
class ReferenceEncoder(nn.Module):
|
||||
'''
|
||||
inputs --- [N, Ty/r, n_mels*r] mels
|
||||
outputs --- [N, ref_enc_gru_size]
|
||||
'''
|
||||
|
||||
def __init__(self):
|
||||
|
||||
super().__init__()
|
||||
K = len(hp.ref_enc_filters)
|
||||
filters = [1] + hp.ref_enc_filters
|
||||
convs = [nn.Conv2d(in_channels=filters[i],
|
||||
out_channels=filters[i + 1],
|
||||
kernel_size=(3, 3),
|
||||
stride=(2, 2),
|
||||
padding=(1, 1)) for i in range(K)]
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=hp.ref_enc_filters[i]) for i in range(K)])
|
||||
|
||||
out_channels = self.calculate_channels(hp.n_mels, 3, 2, 1, K)
|
||||
self.gru = nn.GRU(input_size=hp.ref_enc_filters[-1] * out_channels,
|
||||
hidden_size=hp.E // 2,
|
||||
batch_first=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
N = inputs.size(0)
|
||||
out = inputs.view(N, 1, -1, hp.n_mels) # [N, 1, Ty, n_mels]
|
||||
for conv, bn in zip(self.convs, self.bns):
|
||||
out = conv(out)
|
||||
out = bn(out)
|
||||
out = tFunctional.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
||||
|
||||
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
||||
T = out.size(1)
|
||||
N = out.size(0)
|
||||
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
||||
|
||||
self.gru.flatten_parameters()
|
||||
memory, out = self.gru(out) # out --- [1, N, E//2]
|
||||
|
||||
return out.squeeze(0)
|
||||
|
||||
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
|
||||
for i in range(n_convs):
|
||||
L = (L - kernel_size + 2 * pad) // stride + 1
|
||||
return L
|
||||
|
||||
|
||||
class STL(nn.Module):
|
||||
'''
|
||||
inputs --- [N, E//2]
|
||||
'''
|
||||
|
||||
def __init__(self, speaker_embedding_dim=None):
|
||||
|
||||
super().__init__()
|
||||
self.embed = nn.Parameter(torch.FloatTensor(hp.token_num, hp.E // hp.num_heads))
|
||||
d_q = hp.E // 2
|
||||
d_k = hp.E // hp.num_heads
|
||||
# self.attention = MultiHeadAttention(hp.num_heads, d_model, d_q, d_v)
|
||||
if hparams.use_ser_for_gst and speaker_embedding_dim is not None:
|
||||
d_q += speaker_embedding_dim
|
||||
self.attention = MultiHeadAttention(query_dim=d_q, key_dim=d_k, num_units=hp.E, num_heads=hp.num_heads)
|
||||
|
||||
init.normal_(self.embed, mean=0, std=0.5)
|
||||
|
||||
def forward(self, inputs):
|
||||
N = inputs.size(0)
|
||||
query = inputs.unsqueeze(1) # [N, 1, E//2]
|
||||
keys = torch.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads]
|
||||
style_embed = self.attention(query, keys)
|
||||
|
||||
return style_embed
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
'''
|
||||
input:
|
||||
query --- [N, T_q, query_dim]
|
||||
key --- [N, T_k, key_dim]
|
||||
output:
|
||||
out --- [N, T_q, num_units]
|
||||
'''
|
||||
|
||||
def __init__(self, query_dim, key_dim, num_units, num_heads):
|
||||
|
||||
super().__init__()
|
||||
self.num_units = num_units
|
||||
self.num_heads = num_heads
|
||||
self.key_dim = key_dim
|
||||
|
||||
self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
|
||||
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||
self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||
|
||||
def forward(self, query, key):
|
||||
querys = self.W_query(query) # [N, T_q, num_units]
|
||||
keys = self.W_key(key) # [N, T_k, num_units]
|
||||
values = self.W_value(key)
|
||||
|
||||
split_size = self.num_units // self.num_heads
|
||||
querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h]
|
||||
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||
|
||||
# score = softmax(QK^T / (d_k ** 0.5))
|
||||
scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
|
||||
scores = scores / (self.key_dim ** 0.5)
|
||||
scores = tFunctional.softmax(scores, dim=3)
|
||||
|
||||
# out = score * V
|
||||
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
|
||||
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
|
||||
|
||||
return out
|
||||
42
models/synthesizer/models/sublayer/lsa.py
Normal file
42
models/synthesizer/models/sublayer/lsa.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class LSA(nn.Module):
|
||||
def __init__(self, attn_dim, kernel_size=31, filters=32):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
|
||||
self.L = nn.Linear(filters, attn_dim, bias=False)
|
||||
self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
|
||||
self.v = nn.Linear(attn_dim, 1, bias=False)
|
||||
self.cumulative = None
|
||||
self.attention = None
|
||||
|
||||
def init_attention(self, encoder_seq_proj):
|
||||
device = encoder_seq_proj.device # use same device as parameters
|
||||
b, t, c = encoder_seq_proj.size()
|
||||
self.cumulative = torch.zeros(b, t, device=device)
|
||||
self.attention = torch.zeros(b, t, device=device)
|
||||
|
||||
def forward(self, encoder_seq_proj, query, times, chars):
|
||||
|
||||
if times == 0: self.init_attention(encoder_seq_proj)
|
||||
|
||||
processed_query = self.W(query).unsqueeze(1)
|
||||
|
||||
location = self.cumulative.unsqueeze(1)
|
||||
processed_loc = self.L(self.conv(location).transpose(1, 2))
|
||||
|
||||
u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
|
||||
u = u.squeeze(-1)
|
||||
|
||||
# Mask zero padding chars
|
||||
u = u * (chars != 0).float()
|
||||
|
||||
# Smooth Attention
|
||||
# scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
|
||||
scores = F.softmax(u, dim=1)
|
||||
self.attention = scores
|
||||
self.cumulative = self.cumulative + self.attention
|
||||
|
||||
return scores.unsqueeze(-1).transpose(1, 2)
|
||||
27
models/synthesizer/models/sublayer/pre_net.py
Normal file
27
models/synthesizer/models/sublayer/pre_net.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class PreNet(nn.Module):
|
||||
def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(in_dims, fc1_dims)
|
||||
self.fc2 = nn.Linear(fc1_dims, fc2_dims)
|
||||
self.p = dropout
|
||||
|
||||
def forward(self, x):
|
||||
"""forward
|
||||
|
||||
Args:
|
||||
x (3D tensor with size `[batch_size, num_chars, tts_embed_dims]`): input texts list
|
||||
|
||||
Returns:
|
||||
3D tensor with size `[batch_size, num_chars, encoder_dims]`
|
||||
|
||||
"""
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = F.dropout(x, self.p, training=True)
|
||||
x = self.fc2(x)
|
||||
x = F.relu(x)
|
||||
x = F.dropout(x, self.p, training=True)
|
||||
return x
|
||||
298
models/synthesizer/models/tacotron.py
Normal file
298
models/synthesizer/models/tacotron.py
Normal file
@@ -0,0 +1,298 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .sublayer.global_style_token import GlobalStyleToken
|
||||
from .sublayer.pre_net import PreNet
|
||||
from .sublayer.cbhg import CBHG
|
||||
from .sublayer.lsa import LSA
|
||||
from .base import Base
|
||||
from models.synthesizer.gst_hyperparameters import GSTHyperparameters as gst_hp
|
||||
from models.synthesizer.hparams import hparams
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, num_chars, embed_dims=512, encoder_dims=256, K=5, num_highways=4, dropout=0.5):
|
||||
""" Encoder for SV2TTS
|
||||
|
||||
Args:
|
||||
num_chars (int): length of symbols
|
||||
embed_dims (int, optional): embedding dim for input texts. Defaults to 512.
|
||||
encoder_dims (int, optional): output dim for encoder. Defaults to 256.
|
||||
K (int, optional): _description_. Defaults to 5.
|
||||
num_highways (int, optional): _description_. Defaults to 4.
|
||||
dropout (float, optional): _description_. Defaults to 0.5.
|
||||
"""
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(num_chars, embed_dims)
|
||||
self.pre_net = PreNet(embed_dims, fc1_dims=encoder_dims, fc2_dims=encoder_dims,
|
||||
dropout=dropout)
|
||||
self.cbhg = CBHG(K=K, in_channels=encoder_dims, channels=encoder_dims,
|
||||
proj_channels=[encoder_dims, encoder_dims],
|
||||
num_highways=num_highways)
|
||||
|
||||
def forward(self, x):
|
||||
"""forward pass for encoder
|
||||
|
||||
Args:
|
||||
x (2D tensor with size `[batch_size, text_num_chars]`): input texts list
|
||||
|
||||
Returns:
|
||||
3D tensor with size `[batch_size, text_num_chars, encoder_dims]`
|
||||
|
||||
"""
|
||||
x = self.embedding(x) # return: [batch_size, text_num_chars, tts_embed_dims]
|
||||
x = self.pre_net(x) # return: [batch_size, text_num_chars, encoder_dims]
|
||||
x.transpose_(1, 2) # return: [batch_size, encoder_dims, text_num_chars]
|
||||
return self.cbhg(x) # return: [batch_size, text_num_chars, encoder_dims]
|
||||
|
||||
class Decoder(nn.Module):
|
||||
# Class variable because its value doesn't change between classes
|
||||
# yet ought to be scoped by class because its a property of a Decoder
|
||||
max_r = 20
|
||||
def __init__(self, n_mels, input_dims, decoder_dims, lstm_dims,
|
||||
dropout, speaker_embedding_size):
|
||||
super().__init__()
|
||||
self.register_buffer("r", torch.tensor(1, dtype=torch.int))
|
||||
self.n_mels = n_mels
|
||||
self.prenet = PreNet(n_mels, fc1_dims=decoder_dims * 2, fc2_dims=decoder_dims * 2,
|
||||
dropout=dropout)
|
||||
self.attn_net = LSA(decoder_dims)
|
||||
if hparams.use_gst:
|
||||
speaker_embedding_size += gst_hp.E
|
||||
self.attn_rnn = nn.GRUCell(input_dims + decoder_dims * 2, decoder_dims)
|
||||
self.rnn_input = nn.Linear(input_dims + decoder_dims, lstm_dims)
|
||||
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
|
||||
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
|
||||
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
|
||||
self.stop_proj = nn.Linear(input_dims + lstm_dims, 1)
|
||||
|
||||
def zoneout(self, prev, current, device, p=0.1):
|
||||
mask = torch.zeros(prev.size(),device=device).bernoulli_(p)
|
||||
return prev * mask + current * (1 - mask)
|
||||
|
||||
def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
|
||||
hidden_states, cell_states, context_vec, times, chars):
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
encoder_seq (3D tensor `[batch_size, text_num_chars, project_dim(default to 512)]`): _description_
|
||||
encoder_seq_proj (3D tensor `[batch_size, text_num_chars, decoder_dims(default to 128)]`): _description_
|
||||
prenet_in (2D tensor `[batch_size, n_mels]`): _description_
|
||||
hidden_states (_type_): _description_
|
||||
cell_states (_type_): _description_
|
||||
context_vec (2D tensor `[batch_size, project_dim(default to 512)]`): _description_
|
||||
times (int): the number of times runned
|
||||
chars (2D tensor with size `[batch_size, text_num_chars]`): original texts list input
|
||||
|
||||
"""
|
||||
# Need this for reshaping mels
|
||||
batch_size = encoder_seq.size(0)
|
||||
device = encoder_seq.device
|
||||
# Unpack the hidden and cell states
|
||||
attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
|
||||
rnn1_cell, rnn2_cell = cell_states
|
||||
|
||||
# PreNet for the Attention RNN
|
||||
prenet_out = self.prenet(prenet_in) # return: `[batch_size, decoder_dims * 2(256)]`
|
||||
|
||||
# Compute the Attention RNN hidden state
|
||||
attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1) # `[batch_size, project_dim + decoder_dims * 2 (768)]`
|
||||
attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden) # `[batch_size, decoder_dims (128)]`
|
||||
|
||||
# Compute the attention scores
|
||||
scores = self.attn_net(encoder_seq_proj, attn_hidden, times, chars)
|
||||
|
||||
# Dot product to create the context vector
|
||||
context_vec = scores @ encoder_seq
|
||||
context_vec = context_vec.squeeze(1)
|
||||
|
||||
# Concat Attention RNN output w. Context Vector & project
|
||||
x = torch.cat([context_vec, attn_hidden], dim=1) # `[batch_size, project_dim + decoder_dims (630)]`
|
||||
x = self.rnn_input(x) # `[batch_size, lstm_dims(1024)]`
|
||||
|
||||
# Compute first Residual RNN, training with fixed zoneout rate 0.1
|
||||
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell)) # `[batch_size, lstm_dims(1024)]`
|
||||
if self.training:
|
||||
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next,device=device)
|
||||
else:
|
||||
rnn1_hidden = rnn1_hidden_next
|
||||
x = x + rnn1_hidden
|
||||
|
||||
# Compute second Residual RNN
|
||||
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell)) # `[batch_size, lstm_dims(1024)]`
|
||||
if self.training:
|
||||
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next, device=device)
|
||||
else:
|
||||
rnn2_hidden = rnn2_hidden_next
|
||||
x = x + rnn2_hidden
|
||||
|
||||
# Project Mels
|
||||
mels = self.mel_proj(x) # `[batch_size, 1600]`
|
||||
mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r] # `[batch_size, n_mels, r]`
|
||||
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
||||
cell_states = (rnn1_cell, rnn2_cell)
|
||||
|
||||
# Stop token prediction
|
||||
s = torch.cat((x, context_vec), dim=1)
|
||||
s = self.stop_proj(s)
|
||||
stop_tokens = torch.sigmoid(s)
|
||||
|
||||
return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
|
||||
|
||||
class Tacotron(Base):
|
||||
def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels,
|
||||
fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways,
|
||||
dropout, stop_threshold, speaker_embedding_size):
|
||||
super().__init__(stop_threshold)
|
||||
self.n_mels = n_mels
|
||||
self.lstm_dims = lstm_dims
|
||||
self.encoder_dims = encoder_dims
|
||||
self.decoder_dims = decoder_dims
|
||||
self.speaker_embedding_size = speaker_embedding_size
|
||||
self.encoder = Encoder(num_chars, embed_dims, encoder_dims,
|
||||
encoder_K, num_highways, dropout)
|
||||
self.project_dims = encoder_dims + speaker_embedding_size
|
||||
if hparams.use_gst:
|
||||
self.project_dims += gst_hp.E
|
||||
self.encoder_proj = nn.Linear(self.project_dims, decoder_dims, bias=False)
|
||||
if hparams.use_gst:
|
||||
self.gst = GlobalStyleToken(speaker_embedding_size)
|
||||
self.decoder = Decoder(n_mels, self.project_dims, decoder_dims, lstm_dims,
|
||||
dropout, speaker_embedding_size)
|
||||
self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
|
||||
[postnet_dims, fft_bins], num_highways)
|
||||
self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
|
||||
|
||||
@staticmethod
|
||||
def _concat_speaker_embedding(outputs, speaker_embeddings):
|
||||
speaker_embeddings_ = speaker_embeddings.expand(
|
||||
outputs.size(0), outputs.size(1), -1)
|
||||
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _add_speaker_embedding(x, speaker_embedding):
|
||||
"""Add speaker embedding
|
||||
This concats the speaker embedding for each char in the encoder output
|
||||
Args:
|
||||
x (3D tensor with size `[batch_size, text_num_chars, encoder_dims]`): the encoder output
|
||||
speaker_embedding (2D tensor `[batch_size, speaker_embedding_size]`): the speaker embedding
|
||||
|
||||
Returns:
|
||||
3D tensor with size `[batch_size, text_num_chars, encoder_dims+speaker_embedding_size]`
|
||||
"""
|
||||
# Save the dimensions as human-readable names
|
||||
batch_size = x.size()[0]
|
||||
text_num_chars = x.size()[1]
|
||||
|
||||
# Start by making a copy of each speaker embedding to match the input text length
|
||||
# The output of this has size (batch_size, text_num_chars * speaker_embedding_size)
|
||||
speaker_embedding_size = speaker_embedding.size()[1]
|
||||
e = speaker_embedding.repeat_interleave(text_num_chars, dim=1)
|
||||
|
||||
# Reshape it and transpose
|
||||
e = e.reshape(batch_size, speaker_embedding_size, text_num_chars)
|
||||
e = e.transpose(1, 2)
|
||||
|
||||
# Concatenate the tiled speaker embedding with the encoder output
|
||||
x = torch.cat((x, e), 2)
|
||||
return x
|
||||
|
||||
def forward(self, texts, mels, speaker_embedding, steps=2000, style_idx=0, min_stop_token=5):
|
||||
"""Forward pass for Tacotron
|
||||
|
||||
Args:
|
||||
texts (`[batch_size, text_num_chars]`): input texts list
|
||||
mels (`[batch_size, varied_mel_lengths, steps]`): mels for comparison (training only)
|
||||
speaker_embedding (`[batch_size, speaker_embedding_size(default to 256)]`): referring embedding.
|
||||
steps (int, optional): . Defaults to 2000.
|
||||
style_idx (int, optional): GST style selected. Defaults to 0.
|
||||
min_stop_token (int, optional): decoder min_stop_token. Defaults to 5.
|
||||
"""
|
||||
device = texts.device # use same device as parameters
|
||||
|
||||
if self.training:
|
||||
self.step += 1
|
||||
batch_size, _, steps = mels.size()
|
||||
else:
|
||||
batch_size, _ = texts.size()
|
||||
|
||||
# Initialise all hidden states and pack into tuple
|
||||
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
||||
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
||||
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
||||
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
||||
|
||||
# Initialise all lstm cell states and pack into tuple
|
||||
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
||||
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
||||
cell_states = (rnn1_cell, rnn2_cell)
|
||||
|
||||
# <GO> Frame for start of decoder loop
|
||||
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
||||
|
||||
# SV2TTS: Run the encoder with the speaker embedding
|
||||
# The projection avoids unnecessary matmuls in the decoder loop
|
||||
encoder_seq = self.encoder(texts)
|
||||
|
||||
encoder_seq = self._add_speaker_embedding(encoder_seq, speaker_embedding)
|
||||
|
||||
if hparams.use_gst and self.gst is not None:
|
||||
if self.training:
|
||||
style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced
|
||||
# style_embed = style_embed.expand_as(encoder_seq)
|
||||
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
||||
elif style_idx >= 0 and style_idx < 10:
|
||||
query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
|
||||
if device.type == 'cuda':
|
||||
query = query.cuda()
|
||||
gst_embed = torch.tanh(self.gst.stl.embed)
|
||||
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
|
||||
style_embed = self.gst.stl.attention(query, key)
|
||||
else:
|
||||
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
|
||||
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
||||
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed) # return: [batch_size, text_num_chars, project_dims]
|
||||
|
||||
encoder_seq_proj = self.encoder_proj(encoder_seq) # return: [batch_size, text_num_chars, decoder_dims]
|
||||
|
||||
# Need a couple of lists for outputs
|
||||
mel_outputs, attn_scores, stop_outputs = [], [], []
|
||||
|
||||
# Need an initial context vector
|
||||
context_vec = torch.zeros(batch_size, self.project_dims, device=device)
|
||||
|
||||
# Run the decoder loop
|
||||
for t in range(0, steps, self.r):
|
||||
if self.training:
|
||||
prenet_in = mels[:, :, t -1] if t > 0 else go_frame
|
||||
else:
|
||||
prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
|
||||
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
||||
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
||||
hidden_states, cell_states, context_vec, t, texts)
|
||||
mel_outputs.append(mel_frames)
|
||||
attn_scores.append(scores)
|
||||
stop_outputs.extend([stop_tokens] * self.r)
|
||||
if not self.training and (stop_tokens * 10 > min_stop_token).all() and t > 10: break
|
||||
|
||||
# Concat the mel outputs into sequence
|
||||
mel_outputs = torch.cat(mel_outputs, dim=2)
|
||||
|
||||
# Post-Process for Linear Spectrograms
|
||||
postnet_out = self.postnet(mel_outputs)
|
||||
linear = self.post_proj(postnet_out)
|
||||
linear = linear.transpose(1, 2)
|
||||
|
||||
# For easy visualisation
|
||||
attn_scores = torch.cat(attn_scores, 1)
|
||||
# attn_scores = attn_scores.cpu().data.numpy()
|
||||
stop_outputs = torch.cat(stop_outputs, 1)
|
||||
|
||||
if self.training:
|
||||
self.train()
|
||||
|
||||
return mel_outputs, linear, attn_scores, stop_outputs
|
||||
|
||||
def generate(self, x, speaker_embedding, steps=2000, style_idx=0, min_stop_token=5):
|
||||
self.eval()
|
||||
mel_outputs, linear, attn_scores, _ = self.forward(x, None, speaker_embedding, steps, style_idx, min_stop_token)
|
||||
return mel_outputs, linear, attn_scores
|
||||
120
models/synthesizer/preprocess.py
Normal file
120
models/synthesizer/preprocess.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from multiprocessing.pool import Pool
|
||||
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from models.encoder import inference as encoder
|
||||
from models.synthesizer.preprocess_speaker import preprocess_speaker_general
|
||||
from models.synthesizer.preprocess_transcript import preprocess_transcript_aishell3, preprocess_transcript_magicdata
|
||||
|
||||
data_info = {
|
||||
"aidatatang_200zh": {
|
||||
"subfolders": ["corpus/train"],
|
||||
"trans_filepath": "transcript/aidatatang_200_zh_transcript.txt",
|
||||
"speak_func": preprocess_speaker_general
|
||||
},
|
||||
"magicdata": {
|
||||
"subfolders": ["train"],
|
||||
"trans_filepath": "train/TRANS.txt",
|
||||
"speak_func": preprocess_speaker_general,
|
||||
"transcript_func": preprocess_transcript_magicdata,
|
||||
},
|
||||
"aishell3":{
|
||||
"subfolders": ["train/wav"],
|
||||
"trans_filepath": "train/content.txt",
|
||||
"speak_func": preprocess_speaker_general,
|
||||
"transcript_func": preprocess_transcript_aishell3,
|
||||
},
|
||||
"data_aishell":{
|
||||
"subfolders": ["wav/train"],
|
||||
"trans_filepath": "transcript/aishell_transcript_v0.8.txt",
|
||||
"speak_func": preprocess_speaker_general
|
||||
}
|
||||
}
|
||||
|
||||
def preprocess_dataset(datasets_root: Path, out_dir: Path, n_processes: int,
|
||||
skip_existing: bool, hparams, no_alignments: bool,
|
||||
dataset: str):
|
||||
dataset_info = data_info[dataset]
|
||||
# Gather the input directories
|
||||
dataset_root = datasets_root.joinpath(dataset)
|
||||
input_dirs = [dataset_root.joinpath(subfolder.strip()) for subfolder in dataset_info["subfolders"]]
|
||||
print("\n ".join(map(str, ["Using data from:"] + input_dirs)))
|
||||
assert all(input_dir.exists() for input_dir in input_dirs)
|
||||
|
||||
# Create the output directories for each output file type
|
||||
out_dir.joinpath("mels").mkdir(exist_ok=True)
|
||||
out_dir.joinpath("audio").mkdir(exist_ok=True)
|
||||
|
||||
# Create a metadata file
|
||||
metadata_fpath = out_dir.joinpath("train.txt")
|
||||
metadata_file = metadata_fpath.open("a" if skip_existing else "w", encoding="utf-8")
|
||||
|
||||
# Preprocess the dataset
|
||||
dict_info = {}
|
||||
transcript_dirs = dataset_root.joinpath(dataset_info["trans_filepath"])
|
||||
assert transcript_dirs.exists(), str(transcript_dirs)+" not exist."
|
||||
with open(transcript_dirs, "r", encoding="utf-8") as dict_transcript:
|
||||
# process with specific function for your dataset
|
||||
if "transcript_func" in dataset_info:
|
||||
dataset_info["transcript_func"](dict_info, dict_transcript)
|
||||
else:
|
||||
for v in dict_transcript:
|
||||
if not v:
|
||||
continue
|
||||
v = v.strip().replace("\n","").replace("\t"," ").split(" ")
|
||||
dict_info[v[0]] = " ".join(v[1:])
|
||||
|
||||
speaker_dirs = list(chain.from_iterable(input_dir.glob("*") for input_dir in input_dirs))
|
||||
func = partial(dataset_info["speak_func"], out_dir=out_dir, skip_existing=skip_existing,
|
||||
hparams=hparams, dict_info=dict_info, no_alignments=no_alignments)
|
||||
job = Pool(n_processes).imap(func, speaker_dirs)
|
||||
for speaker_metadata in tqdm(job, dataset, len(speaker_dirs), unit="speakers"):
|
||||
for metadatum in speaker_metadata:
|
||||
metadata_file.write("|".join(str(x) for x in metadatum) + "\n")
|
||||
metadata_file.close()
|
||||
|
||||
# Verify the contents of the metadata file
|
||||
with metadata_fpath.open("r", encoding="utf-8") as metadata_file:
|
||||
metadata = [line.split("|") for line in metadata_file]
|
||||
mel_frames = sum([int(m[4]) for m in metadata])
|
||||
timesteps = sum([int(m[3]) for m in metadata])
|
||||
sample_rate = hparams.sample_rate
|
||||
hours = (timesteps / sample_rate) / 3600
|
||||
print("The dataset consists of %d utterances, %d mel frames, %d audio timesteps (%.2f hours)." %
|
||||
(len(metadata), mel_frames, timesteps, hours))
|
||||
print("Max input length (text chars): %d" % max(len(m[5]) for m in metadata))
|
||||
print("Max mel frames length: %d" % max(int(m[4]) for m in metadata))
|
||||
print("Max audio timesteps length: %d" % max(int(m[3]) for m in metadata))
|
||||
|
||||
def embed_utterance(fpaths, encoder_model_fpath):
|
||||
if not encoder.is_loaded():
|
||||
encoder.load_model(encoder_model_fpath)
|
||||
|
||||
# Compute the speaker embedding of the utterance
|
||||
wav_fpath, embed_fpath = fpaths
|
||||
wav = np.load(wav_fpath)
|
||||
wav = encoder.preprocess_wav(wav)
|
||||
embed = encoder.embed_utterance(wav)
|
||||
np.save(embed_fpath, embed, allow_pickle=False)
|
||||
|
||||
|
||||
def create_embeddings(synthesizer_root: Path, encoder_model_fpath: Path, n_processes: int):
|
||||
wav_dir = synthesizer_root.joinpath("audio")
|
||||
metadata_fpath = synthesizer_root.joinpath("train.txt")
|
||||
assert wav_dir.exists() and metadata_fpath.exists()
|
||||
embed_dir = synthesizer_root.joinpath("embeds")
|
||||
embed_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Gather the input wave filepath and the target output embed filepath
|
||||
with metadata_fpath.open("r", encoding="utf-8") as metadata_file:
|
||||
metadata = [line.split("|") for line in metadata_file]
|
||||
fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
|
||||
|
||||
# TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
|
||||
# Embed the utterances in separate threads
|
||||
func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
|
||||
job = Pool(n_processes).imap(func, fpaths)
|
||||
list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
|
||||
99
models/synthesizer/preprocess_speaker.py
Normal file
99
models/synthesizer/preprocess_speaker.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
from models.encoder import inference as encoder
|
||||
from utils import logmmse
|
||||
from models.synthesizer import audio
|
||||
from pathlib import Path
|
||||
from pypinyin import Style
|
||||
from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin
|
||||
from pypinyin.converter import DefaultConverter
|
||||
from pypinyin.core import Pinyin
|
||||
|
||||
class PinyinConverter(NeutralToneWith5Mixin, DefaultConverter):
|
||||
pass
|
||||
|
||||
pinyin = Pinyin(PinyinConverter()).pinyin
|
||||
|
||||
|
||||
def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
|
||||
skip_existing: bool, hparams):
|
||||
## FOR REFERENCE:
|
||||
# For you not to lose your head if you ever wish to change things here or implement your own
|
||||
# synthesizer.
|
||||
# - Both the audios and the mel spectrograms are saved as numpy arrays
|
||||
# - There is no processing done to the audios that will be saved to disk beyond volume
|
||||
# normalization (in split_on_silences)
|
||||
# - However, pre-emphasis is applied to the audios before computing the mel spectrogram. This
|
||||
# is why we re-apply it on the audio on the side of the vocoder.
|
||||
# - Librosa pads the waveform before computing the mel spectrogram. Here, the waveform is saved
|
||||
# without extra padding. This means that you won't have an exact relation between the length
|
||||
# of the wav and of the mel spectrogram. See the vocoder data loader.
|
||||
|
||||
|
||||
# Skip existing utterances if needed
|
||||
mel_fpath = out_dir.joinpath("mels", "mel-%s.npy" % basename)
|
||||
wav_fpath = out_dir.joinpath("audio", "audio-%s.npy" % basename)
|
||||
if skip_existing and mel_fpath.exists() and wav_fpath.exists():
|
||||
return None
|
||||
|
||||
# Trim silence
|
||||
if hparams.trim_silence:
|
||||
wav = encoder.preprocess_wav(wav, normalize=False, trim_silence=True)
|
||||
|
||||
# Skip utterances that are too short
|
||||
if len(wav) < hparams.utterance_min_duration * hparams.sample_rate:
|
||||
return None
|
||||
|
||||
# Compute the mel spectrogram
|
||||
mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
|
||||
mel_frames = mel_spectrogram.shape[1]
|
||||
|
||||
# Skip utterances that are too long
|
||||
if mel_frames > hparams.max_mel_frames and hparams.clip_mels_length:
|
||||
return None
|
||||
|
||||
# Write the spectrogram, embed and audio to disk
|
||||
np.save(mel_fpath, mel_spectrogram.T, allow_pickle=False)
|
||||
np.save(wav_fpath, wav, allow_pickle=False)
|
||||
|
||||
# Return a tuple describing this training example
|
||||
return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, len(wav), mel_frames, text
|
||||
|
||||
|
||||
def _split_on_silences(wav_fpath, words, hparams):
|
||||
# Load the audio waveform
|
||||
wav, _ = librosa.load(wav_fpath, sr= hparams.sample_rate)
|
||||
wav = librosa.effects.trim(wav, top_db= 40, frame_length=2048, hop_length=512)[0]
|
||||
if hparams.rescale:
|
||||
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
||||
# denoise, we may not need it here.
|
||||
if len(wav) > hparams.sample_rate*(0.3+0.1):
|
||||
noise_wav = np.concatenate([wav[:int(hparams.sample_rate*0.15)],
|
||||
wav[-int(hparams.sample_rate*0.15):]])
|
||||
profile = logmmse.profile_noise(noise_wav, hparams.sample_rate)
|
||||
wav = logmmse.denoise(wav, profile, eta=0)
|
||||
|
||||
resp = pinyin(words, style=Style.TONE3)
|
||||
res = [v[0] for v in resp if v[0].strip()]
|
||||
res = " ".join(res)
|
||||
|
||||
return wav, res
|
||||
|
||||
def preprocess_speaker_general(speaker_dir, out_dir: Path, skip_existing: bool, hparams, dict_info, no_alignments: bool):
|
||||
metadata = []
|
||||
extensions = ["*.wav", "*.flac", "*.mp3"]
|
||||
for extension in extensions:
|
||||
wav_fpath_list = speaker_dir.glob(extension)
|
||||
# Iterate over each wav
|
||||
for wav_fpath in wav_fpath_list:
|
||||
words = dict_info.get(wav_fpath.name.split(".")[0])
|
||||
words = dict_info.get(wav_fpath.name) if not words else words # try with wav
|
||||
if not words:
|
||||
print("no wordS")
|
||||
continue
|
||||
sub_basename = "%s_%02d" % (wav_fpath.name, 0)
|
||||
wav, text = _split_on_silences(wav_fpath, words, hparams)
|
||||
metadata.append(_process_utterance(wav, text, out_dir, sub_basename,
|
||||
skip_existing, hparams))
|
||||
return [m for m in metadata if m is not None]
|
||||
18
models/synthesizer/preprocess_transcript.py
Normal file
18
models/synthesizer/preprocess_transcript.py
Normal file
@@ -0,0 +1,18 @@
|
||||
def preprocess_transcript_aishell3(dict_info, dict_transcript):
|
||||
for v in dict_transcript:
|
||||
if not v:
|
||||
continue
|
||||
v = v.strip().replace("\n","").replace("\t"," ").split(" ")
|
||||
transList = []
|
||||
for i in range(2, len(v), 2):
|
||||
transList.append(v[i])
|
||||
dict_info[v[0]] = " ".join(transList)
|
||||
|
||||
|
||||
def preprocess_transcript_magicdata(dict_info, dict_transcript):
|
||||
for v in dict_transcript:
|
||||
if not v:
|
||||
continue
|
||||
v = v.strip().replace("\n","").replace("\t"," ").split(" ")
|
||||
dict_info[v[0]] = " ".join(v[2:])
|
||||
|
||||
97
models/synthesizer/synthesize.py
Normal file
97
models/synthesizer/synthesize.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from models.synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
|
||||
from models.synthesizer.models.tacotron import Tacotron
|
||||
from models.synthesizer.utils.text import text_to_sequence
|
||||
from models.synthesizer.utils.symbols import symbols
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import sys
|
||||
|
||||
|
||||
def run_synthesis(in_dir, out_dir, model_dir, hparams):
|
||||
# This generates ground truth-aligned mels for vocoder training
|
||||
synth_dir = Path(out_dir).joinpath("mels_gta")
|
||||
synth_dir.mkdir(parents=True, exist_ok=True)
|
||||
print(str(hparams))
|
||||
|
||||
# Check for GPU
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
if hparams.synthesis_batch_size % torch.cuda.device_count() != 0:
|
||||
raise ValueError("`hparams.synthesis_batch_size` must be evenly divisible by n_gpus!")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print("Synthesizer using device:", device)
|
||||
|
||||
# Instantiate Tacotron model
|
||||
model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
||||
num_chars=len(symbols),
|
||||
encoder_dims=hparams.tts_encoder_dims,
|
||||
decoder_dims=hparams.tts_decoder_dims,
|
||||
n_mels=hparams.num_mels,
|
||||
fft_bins=hparams.num_mels,
|
||||
postnet_dims=hparams.tts_postnet_dims,
|
||||
encoder_K=hparams.tts_encoder_K,
|
||||
lstm_dims=hparams.tts_lstm_dims,
|
||||
postnet_K=hparams.tts_postnet_K,
|
||||
num_highways=hparams.tts_num_highways,
|
||||
dropout=0., # Use zero dropout for gta mels
|
||||
stop_threshold=hparams.tts_stop_threshold,
|
||||
speaker_embedding_size=hparams.speaker_embedding_size).to(device)
|
||||
|
||||
# Load the weights
|
||||
model_dir = Path(model_dir)
|
||||
model_fpath = model_dir.joinpath(model_dir.stem).with_suffix(".pt")
|
||||
print("\nLoading weights at %s" % model_fpath)
|
||||
model.load(model_fpath, device)
|
||||
print("Tacotron weights loaded from step %d" % model.step)
|
||||
|
||||
# Synthesize using same reduction factor as the model is currently trained
|
||||
r = np.int32(model.r)
|
||||
|
||||
# Set model to eval mode (disable gradient and zoneout)
|
||||
model.eval()
|
||||
|
||||
# Initialize the dataset
|
||||
in_dir = Path(in_dir)
|
||||
metadata_fpath = in_dir.joinpath("train.txt")
|
||||
mel_dir = in_dir.joinpath("mels")
|
||||
embed_dir = in_dir.joinpath("embeds")
|
||||
num_workers = 0 if sys.platform.startswith("win") else 2;
|
||||
dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
|
||||
data_loader = DataLoader(dataset,
|
||||
collate_fn=lambda batch: collate_synthesizer(batch),
|
||||
batch_size=hparams.synthesis_batch_size,
|
||||
num_workers=num_workers,
|
||||
shuffle=False,
|
||||
pin_memory=True)
|
||||
|
||||
# Generate GTA mels
|
||||
meta_out_fpath = Path(out_dir).joinpath("synthesized.txt")
|
||||
with open(meta_out_fpath, "w") as file:
|
||||
for i, (texts, mels, embeds, idx) in tqdm(enumerate(data_loader), total=len(data_loader)):
|
||||
texts = texts.to(device)
|
||||
mels = mels.to(device)
|
||||
embeds = embeds.to(device)
|
||||
|
||||
# Parallelize model onto GPUS using workaround due to python bug
|
||||
if device.type == "cuda" and torch.cuda.device_count() > 1:
|
||||
_, mels_out, _ , _ = data_parallel_workaround(model, texts, mels, embeds)
|
||||
else:
|
||||
_, mels_out, _, _ = model(texts, mels, embeds)
|
||||
|
||||
for j, k in enumerate(idx):
|
||||
# Note: outputs mel-spectrogram files and target ones have same names, just different folders
|
||||
mel_filename = Path(synth_dir).joinpath(dataset.metadata[k][1])
|
||||
mel_out = mels_out[j].detach().cpu().numpy().T
|
||||
|
||||
# Use the length of the ground truth mel to remove padding from the generated mels
|
||||
mel_out = mel_out[:int(dataset.metadata[k][4])]
|
||||
|
||||
# Write the spectrogram to disk
|
||||
np.save(mel_filename, mel_out, allow_pickle=False)
|
||||
|
||||
# Write metadata into the synthesized file
|
||||
file.write("|".join(dataset.metadata[k]))
|
||||
93
models/synthesizer/synthesizer_dataset.py
Normal file
93
models/synthesizer/synthesizer_dataset.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from models.synthesizer.utils.text import text_to_sequence
|
||||
|
||||
|
||||
class SynthesizerDataset(Dataset):
|
||||
def __init__(self, metadata_fpath: Path, mel_dir: Path, embed_dir: Path, hparams):
|
||||
print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, embed_dir))
|
||||
|
||||
with metadata_fpath.open("r", encoding="utf-8") as metadata_file:
|
||||
metadata = [line.split("|") for line in metadata_file]
|
||||
|
||||
mel_fnames = [x[1] for x in metadata if int(x[4])]
|
||||
mel_fpaths = [mel_dir.joinpath(fname) for fname in mel_fnames]
|
||||
embed_fnames = [x[2] for x in metadata if int(x[4])]
|
||||
embed_fpaths = [embed_dir.joinpath(fname) for fname in embed_fnames]
|
||||
self.samples_fpaths = list(zip(mel_fpaths, embed_fpaths))
|
||||
self.samples_texts = [x[5].strip() for x in metadata if int(x[4])]
|
||||
self.metadata = metadata
|
||||
self.hparams = hparams
|
||||
|
||||
print("Found %d samples" % len(self.samples_fpaths))
|
||||
|
||||
def __getitem__(self, index):
|
||||
# Sometimes index may be a list of 2 (not sure why this happens)
|
||||
# If that is the case, return a single item corresponding to first element in index
|
||||
if index is list:
|
||||
index = index[0]
|
||||
|
||||
mel_path, embed_path = self.samples_fpaths[index]
|
||||
mel = np.load(mel_path).T.astype(np.float32)
|
||||
|
||||
# Load the embed
|
||||
embed = np.load(embed_path)
|
||||
|
||||
# Get the text and clean it
|
||||
text = text_to_sequence(self.samples_texts[index], self.hparams.tts_cleaner_names)
|
||||
|
||||
# Convert the list returned by text_to_sequence to a numpy array
|
||||
text = np.asarray(text).astype(np.int32)
|
||||
|
||||
return text, mel.astype(np.float32), embed.astype(np.float32), index
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples_fpaths)
|
||||
|
||||
|
||||
def collate_synthesizer(batch):
|
||||
# Text
|
||||
x_lens = [len(x[0]) for x in batch]
|
||||
max_x_len = max(x_lens)
|
||||
|
||||
chars = [pad1d(x[0], max_x_len) for x in batch]
|
||||
chars = np.stack(chars)
|
||||
|
||||
# Mel spectrogram
|
||||
spec_lens = [x[1].shape[-1] for x in batch]
|
||||
max_spec_len = max(spec_lens) + 1
|
||||
if max_spec_len % 2 != 0: # FIXIT: Hardcoded due to incompatibility with Windows (no lambda)
|
||||
max_spec_len += 2 - max_spec_len % 2
|
||||
|
||||
# WaveRNN mel spectrograms are normalized to [0, 1] so zero padding adds silence
|
||||
# By default, SV2TTS uses symmetric mels, where -1*max_abs_value is silence.
|
||||
# if hparams.symmetric_mels:
|
||||
# mel_pad_value = -1 * hparams.max_abs_value
|
||||
# else:
|
||||
# mel_pad_value = 0
|
||||
mel_pad_value = -4 # FIXIT: Hardcoded due to incompatibility with Windows (no lambda)
|
||||
mel = [pad2d(x[1], max_spec_len, pad_value=mel_pad_value) for x in batch]
|
||||
mel = np.stack(mel)
|
||||
|
||||
# Speaker embedding (SV2TTS)
|
||||
embeds = [x[2] for x in batch]
|
||||
embeds = np.stack(embeds)
|
||||
|
||||
# Index (for vocoder preprocessing)
|
||||
indices = [x[3] for x in batch]
|
||||
|
||||
|
||||
# Convert all to tensor
|
||||
chars = torch.tensor(chars).long()
|
||||
mel = torch.tensor(mel)
|
||||
embeds = torch.tensor(embeds)
|
||||
|
||||
return chars, mel, embeds, indices
|
||||
|
||||
def pad1d(x, max_len, pad_value=0):
|
||||
return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
|
||||
|
||||
def pad2d(x, max_len, pad_value=0):
|
||||
return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant", constant_values=pad_value)
|
||||
317
models/synthesizer/train.py
Normal file
317
models/synthesizer/train.py
Normal file
@@ -0,0 +1,317 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from models.synthesizer import audio
|
||||
from models.synthesizer.models.tacotron import Tacotron
|
||||
from models.synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
|
||||
from models.synthesizer.utils import ValueWindow, data_parallel_workaround
|
||||
from models.synthesizer.utils.plot import plot_spectrogram, plot_spectrogram_and_trace
|
||||
from models.synthesizer.utils.symbols import symbols
|
||||
from models.synthesizer.utils.text import sequence_to_text
|
||||
from models.vocoder.display import *
|
||||
from datetime import datetime
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import time
|
||||
import os
|
||||
|
||||
def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
|
||||
|
||||
def time_string():
|
||||
return datetime.now().strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||
backup_every: int, log_every:int, force_restart:bool, hparams):
|
||||
|
||||
syn_dir = Path(syn_dir)
|
||||
models_dir = Path(models_dir)
|
||||
models_dir.mkdir(exist_ok=True)
|
||||
|
||||
model_dir = models_dir.joinpath(run_id)
|
||||
plot_dir = model_dir.joinpath("plots")
|
||||
wav_dir = model_dir.joinpath("wavs")
|
||||
mel_output_dir = model_dir.joinpath("mel-spectrograms")
|
||||
meta_folder = model_dir.joinpath("metas")
|
||||
model_dir.mkdir(exist_ok=True)
|
||||
plot_dir.mkdir(exist_ok=True)
|
||||
wav_dir.mkdir(exist_ok=True)
|
||||
mel_output_dir.mkdir(exist_ok=True)
|
||||
meta_folder.mkdir(exist_ok=True)
|
||||
|
||||
weights_fpath = model_dir.joinpath(run_id).with_suffix(".pt")
|
||||
metadata_fpath = syn_dir.joinpath("train.txt")
|
||||
|
||||
print("Checkpoint path: {}".format(weights_fpath))
|
||||
print("Loading training data from: {}".format(metadata_fpath))
|
||||
print("Using model: Tacotron")
|
||||
|
||||
# Book keeping
|
||||
step = 0
|
||||
time_window = ValueWindow(100)
|
||||
loss_window = ValueWindow(100)
|
||||
|
||||
|
||||
# From WaveRNN/train_tacotron.py
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
|
||||
for session in hparams.tts_schedule:
|
||||
_, _, _, batch_size = session
|
||||
if batch_size % torch.cuda.device_count() != 0:
|
||||
raise ValueError("`batch_size` must be evenly divisible by n_gpus!")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print("Using device:", device)
|
||||
|
||||
# Instantiate Tacotron Model
|
||||
print("\nInitialising Tacotron Model...\n")
|
||||
num_chars = len(symbols)
|
||||
if weights_fpath.exists():
|
||||
# for compatibility purpose, change symbols accordingly:
|
||||
loaded_shape = torch.load(str(weights_fpath), map_location=device)["model_state"]["encoder.embedding.weight"].shape
|
||||
if num_chars != loaded_shape[0]:
|
||||
print("WARNING: you are using compatible mode due to wrong sympols length, please modify varible _characters in `utils\symbols.py`")
|
||||
num_chars != loaded_shape[0]
|
||||
# Try to scan config file
|
||||
model_config_fpaths = list(weights_fpath.parent.rglob("*.json"))
|
||||
if len(model_config_fpaths)>0 and model_config_fpaths[0].exists():
|
||||
with model_config_fpaths[0].open("r", encoding="utf-8") as f:
|
||||
hparams.loadJson(json.load(f))
|
||||
else: # save a config
|
||||
hparams.dumpJson(weights_fpath.parent.joinpath(run_id).with_suffix(".json"))
|
||||
|
||||
|
||||
model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
||||
num_chars=num_chars,
|
||||
encoder_dims=hparams.tts_encoder_dims,
|
||||
decoder_dims=hparams.tts_decoder_dims,
|
||||
n_mels=hparams.num_mels,
|
||||
fft_bins=hparams.num_mels,
|
||||
postnet_dims=hparams.tts_postnet_dims,
|
||||
encoder_K=hparams.tts_encoder_K,
|
||||
lstm_dims=hparams.tts_lstm_dims,
|
||||
postnet_K=hparams.tts_postnet_K,
|
||||
num_highways=hparams.tts_num_highways,
|
||||
dropout=hparams.tts_dropout,
|
||||
stop_threshold=hparams.tts_stop_threshold,
|
||||
speaker_embedding_size=hparams.speaker_embedding_size).to(device)
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = optim.Adam(model.parameters(), amsgrad=True)
|
||||
|
||||
# Load the weights
|
||||
if force_restart or not weights_fpath.exists():
|
||||
print("\nStarting the training of Tacotron from scratch\n")
|
||||
model.save(weights_fpath)
|
||||
|
||||
# Embeddings metadata
|
||||
char_embedding_fpath = meta_folder.joinpath("CharacterEmbeddings.tsv")
|
||||
with open(char_embedding_fpath, "w", encoding="utf-8") as f:
|
||||
for symbol in symbols:
|
||||
if symbol == " ":
|
||||
symbol = "\\s" # For visual purposes, swap space with \s
|
||||
|
||||
f.write("{}\n".format(symbol))
|
||||
|
||||
else:
|
||||
print("\nLoading weights at %s" % weights_fpath)
|
||||
model.load(weights_fpath, device, optimizer)
|
||||
print("Tacotron weights loaded from step %d" % model.step)
|
||||
|
||||
# Initialize the dataset
|
||||
metadata_fpath = syn_dir.joinpath("train.txt")
|
||||
mel_dir = syn_dir.joinpath("mels")
|
||||
embed_dir = syn_dir.joinpath("embeds")
|
||||
dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
|
||||
test_loader = DataLoader(dataset,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
pin_memory=True)
|
||||
|
||||
# tracing training step
|
||||
sw = SummaryWriter(log_dir=model_dir.joinpath("logs"))
|
||||
|
||||
for i, session in enumerate(hparams.tts_schedule):
|
||||
current_step = model.get_step()
|
||||
|
||||
r, lr, max_step, batch_size = session
|
||||
|
||||
training_steps = max_step - current_step
|
||||
|
||||
# Do we need to change to the next session?
|
||||
if current_step >= max_step:
|
||||
# Are there no further sessions than the current one?
|
||||
if i == len(hparams.tts_schedule) - 1:
|
||||
# We have completed training. Save the model and exit
|
||||
model.save(weights_fpath, optimizer)
|
||||
break
|
||||
else:
|
||||
# There is a following session, go to it
|
||||
continue
|
||||
|
||||
model.r = r
|
||||
# Begin the training
|
||||
simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
|
||||
("Batch Size", batch_size),
|
||||
("Learning Rate", lr),
|
||||
("Outputs/Step (r)", model.r)])
|
||||
|
||||
for p in optimizer.param_groups:
|
||||
p["lr"] = lr
|
||||
if hparams.tts_finetune_layers is not None and len(hparams.tts_finetune_layers) > 0:
|
||||
model.finetune_partial(hparams.tts_finetune_layers)
|
||||
|
||||
data_loader = DataLoader(dataset,
|
||||
collate_fn=collate_synthesizer,
|
||||
batch_size=batch_size, #change if you got graphic card OOM
|
||||
num_workers=2,
|
||||
shuffle=True,
|
||||
pin_memory=True)
|
||||
|
||||
total_iters = len(dataset)
|
||||
steps_per_epoch = np.ceil(total_iters / batch_size).astype(np.int32)
|
||||
epochs = np.ceil(training_steps / steps_per_epoch).astype(np.int32)
|
||||
|
||||
for epoch in range(1, epochs+1):
|
||||
for i, (texts, mels, embeds, idx) in enumerate(data_loader, 1):
|
||||
start_time = time.time()
|
||||
|
||||
# Generate stop tokens for training
|
||||
stop = torch.ones(mels.shape[0], mels.shape[2])
|
||||
for j, k in enumerate(idx):
|
||||
stop[j, :int(dataset.metadata[k][4])-1] = 0
|
||||
|
||||
texts = texts.to(device)
|
||||
mels = mels.to(device)
|
||||
embeds = embeds.to(device)
|
||||
stop = stop.to(device)
|
||||
|
||||
# Forward pass
|
||||
# Parallelize model onto GPUS using workaround due to python bug
|
||||
if device.type == "cuda" and torch.cuda.device_count() > 1:
|
||||
m1_hat, m2_hat, attention, stop_pred = data_parallel_workaround(model, texts,
|
||||
mels, embeds)
|
||||
else:
|
||||
m1_hat, m2_hat, attention, stop_pred = model(texts, mels, embeds)
|
||||
|
||||
# Backward pass
|
||||
m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels)
|
||||
m2_loss = F.mse_loss(m2_hat, mels)
|
||||
stop_loss = F.binary_cross_entropy(stop_pred, stop)
|
||||
|
||||
loss = m1_loss + m2_loss + stop_loss
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
if hparams.tts_clip_grad_norm is not None:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.tts_clip_grad_norm)
|
||||
if np.isnan(grad_norm.cpu()):
|
||||
print("grad_norm was NaN!")
|
||||
|
||||
optimizer.step()
|
||||
|
||||
time_window.append(time.time() - start_time)
|
||||
loss_window.append(loss.item())
|
||||
|
||||
step = model.get_step()
|
||||
k = step // 1000
|
||||
|
||||
|
||||
msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | {1./time_window.average:#.2} steps/s | Step: {k}k | "
|
||||
stream(msg)
|
||||
|
||||
if log_every != 0 and step % log_every == 0 :
|
||||
sw.add_scalar("training/loss", loss_window.average, step)
|
||||
|
||||
# Backup or save model as appropriate
|
||||
if backup_every != 0 and step % backup_every == 0 :
|
||||
backup_fpath = Path("{}/{}_{}.pt".format(str(weights_fpath.parent), run_id, step))
|
||||
model.save(backup_fpath, optimizer)
|
||||
|
||||
if save_every != 0 and step % save_every == 0 :
|
||||
# Must save latest optimizer state to ensure that resuming training
|
||||
# doesn't produce artifacts
|
||||
model.save(weights_fpath, optimizer)
|
||||
|
||||
|
||||
# Evaluate model to generate samples
|
||||
epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done
|
||||
step_eval = hparams.tts_eval_interval > 0 and step % hparams.tts_eval_interval == 0 # Every N steps
|
||||
if epoch_eval or step_eval:
|
||||
for sample_idx in range(hparams.tts_eval_num_samples):
|
||||
# At most, generate samples equal to number in the batch
|
||||
if sample_idx + 1 <= len(texts):
|
||||
# Remove padding from mels using frame length in metadata
|
||||
mel_length = int(dataset.metadata[idx[sample_idx]][4])
|
||||
mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length]
|
||||
target_spectrogram = np_now(mels[sample_idx]).T[:mel_length]
|
||||
attention_len = mel_length // model.r
|
||||
# eval_loss = F.mse_loss(mel_prediction, target_spectrogram)
|
||||
# sw.add_scalar("validing/loss", eval_loss.item(), step)
|
||||
eval_model(attention=np_now(attention[sample_idx][:, :attention_len]),
|
||||
mel_prediction=mel_prediction,
|
||||
target_spectrogram=target_spectrogram,
|
||||
input_seq=np_now(texts[sample_idx]),
|
||||
step=step,
|
||||
plot_dir=plot_dir,
|
||||
mel_output_dir=mel_output_dir,
|
||||
wav_dir=wav_dir,
|
||||
sample_num=sample_idx + 1,
|
||||
loss=loss,
|
||||
hparams=hparams,
|
||||
sw=sw)
|
||||
MAX_SAVED_COUNT = 20
|
||||
if (step / hparams.tts_eval_interval) % MAX_SAVED_COUNT == 0:
|
||||
# clean up and save last MAX_SAVED_COUNT;
|
||||
plots = next(os.walk(plot_dir), (None, None, []))[2]
|
||||
for plot in plots[-MAX_SAVED_COUNT:]:
|
||||
os.remove(plot_dir.joinpath(plot))
|
||||
mel_files = next(os.walk(mel_output_dir), (None, None, []))[2]
|
||||
for mel_file in mel_files[-MAX_SAVED_COUNT:]:
|
||||
os.remove(mel_output_dir.joinpath(mel_file))
|
||||
wavs = next(os.walk(wav_dir), (None, None, []))[2]
|
||||
for w in wavs[-MAX_SAVED_COUNT:]:
|
||||
os.remove(wav_dir.joinpath(w))
|
||||
|
||||
# Break out of loop to update training schedule
|
||||
if step >= max_step:
|
||||
break
|
||||
|
||||
# Add line break after every epoch
|
||||
print("")
|
||||
|
||||
def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
|
||||
plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams, sw):
|
||||
# Save some results for evaluation
|
||||
attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num)))
|
||||
# save_attention(attention, attention_path)
|
||||
save_and_trace_attention(attention, attention_path, sw, step)
|
||||
|
||||
# save predicted mel spectrogram to disk (debug)
|
||||
mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num))
|
||||
np.save(str(mel_output_fpath), mel_prediction, allow_pickle=False)
|
||||
|
||||
# save griffin lim inverted wav for debug (mel -> wav)
|
||||
wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
|
||||
wav_fpath = wav_dir.joinpath("step-{}-wave-from-mel_sample_{}.wav".format(step, sample_num))
|
||||
audio.save_wav(wav, str(wav_fpath), sr=hparams.sample_rate)
|
||||
|
||||
# save real and predicted mel-spectrogram plot to disk (control purposes)
|
||||
spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num))
|
||||
title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss)
|
||||
# plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str,
|
||||
# target_spectrogram=target_spectrogram,
|
||||
# max_len=target_spectrogram.size // hparams.num_mels)
|
||||
plot_spectrogram_and_trace(
|
||||
mel_prediction,
|
||||
str(spec_fpath),
|
||||
title=title_str,
|
||||
target_spectrogram=target_spectrogram,
|
||||
max_len=target_spectrogram.size // hparams.num_mels,
|
||||
sw=sw,
|
||||
step=step)
|
||||
print("Input at step {}: {}".format(step, sequence_to_text(input_seq)))
|
||||
45
models/synthesizer/utils/__init__.py
Normal file
45
models/synthesizer/utils/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch
|
||||
|
||||
|
||||
_output_ref = None
|
||||
_replicas_ref = None
|
||||
|
||||
def data_parallel_workaround(model, *input):
|
||||
global _output_ref
|
||||
global _replicas_ref
|
||||
device_ids = list(range(torch.cuda.device_count()))
|
||||
output_device = device_ids[0]
|
||||
replicas = torch.nn.parallel.replicate(model, device_ids)
|
||||
# input.shape = (num_args, batch, ...)
|
||||
inputs = torch.nn.parallel.scatter(input, device_ids)
|
||||
# inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
|
||||
replicas = replicas[:len(inputs)]
|
||||
outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
|
||||
y_hat = torch.nn.parallel.gather(outputs, output_device)
|
||||
_output_ref = outputs
|
||||
_replicas_ref = replicas
|
||||
return y_hat
|
||||
|
||||
|
||||
class ValueWindow():
|
||||
def __init__(self, window_size=100):
|
||||
self._window_size = window_size
|
||||
self._values = []
|
||||
|
||||
def append(self, x):
|
||||
self._values = self._values[-(self._window_size - 1):] + [x]
|
||||
|
||||
@property
|
||||
def sum(self):
|
||||
return sum(self._values)
|
||||
|
||||
@property
|
||||
def count(self):
|
||||
return len(self._values)
|
||||
|
||||
@property
|
||||
def average(self):
|
||||
return self.sum / max(1, self.count)
|
||||
|
||||
def reset(self):
|
||||
self._values = []
|
||||
62
models/synthesizer/utils/_cmudict.py
Normal file
62
models/synthesizer/utils/_cmudict.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import re
|
||||
|
||||
valid_symbols = [
|
||||
"AA", "AA0", "AA1", "AA2", "AE", "AE0", "AE1", "AE2", "AH", "AH0", "AH1", "AH2",
|
||||
"AO", "AO0", "AO1", "AO2", "AW", "AW0", "AW1", "AW2", "AY", "AY0", "AY1", "AY2",
|
||||
"B", "CH", "D", "DH", "EH", "EH0", "EH1", "EH2", "ER", "ER0", "ER1", "ER2", "EY",
|
||||
"EY0", "EY1", "EY2", "F", "G", "HH", "IH", "IH0", "IH1", "IH2", "IY", "IY0", "IY1",
|
||||
"IY2", "JH", "K", "L", "M", "N", "NG", "OW", "OW0", "OW1", "OW2", "OY", "OY0",
|
||||
"OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH", "UH0", "UH1", "UH2", "UW",
|
||||
"UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH"
|
||||
]
|
||||
|
||||
_valid_symbol_set = set(valid_symbols)
|
||||
|
||||
|
||||
class CMUDict:
|
||||
"""Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
|
||||
def __init__(self, file_or_path, keep_ambiguous=True):
|
||||
if isinstance(file_or_path, str):
|
||||
with open(file_or_path, encoding="latin-1") as f:
|
||||
entries = _parse_cmudict(f)
|
||||
else:
|
||||
entries = _parse_cmudict(file_or_path)
|
||||
if not keep_ambiguous:
|
||||
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
||||
self._entries = entries
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self._entries)
|
||||
|
||||
|
||||
def lookup(self, word):
|
||||
"""Returns list of ARPAbet pronunciations of the given word."""
|
||||
return self._entries.get(word.upper())
|
||||
|
||||
|
||||
|
||||
_alt_re = re.compile(r"\([0-9]+\)")
|
||||
|
||||
|
||||
def _parse_cmudict(file):
|
||||
cmudict = {}
|
||||
for line in file:
|
||||
if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
|
||||
parts = line.split(" ")
|
||||
word = re.sub(_alt_re, "", parts[0])
|
||||
pronunciation = _get_pronunciation(parts[1])
|
||||
if pronunciation:
|
||||
if word in cmudict:
|
||||
cmudict[word].append(pronunciation)
|
||||
else:
|
||||
cmudict[word] = [pronunciation]
|
||||
return cmudict
|
||||
|
||||
|
||||
def _get_pronunciation(s):
|
||||
parts = s.strip().split(" ")
|
||||
for part in parts:
|
||||
if part not in _valid_symbol_set:
|
||||
return None
|
||||
return " ".join(parts)
|
||||
88
models/synthesizer/utils/cleaners.py
Normal file
88
models/synthesizer/utils/cleaners.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Cleaners are transformations that run over the input text at both training and eval time.
|
||||
|
||||
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
||||
hyperparameter. Some cleaners are English-specific. You"ll typically want to use:
|
||||
1. "english_cleaners" for English text
|
||||
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
||||
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
||||
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
||||
the symbols in symbols.py to match your data).
|
||||
"""
|
||||
|
||||
import re
|
||||
from unidecode import unidecode
|
||||
from .numbers import normalize_numbers
|
||||
|
||||
# Regular expression matching whitespace:
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = [(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
("mrs", "misess"),
|
||||
("mr", "mister"),
|
||||
("dr", "doctor"),
|
||||
("st", "saint"),
|
||||
("co", "company"),
|
||||
("jr", "junior"),
|
||||
("maj", "major"),
|
||||
("gen", "general"),
|
||||
("drs", "doctors"),
|
||||
("rev", "reverend"),
|
||||
("lt", "lieutenant"),
|
||||
("hon", "honorable"),
|
||||
("sgt", "sergeant"),
|
||||
("capt", "captain"),
|
||||
("esq", "esquire"),
|
||||
("ltd", "limited"),
|
||||
("col", "colonel"),
|
||||
("ft", "fort"),
|
||||
]]
|
||||
|
||||
|
||||
def expand_abbreviations(text):
|
||||
for regex, replacement in _abbreviations:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def expand_numbers(text):
|
||||
return normalize_numbers(text)
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
"""lowercase input tokens."""
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, " ", text)
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
return unidecode(text)
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def transliteration_cleaners(text):
|
||||
"""Pipeline for non-English text that transliterates to ASCII."""
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def english_cleaners(text):
|
||||
"""Pipeline for English text, including number and abbreviation expansion."""
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_numbers(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
68
models/synthesizer/utils/numbers.py
Normal file
68
models/synthesizer/utils/numbers.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import re
|
||||
import inflect
|
||||
|
||||
_inflect = inflect.engine()
|
||||
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
||||
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
||||
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
||||
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
||||
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
||||
_number_re = re.compile(r"[0-9]+")
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
return m.group(1).replace(",", "")
|
||||
|
||||
|
||||
def _expand_decimal_point(m):
|
||||
return m.group(1).replace(".", " point ")
|
||||
|
||||
|
||||
def _expand_dollars(m):
|
||||
match = m.group(1)
|
||||
parts = match.split(".")
|
||||
if len(parts) > 2:
|
||||
return match + " dollars" # Unexpected format
|
||||
dollars = int(parts[0]) if parts[0] else 0
|
||||
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||
if dollars and cents:
|
||||
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||
cent_unit = "cent" if cents == 1 else "cents"
|
||||
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
||||
elif dollars:
|
||||
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||
return "%s %s" % (dollars, dollar_unit)
|
||||
elif cents:
|
||||
cent_unit = "cent" if cents == 1 else "cents"
|
||||
return "%s %s" % (cents, cent_unit)
|
||||
else:
|
||||
return "zero dollars"
|
||||
|
||||
|
||||
def _expand_ordinal(m):
|
||||
return _inflect.number_to_words(m.group(0))
|
||||
|
||||
|
||||
def _expand_number(m):
|
||||
num = int(m.group(0))
|
||||
if num > 1000 and num < 3000:
|
||||
if num == 2000:
|
||||
return "two thousand"
|
||||
elif num > 2000 and num < 2010:
|
||||
return "two thousand " + _inflect.number_to_words(num % 100)
|
||||
elif num % 100 == 0:
|
||||
return _inflect.number_to_words(num // 100) + " hundred"
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword="")
|
||||
|
||||
|
||||
def normalize_numbers(text):
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
text = re.sub(_pounds_re, r"\1 pounds", text)
|
||||
text = re.sub(_dollars_re, _expand_dollars, text)
|
||||
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
||||
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||
text = re.sub(_number_re, _expand_number, text)
|
||||
return text
|
||||
115
models/synthesizer/utils/plot.py
Normal file
115
models/synthesizer/utils/plot.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
def split_title_line(title_text, max_words=5):
|
||||
"""
|
||||
A function that splits any string based on specific character
|
||||
(returning it with the string), with maximum number of words on it
|
||||
"""
|
||||
seq = title_text.split()
|
||||
return "\n".join([" ".join(seq[i:i + max_words]) for i in range(0, len(seq), max_words)])
|
||||
|
||||
def plot_alignment(alignment, path, title=None, split_title=False, max_len=None):
|
||||
if max_len is not None:
|
||||
alignment = alignment[:, :max_len]
|
||||
|
||||
fig = plt.figure(figsize=(8, 6))
|
||||
ax = fig.add_subplot(111)
|
||||
|
||||
im = ax.imshow(
|
||||
alignment,
|
||||
aspect="auto",
|
||||
origin="lower",
|
||||
interpolation="none")
|
||||
fig.colorbar(im, ax=ax)
|
||||
xlabel = "Decoder timestep"
|
||||
|
||||
if split_title:
|
||||
title = split_title_line(title)
|
||||
|
||||
plt.xlabel(xlabel)
|
||||
plt.title(title)
|
||||
plt.ylabel("Encoder timestep")
|
||||
plt.tight_layout()
|
||||
plt.savefig(path, format="png")
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_spectrogram(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False):
|
||||
if max_len is not None:
|
||||
target_spectrogram = target_spectrogram[:max_len]
|
||||
pred_spectrogram = pred_spectrogram[:max_len]
|
||||
|
||||
if split_title:
|
||||
title = split_title_line(title)
|
||||
|
||||
fig = plt.figure(figsize=(10, 8))
|
||||
# Set common labels
|
||||
fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16)
|
||||
|
||||
#target spectrogram subplot
|
||||
if target_spectrogram is not None:
|
||||
ax1 = fig.add_subplot(311)
|
||||
ax2 = fig.add_subplot(312)
|
||||
|
||||
if auto_aspect:
|
||||
im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none")
|
||||
else:
|
||||
im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none")
|
||||
ax1.set_title("Target Mel-Spectrogram")
|
||||
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
|
||||
ax2.set_title("Predicted Mel-Spectrogram")
|
||||
else:
|
||||
ax2 = fig.add_subplot(211)
|
||||
|
||||
if auto_aspect:
|
||||
im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none")
|
||||
else:
|
||||
im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none")
|
||||
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(path, format="png")
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_spectrogram_and_trace(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False, sw=None, step=0):
|
||||
if max_len is not None:
|
||||
target_spectrogram = target_spectrogram[:max_len]
|
||||
pred_spectrogram = pred_spectrogram[:max_len]
|
||||
|
||||
if split_title:
|
||||
title = split_title_line(title)
|
||||
|
||||
fig = plt.figure(figsize=(10, 8))
|
||||
# Set common labels
|
||||
fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16)
|
||||
|
||||
#target spectrogram subplot
|
||||
if target_spectrogram is not None:
|
||||
ax1 = fig.add_subplot(311)
|
||||
ax2 = fig.add_subplot(312)
|
||||
|
||||
if auto_aspect:
|
||||
im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none")
|
||||
else:
|
||||
im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none")
|
||||
ax1.set_title("Target Mel-Spectrogram")
|
||||
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
|
||||
ax2.set_title("Predicted Mel-Spectrogram")
|
||||
else:
|
||||
ax2 = fig.add_subplot(211)
|
||||
|
||||
if auto_aspect:
|
||||
im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none")
|
||||
else:
|
||||
im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none")
|
||||
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(path, format="png")
|
||||
sw.add_figure("spectrogram", fig, step)
|
||||
plt.close()
|
||||
18
models/synthesizer/utils/symbols.py
Normal file
18
models/synthesizer/utils/symbols.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Defines the set of symbols used in text input to the model.
|
||||
|
||||
The default is a set of ASCII characters that works well for English or text that has been run
|
||||
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
|
||||
"""
|
||||
# from . import cmudict
|
||||
|
||||
_pad = "_"
|
||||
_eos = "~"
|
||||
_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890!\'(),-.:;? '
|
||||
|
||||
#_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz12340!\'(),-.:;? ' # use this old one if you want to train old model
|
||||
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
||||
#_arpabet = ["@' + s for s in cmudict.valid_symbols]
|
||||
|
||||
# Export all symbols:
|
||||
symbols = [_pad, _eos] + list(_characters) #+ _arpabet
|
||||
74
models/synthesizer/utils/text.py
Normal file
74
models/synthesizer/utils/text.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from .symbols import symbols
|
||||
from . import cleaners
|
||||
import re
|
||||
|
||||
# Mappings from symbol to numeric ID and vice versa:
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
||||
|
||||
# Regular expression matching text enclosed in curly braces:
|
||||
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
|
||||
|
||||
|
||||
def text_to_sequence(text, cleaner_names):
|
||||
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
|
||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
||||
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
||||
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
cleaner_names: names of the cleaner functions to run the text through
|
||||
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
"""
|
||||
sequence = []
|
||||
|
||||
# Check for curly braces and treat their contents as ARPAbet:
|
||||
while len(text):
|
||||
m = _curly_re.match(text)
|
||||
if not m:
|
||||
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
|
||||
break
|
||||
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
|
||||
sequence += _arpabet_to_sequence(m.group(2))
|
||||
text = m.group(3)
|
||||
|
||||
# Append EOS token
|
||||
sequence.append(_symbol_to_id["~"])
|
||||
return sequence
|
||||
|
||||
|
||||
def sequence_to_text(sequence):
|
||||
"""Converts a sequence of IDs back to a string"""
|
||||
result = ""
|
||||
for symbol_id in sequence:
|
||||
if symbol_id in _id_to_symbol:
|
||||
s = _id_to_symbol[symbol_id]
|
||||
# Enclose ARPAbet back in curly braces:
|
||||
if len(s) > 1 and s[0] == "@":
|
||||
s = "{%s}" % s[1:]
|
||||
result += s
|
||||
return result.replace("}{", " ")
|
||||
|
||||
|
||||
def _clean_text(text, cleaner_names):
|
||||
for name in cleaner_names:
|
||||
cleaner = getattr(cleaners, name)
|
||||
if not cleaner:
|
||||
raise Exception("Unknown cleaner: %s" % name)
|
||||
text = cleaner(text)
|
||||
return text
|
||||
|
||||
|
||||
def _symbols_to_sequence(symbols):
|
||||
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
|
||||
|
||||
|
||||
def _arpabet_to_sequence(text):
|
||||
return _symbols_to_sequence(["@" + s for s in text.split()])
|
||||
|
||||
|
||||
def _should_keep_symbol(s):
|
||||
return s in _symbol_to_id and s not in ("_", "~")
|
||||
22
models/vocoder/LICENSE.txt
Normal file
22
models/vocoder/LICENSE.txt
Normal file
@@ -0,0 +1,22 @@
|
||||
MIT License
|
||||
|
||||
Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
|
||||
Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
1
models/vocoder/__init__.py
Normal file
1
models/vocoder/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
#
|
||||
128
models/vocoder/display.py
Normal file
128
models/vocoder/display.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import time
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
|
||||
def progbar(i, n, size=16):
|
||||
done = (i * size) // n
|
||||
bar = ''
|
||||
for i in range(size):
|
||||
bar += '█' if i <= done else '░'
|
||||
return bar
|
||||
|
||||
|
||||
def stream(message) :
|
||||
try:
|
||||
sys.stdout.write("\r{%s}" % message)
|
||||
except:
|
||||
#Remove non-ASCII characters from message
|
||||
message = ''.join(i for i in message if ord(i)<128)
|
||||
sys.stdout.write("\r{%s}" % message)
|
||||
|
||||
|
||||
def simple_table(item_tuples) :
|
||||
|
||||
border_pattern = '+---------------------------------------'
|
||||
whitespace = ' '
|
||||
|
||||
headings, cells, = [], []
|
||||
|
||||
for item in item_tuples :
|
||||
|
||||
heading, cell = str(item[0]), str(item[1])
|
||||
|
||||
pad_head = True if len(heading) < len(cell) else False
|
||||
|
||||
pad = abs(len(heading) - len(cell))
|
||||
pad = whitespace[:pad]
|
||||
|
||||
pad_left = pad[:len(pad)//2]
|
||||
pad_right = pad[len(pad)//2:]
|
||||
|
||||
if pad_head :
|
||||
heading = pad_left + heading + pad_right
|
||||
else :
|
||||
cell = pad_left + cell + pad_right
|
||||
|
||||
headings += [heading]
|
||||
cells += [cell]
|
||||
|
||||
border, head, body = '', '', ''
|
||||
|
||||
for i in range(len(item_tuples)) :
|
||||
|
||||
temp_head = f'| {headings[i]} '
|
||||
temp_body = f'| {cells[i]} '
|
||||
|
||||
border += border_pattern[:len(temp_head)]
|
||||
head += temp_head
|
||||
body += temp_body
|
||||
|
||||
if i == len(item_tuples) - 1 :
|
||||
head += '|'
|
||||
body += '|'
|
||||
border += '+'
|
||||
|
||||
print(border)
|
||||
print(head)
|
||||
print(border)
|
||||
print(body)
|
||||
print(border)
|
||||
print(' ')
|
||||
|
||||
|
||||
def time_since(started) :
|
||||
elapsed = time.time() - started
|
||||
m = int(elapsed // 60)
|
||||
s = int(elapsed % 60)
|
||||
if m >= 60 :
|
||||
h = int(m // 60)
|
||||
m = m % 60
|
||||
return f'{h}h {m}m {s}s'
|
||||
else :
|
||||
return f'{m}m {s}s'
|
||||
|
||||
|
||||
def save_attention(attn, path) :
|
||||
fig = plt.figure(figsize=(12, 6))
|
||||
plt.imshow(attn.T, interpolation='nearest', aspect='auto')
|
||||
fig.savefig(f'{path}.png', bbox_inches='tight')
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def save_and_trace_attention(attn, path, sw, step):
|
||||
fig = plt.figure(figsize=(12, 6))
|
||||
plt.imshow(attn.T, interpolation='nearest', aspect='auto')
|
||||
fig.savefig(f'{path}.png', bbox_inches='tight')
|
||||
sw.add_figure('attention', fig, step)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def save_spectrogram(M, path, length=None) :
|
||||
M = np.flip(M, axis=0)
|
||||
if length : M = M[:, :length]
|
||||
fig = plt.figure(figsize=(12, 6))
|
||||
plt.imshow(M, interpolation='nearest', aspect='auto')
|
||||
fig.savefig(f'{path}.png', bbox_inches='tight')
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def plot(array) :
|
||||
fig = plt.figure(figsize=(30, 5))
|
||||
ax = fig.add_subplot(111)
|
||||
ax.xaxis.label.set_color('grey')
|
||||
ax.yaxis.label.set_color('grey')
|
||||
ax.xaxis.label.set_fontsize(23)
|
||||
ax.yaxis.label.set_fontsize(23)
|
||||
ax.tick_params(axis='x', colors='grey', labelsize=23)
|
||||
ax.tick_params(axis='y', colors='grey', labelsize=23)
|
||||
plt.plot(array)
|
||||
|
||||
|
||||
def plot_spec(M) :
|
||||
M = np.flip(M, axis=0)
|
||||
plt.figure(figsize=(18,4))
|
||||
plt.imshow(M, interpolation='nearest', aspect='auto')
|
||||
plt.show()
|
||||
|
||||
132
models/vocoder/distribution.py
Normal file
132
models/vocoder/distribution.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def log_sum_exp(x):
|
||||
""" numerically stable log_sum_exp implementation that prevents overflow """
|
||||
# TF ordering
|
||||
axis = len(x.size()) - 1
|
||||
m, _ = torch.max(x, dim=axis)
|
||||
m2, _ = torch.max(x, dim=axis, keepdim=True)
|
||||
return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
|
||||
|
||||
|
||||
# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
|
||||
def discretized_mix_logistic_loss(y_hat, y, num_classes=65536,
|
||||
log_scale_min=None, reduce=True):
|
||||
if log_scale_min is None:
|
||||
log_scale_min = float(np.log(1e-14))
|
||||
y_hat = y_hat.permute(0,2,1)
|
||||
assert y_hat.dim() == 3
|
||||
assert y_hat.size(1) % 3 == 0
|
||||
nr_mix = y_hat.size(1) // 3
|
||||
|
||||
# (B x T x C)
|
||||
y_hat = y_hat.transpose(1, 2)
|
||||
|
||||
# unpack parameters. (B, T, num_mixtures) x 3
|
||||
logit_probs = y_hat[:, :, :nr_mix]
|
||||
means = y_hat[:, :, nr_mix:2 * nr_mix]
|
||||
log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min)
|
||||
|
||||
# B x T x 1 -> B x T x num_mixtures
|
||||
y = y.expand_as(means)
|
||||
|
||||
centered_y = y - means
|
||||
inv_stdv = torch.exp(-log_scales)
|
||||
plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1))
|
||||
cdf_plus = torch.sigmoid(plus_in)
|
||||
min_in = inv_stdv * (centered_y - 1. / (num_classes - 1))
|
||||
cdf_min = torch.sigmoid(min_in)
|
||||
|
||||
# log probability for edge case of 0 (before scaling)
|
||||
# equivalent: torch.log(F.sigmoid(plus_in))
|
||||
log_cdf_plus = plus_in - F.softplus(plus_in)
|
||||
|
||||
# log probability for edge case of 255 (before scaling)
|
||||
# equivalent: (1 - F.sigmoid(min_in)).log()
|
||||
log_one_minus_cdf_min = -F.softplus(min_in)
|
||||
|
||||
# probability for all other cases
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
|
||||
mid_in = inv_stdv * centered_y
|
||||
# log probability in the center of the bin, to be used in extreme cases
|
||||
# (not actually used in our code)
|
||||
log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
|
||||
|
||||
# tf equivalent
|
||||
"""
|
||||
log_probs = tf.where(x < -0.999, log_cdf_plus,
|
||||
tf.where(x > 0.999, log_one_minus_cdf_min,
|
||||
tf.where(cdf_delta > 1e-5,
|
||||
tf.log(tf.maximum(cdf_delta, 1e-12)),
|
||||
log_pdf_mid - np.log(127.5))))
|
||||
"""
|
||||
# TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
|
||||
# for num_classes=65536 case? 1e-7? not sure..
|
||||
inner_inner_cond = (cdf_delta > 1e-5).float()
|
||||
|
||||
inner_inner_out = inner_inner_cond * \
|
||||
torch.log(torch.clamp(cdf_delta, min=1e-12)) + \
|
||||
(1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
|
||||
inner_cond = (y > 0.999).float()
|
||||
inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
|
||||
cond = (y < -0.999).float()
|
||||
log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
|
||||
|
||||
log_probs = log_probs + F.log_softmax(logit_probs, -1)
|
||||
|
||||
if reduce:
|
||||
return -torch.mean(log_sum_exp(log_probs))
|
||||
else:
|
||||
return -log_sum_exp(log_probs).unsqueeze(-1)
|
||||
|
||||
|
||||
def sample_from_discretized_mix_logistic(y, log_scale_min=None):
|
||||
"""
|
||||
Sample from discretized mixture of logistic distributions
|
||||
Args:
|
||||
y (Tensor): B x C x T
|
||||
log_scale_min (float): Log scale minimum value
|
||||
Returns:
|
||||
Tensor: sample in range of [-1, 1].
|
||||
"""
|
||||
if log_scale_min is None:
|
||||
log_scale_min = float(np.log(1e-14))
|
||||
assert y.size(1) % 3 == 0
|
||||
nr_mix = y.size(1) // 3
|
||||
|
||||
# B x T x C
|
||||
y = y.transpose(1, 2)
|
||||
logit_probs = y[:, :, :nr_mix]
|
||||
|
||||
# sample mixture indicator from softmax
|
||||
temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
|
||||
temp = logit_probs.data - torch.log(- torch.log(temp))
|
||||
_, argmax = temp.max(dim=-1)
|
||||
|
||||
# (B, T) -> (B, T, nr_mix)
|
||||
one_hot = to_one_hot(argmax, nr_mix)
|
||||
# select logistic parameters
|
||||
means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1)
|
||||
log_scales = torch.clamp(torch.sum(
|
||||
y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min)
|
||||
# sample from logistic & clip to interval
|
||||
# we don't actually round to the nearest 8bit value when sampling
|
||||
u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
|
||||
x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
|
||||
|
||||
x = torch.clamp(torch.clamp(x, min=-1.), max=1.)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def to_one_hot(tensor, n, fill_with=1.):
|
||||
# we perform one hot encore with respect to the last axis
|
||||
one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
|
||||
if tensor.is_cuda:
|
||||
one_hot = one_hot.cuda()
|
||||
one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
|
||||
return one_hot
|
||||
129
models/vocoder/fregan/.gitignore
vendored
Normal file
129
models/vocoder/fregan/.gitignore
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
21
models/vocoder/fregan/LICENSE
Normal file
21
models/vocoder/fregan/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2021 Rishikesh (ऋषिकेश)
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
1
models/vocoder/fregan/__init__.py
Normal file
1
models/vocoder/fregan/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
#
|
||||
42
models/vocoder/fregan/config.json
Normal file
42
models/vocoder/fregan/config.json
Normal file
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"resblock": "1",
|
||||
"num_gpus": 0,
|
||||
"batch_size": 16,
|
||||
"learning_rate": 0.0002,
|
||||
"adam_b1": 0.8,
|
||||
"adam_b2": 0.99,
|
||||
"lr_decay": 0.999,
|
||||
"seed": 1234,
|
||||
"disc_start_step":0,
|
||||
|
||||
|
||||
"upsample_rates": [5,5,2,2,2],
|
||||
"upsample_kernel_sizes": [10,10,4,4,4],
|
||||
"upsample_initial_channel": 512,
|
||||
"resblock_kernel_sizes": [3,7,11],
|
||||
"resblock_dilation_sizes": [[1, 3, 5, 7], [1,3,5,7], [1,3,5,7]],
|
||||
|
||||
"segment_size": 6400,
|
||||
"num_mels": 80,
|
||||
"num_freq": 1025,
|
||||
"n_fft": 1024,
|
||||
"hop_size": 200,
|
||||
"win_size": 800,
|
||||
|
||||
"sampling_rate": 16000,
|
||||
|
||||
"fmin": 0,
|
||||
"fmax": 7600,
|
||||
"fmax_for_loss": null,
|
||||
|
||||
"num_workers": 4,
|
||||
|
||||
"dist_config": {
|
||||
"dist_backend": "nccl",
|
||||
"dist_url": "tcp://localhost:54321",
|
||||
"world_size": 1
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
303
models/vocoder/fregan/discriminator.py
Normal file
303
models/vocoder/fregan/discriminator.py
Normal file
@@ -0,0 +1,303 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from torch.nn import Conv1d, AvgPool1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, spectral_norm
|
||||
from models.vocoder.fregan.utils import get_padding
|
||||
from models.vocoder.fregan.stft_loss import stft
|
||||
from models.vocoder.fregan.dwt import DWT_1D
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
|
||||
class SpecDiscriminator(nn.Module):
|
||||
"""docstring for Discriminator."""
|
||||
|
||||
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
|
||||
super(SpecDiscriminator, self).__init__()
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.fft_size = fft_size
|
||||
self.shift_size = shift_size
|
||||
self.win_length = win_length
|
||||
self.window = getattr(torch, window)(win_length)
|
||||
self.discriminators = nn.ModuleList([
|
||||
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1,1), padding=(1, 1))),
|
||||
])
|
||||
|
||||
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
|
||||
|
||||
def forward(self, y):
|
||||
|
||||
fmap = []
|
||||
with torch.no_grad():
|
||||
y = y.squeeze(1)
|
||||
y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.get_device()))
|
||||
y = y.unsqueeze(1)
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y = d(y)
|
||||
y = F.leaky_relu(y, LRELU_SLOPE)
|
||||
fmap.append(y)
|
||||
|
||||
y = self.out(y)
|
||||
fmap.append(y)
|
||||
|
||||
return torch.flatten(y, 1, -1), fmap
|
||||
|
||||
class MultiResSpecDiscriminator(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
fft_sizes=[1024, 2048, 512],
|
||||
hop_sizes=[120, 240, 50],
|
||||
win_lengths=[600, 1200, 240],
|
||||
window="hann_window"):
|
||||
|
||||
super(MultiResSpecDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList([
|
||||
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
|
||||
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
|
||||
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)
|
||||
])
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super(DiscriminatorP, self).__init__()
|
||||
self.period = period
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.dwt1d = DWT_1D()
|
||||
self.dwt_conv1 = norm_f(Conv1d(2, 1, 1))
|
||||
self.dwt_proj1 = norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0)))
|
||||
self.dwt_conv2 = norm_f(Conv1d(4, 1, 1))
|
||||
self.dwt_proj2 = norm_f(Conv2d(1, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0)))
|
||||
self.dwt_conv3 = norm_f(Conv1d(8, 1, 1))
|
||||
self.dwt_proj3 = norm_f(Conv2d(1, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0)))
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||
])
|
||||
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# DWT 1
|
||||
x_d1_high1, x_d1_low1 = self.dwt1d(x)
|
||||
x_d1 = self.dwt_conv1(torch.cat([x_d1_high1, x_d1_low1], dim=1))
|
||||
# 1d to 2d
|
||||
b, c, t = x_d1.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x_d1 = F.pad(x_d1, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x_d1 = x_d1.view(b, c, t // self.period, self.period)
|
||||
|
||||
x_d1 = self.dwt_proj1(x_d1)
|
||||
|
||||
# DWT 2
|
||||
x_d2_high1, x_d2_low1 = self.dwt1d(x_d1_high1)
|
||||
x_d2_high2, x_d2_low2 = self.dwt1d(x_d1_low1)
|
||||
x_d2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1))
|
||||
# 1d to 2d
|
||||
b, c, t = x_d2.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x_d2 = F.pad(x_d2, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x_d2 = x_d2.view(b, c, t // self.period, self.period)
|
||||
|
||||
x_d2 = self.dwt_proj2(x_d2)
|
||||
|
||||
# DWT 3
|
||||
|
||||
x_d3_high1, x_d3_low1 = self.dwt1d(x_d2_high1)
|
||||
x_d3_high2, x_d3_low2 = self.dwt1d(x_d2_low1)
|
||||
x_d3_high3, x_d3_low3 = self.dwt1d(x_d2_high2)
|
||||
x_d3_high4, x_d3_low4 = self.dwt1d(x_d2_low2)
|
||||
x_d3 = self.dwt_conv3(
|
||||
torch.cat([x_d3_high1, x_d3_low1, x_d3_high2, x_d3_low2, x_d3_high3, x_d3_low3, x_d3_high4, x_d3_low4],
|
||||
dim=1))
|
||||
# 1d to 2d
|
||||
b, c, t = x_d3.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x_d3 = F.pad(x_d3, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x_d3 = x_d3.view(b, c, t // self.period, self.period)
|
||||
|
||||
x_d3 = self.dwt_proj3(x_d3)
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
i = 0
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
|
||||
fmap.append(x)
|
||||
if i == 0:
|
||||
x = torch.cat([x, x_d1], dim=2)
|
||||
elif i == 1:
|
||||
x = torch.cat([x, x_d2], dim=2)
|
||||
elif i == 2:
|
||||
x = torch.cat([x, x_d3], dim=2)
|
||||
else:
|
||||
x = x
|
||||
i = i + 1
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class ResWiseMultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(ResWiseMultiPeriodDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList([
|
||||
DiscriminatorP(2),
|
||||
DiscriminatorP(3),
|
||||
DiscriminatorP(5),
|
||||
DiscriminatorP(7),
|
||||
DiscriminatorP(11),
|
||||
])
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(DiscriminatorS, self).__init__()
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.dwt1d = DWT_1D()
|
||||
self.dwt_conv1 = norm_f(Conv1d(2, 128, 15, 1, padding=7))
|
||||
self.dwt_conv2 = norm_f(Conv1d(4, 128, 41, 2, padding=20))
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
||||
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
||||
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
||||
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
])
|
||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# DWT 1
|
||||
x_d1_high1, x_d1_low1 = self.dwt1d(x)
|
||||
x_d1 = self.dwt_conv1(torch.cat([x_d1_high1, x_d1_low1], dim=1))
|
||||
|
||||
# DWT 2
|
||||
x_d2_high1, x_d2_low1 = self.dwt1d(x_d1_high1)
|
||||
x_d2_high2, x_d2_low2 = self.dwt1d(x_d1_low1)
|
||||
x_d2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1))
|
||||
|
||||
i = 0
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
if i == 0:
|
||||
x = torch.cat([x, x_d1], dim=2)
|
||||
if i == 1:
|
||||
x = torch.cat([x, x_d2], dim=2)
|
||||
i = i + 1
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class ResWiseMultiScaleDiscriminator(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(ResWiseMultiScaleDiscriminator, self).__init__()
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.dwt1d = DWT_1D()
|
||||
self.dwt_conv1 = norm_f(Conv1d(2, 1, 1))
|
||||
self.dwt_conv2 = norm_f(Conv1d(4, 1, 1))
|
||||
self.discriminators = nn.ModuleList([
|
||||
DiscriminatorS(use_spectral_norm=True),
|
||||
DiscriminatorS(),
|
||||
DiscriminatorS(),
|
||||
])
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
# DWT 1
|
||||
y_hi, y_lo = self.dwt1d(y)
|
||||
y_1 = self.dwt_conv1(torch.cat([y_hi, y_lo], dim=1))
|
||||
x_d1_high1, x_d1_low1 = self.dwt1d(y_hat)
|
||||
y_hat_1 = self.dwt_conv1(torch.cat([x_d1_high1, x_d1_low1], dim=1))
|
||||
|
||||
# DWT 2
|
||||
x_d2_high1, x_d2_low1 = self.dwt1d(y_hi)
|
||||
x_d2_high2, x_d2_low2 = self.dwt1d(y_lo)
|
||||
y_2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1))
|
||||
|
||||
x_d2_high1, x_d2_low1 = self.dwt1d(x_d1_high1)
|
||||
x_d2_high2, x_d2_low2 = self.dwt1d(x_d1_low1)
|
||||
y_hat_2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1))
|
||||
|
||||
for i, d in enumerate(self.discriminators):
|
||||
|
||||
if i == 1:
|
||||
y = y_1
|
||||
y_hat = y_hat_1
|
||||
if i == 2:
|
||||
y = y_2
|
||||
y_hat = y_hat_2
|
||||
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
76
models/vocoder/fregan/dwt.py
Normal file
76
models/vocoder/fregan/dwt.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright (c) 2019, Adobe Inc. All rights reserved.
|
||||
#
|
||||
# This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
|
||||
# 4.0 International Public License. To view a copy of this license, visit
|
||||
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
|
||||
|
||||
# DWT code borrow from https://github.com/LiQiufu/WaveSNet/blob/12cb9d24208c3d26917bf953618c30f0c6b0f03d/DWT_IDWT/DWT_IDWT_layer.py
|
||||
|
||||
|
||||
import pywt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = ['DWT_1D']
|
||||
Pad_Mode = ['constant', 'reflect', 'replicate', 'circular']
|
||||
|
||||
|
||||
class DWT_1D(nn.Module):
|
||||
def __init__(self, pad_type='reflect', wavename='haar',
|
||||
stride=2, in_channels=1, out_channels=None, groups=None,
|
||||
kernel_size=None, trainable=False):
|
||||
|
||||
super(DWT_1D, self).__init__()
|
||||
self.trainable = trainable
|
||||
self.kernel_size = kernel_size
|
||||
if not self.trainable:
|
||||
assert self.kernel_size == None
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = self.in_channels if out_channels == None else out_channels
|
||||
self.groups = self.in_channels if groups == None else groups
|
||||
assert isinstance(self.groups, int) and self.in_channels % self.groups == 0
|
||||
self.stride = stride
|
||||
assert self.stride == 2
|
||||
self.wavename = wavename
|
||||
self.pad_type = pad_type
|
||||
assert self.pad_type in Pad_Mode
|
||||
self.get_filters()
|
||||
self.initialization()
|
||||
|
||||
def get_filters(self):
|
||||
wavelet = pywt.Wavelet(self.wavename)
|
||||
band_low = torch.tensor(wavelet.rec_lo)
|
||||
band_high = torch.tensor(wavelet.rec_hi)
|
||||
length_band = band_low.size()[0]
|
||||
self.kernel_size = length_band if self.kernel_size == None else self.kernel_size
|
||||
assert self.kernel_size >= length_band
|
||||
a = (self.kernel_size - length_band) // 2
|
||||
b = - (self.kernel_size - length_band - a)
|
||||
b = None if b == 0 else b
|
||||
self.filt_low = torch.zeros(self.kernel_size)
|
||||
self.filt_high = torch.zeros(self.kernel_size)
|
||||
self.filt_low[a:b] = band_low
|
||||
self.filt_high[a:b] = band_high
|
||||
|
||||
def initialization(self):
|
||||
self.filter_low = self.filt_low[None, None, :].repeat((self.out_channels, self.in_channels // self.groups, 1))
|
||||
self.filter_high = self.filt_high[None, None, :].repeat((self.out_channels, self.in_channels // self.groups, 1))
|
||||
if torch.cuda.is_available():
|
||||
self.filter_low = self.filter_low.cuda()
|
||||
self.filter_high = self.filter_high.cuda()
|
||||
if self.trainable:
|
||||
self.filter_low = nn.Parameter(self.filter_low)
|
||||
self.filter_high = nn.Parameter(self.filter_high)
|
||||
if self.kernel_size % 2 == 0:
|
||||
self.pad_sizes = [self.kernel_size // 2 - 1, self.kernel_size // 2 - 1]
|
||||
else:
|
||||
self.pad_sizes = [self.kernel_size // 2, self.kernel_size // 2]
|
||||
|
||||
def forward(self, input):
|
||||
assert isinstance(input, torch.Tensor)
|
||||
assert len(input.size()) == 3
|
||||
assert input.size()[1] == self.in_channels
|
||||
input = F.pad(input, pad=self.pad_sizes, mode=self.pad_type)
|
||||
return F.conv1d(input, self.filter_low.to(input.device), stride=self.stride, groups=self.groups), \
|
||||
F.conv1d(input, self.filter_high.to(input.device), stride=self.stride, groups=self.groups)
|
||||
210
models/vocoder/fregan/generator.py
Normal file
210
models/vocoder/fregan/generator.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from models.vocoder.fregan.utils import init_weights, get_padding
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5, 7)):
|
||||
super(ResBlock1, self).__init__()
|
||||
self.h = h
|
||||
self.convs1 = nn.ModuleList([
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[3],
|
||||
padding=get_padding(kernel_size, dilation[3])))
|
||||
])
|
||||
self.convs1.apply(init_weights)
|
||||
|
||||
self.convs2 = nn.ModuleList([
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1)))
|
||||
])
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super(ResBlock2, self).__init__()
|
||||
self.h = h
|
||||
self.convs = nn.ModuleList([
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1])))
|
||||
])
|
||||
self.convs.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class FreGAN(torch.nn.Module):
|
||||
def __init__(self, h, top_k=4):
|
||||
super(FreGAN, self).__init__()
|
||||
self.h = h
|
||||
|
||||
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
self.upsample_rates = h.upsample_rates
|
||||
self.up_kernels = h.upsample_kernel_sizes
|
||||
self.cond_level = self.num_upsamples - top_k
|
||||
self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
|
||||
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
self.cond_up = nn.ModuleList()
|
||||
self.res_output = nn.ModuleList()
|
||||
upsample_ = 1
|
||||
kr = 80
|
||||
|
||||
for i, (u, k) in enumerate(zip(self.upsample_rates, self.up_kernels)):
|
||||
# self.ups.append(weight_norm(
|
||||
# ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)),
|
||||
# k, u, padding=(k - u) // 2)))
|
||||
self.ups.append(weight_norm(ConvTranspose1d(h.upsample_initial_channel//(2**i),
|
||||
h.upsample_initial_channel//(2**(i+1)),
|
||||
k, u, padding=(u//2 + u%2), output_padding=u%2)))
|
||||
|
||||
if i > (self.num_upsamples - top_k):
|
||||
self.res_output.append(
|
||||
nn.Sequential(
|
||||
nn.Upsample(scale_factor=u, mode='nearest'),
|
||||
weight_norm(nn.Conv1d(h.upsample_initial_channel // (2 ** i),
|
||||
h.upsample_initial_channel // (2 ** (i + 1)), 1))
|
||||
)
|
||||
)
|
||||
if i >= (self.num_upsamples - top_k):
|
||||
self.cond_up.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(kr, h.upsample_initial_channel // (2 ** i),
|
||||
self.up_kernels[i - 1], self.upsample_rates[i - 1],
|
||||
padding=(self.upsample_rates[i-1]//2+self.upsample_rates[i-1]%2), output_padding=self.upsample_rates[i-1]%2))
|
||||
)
|
||||
kr = h.upsample_initial_channel // (2 ** i)
|
||||
|
||||
upsample_ *= u
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(h, ch, k, d))
|
||||
|
||||
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
||||
self.ups.apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
self.cond_up.apply(init_weights)
|
||||
self.res_output.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
mel = x
|
||||
x = self.conv_pre(x)
|
||||
output = None
|
||||
for i in range(self.num_upsamples):
|
||||
if i >= self.cond_level:
|
||||
mel = self.cond_up[i - self.cond_level](mel)
|
||||
x += mel
|
||||
if i > self.cond_level:
|
||||
if output is None:
|
||||
output = self.res_output[i - self.cond_level - 1](x)
|
||||
else:
|
||||
output = self.res_output[i - self.cond_level - 1](output)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
if output is not None:
|
||||
output = output + x
|
||||
|
||||
x = F.leaky_relu(output)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
for l in self.cond_up:
|
||||
remove_weight_norm(l)
|
||||
for l in self.res_output:
|
||||
remove_weight_norm(l[1])
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
|
||||
'''
|
||||
to run this, fix
|
||||
from . import ResStack
|
||||
into
|
||||
from res_stack import ResStack
|
||||
'''
|
||||
if __name__ == '__main__':
|
||||
'''
|
||||
torch.Size([3, 80, 10])
|
||||
torch.Size([3, 1, 2000])
|
||||
4527362
|
||||
'''
|
||||
with open('config.json') as f:
|
||||
data = f.read()
|
||||
from utils import AttrDict
|
||||
import json
|
||||
json_config = json.loads(data)
|
||||
h = AttrDict(json_config)
|
||||
model = FreGAN(h)
|
||||
|
||||
c = torch.randn(3, 80, 10) # (B, channels, T).
|
||||
print(c.shape)
|
||||
|
||||
y = model(c) # (B, 1, T ** prod(upsample_scales)
|
||||
print(y.shape)
|
||||
assert y.shape == torch.Size([3, 1, 2560]) # For normal melgan torch.Size([3, 1, 2560])
|
||||
|
||||
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(pytorch_total_params)
|
||||
74
models/vocoder/fregan/inference.py
Normal file
74
models/vocoder/fregan/inference.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
from utils.util import AttrDict
|
||||
from models.vocoder.fregan.generator import FreGAN
|
||||
|
||||
generator = None # type: FreGAN
|
||||
output_sample_rate = None
|
||||
_device = None
|
||||
|
||||
|
||||
def load_checkpoint(filepath, device):
|
||||
assert os.path.isfile(filepath)
|
||||
print("Loading '{}'".format(filepath))
|
||||
checkpoint_dict = torch.load(filepath, map_location=device)
|
||||
print("Complete.")
|
||||
return checkpoint_dict
|
||||
|
||||
|
||||
def load_model(weights_fpath, config_fpath=None, verbose=True):
|
||||
global generator, _device, output_sample_rate
|
||||
|
||||
if verbose:
|
||||
print("Building fregan")
|
||||
|
||||
if config_fpath == None:
|
||||
model_config_fpaths = list(weights_fpath.parent.rglob("*.json"))
|
||||
if len(model_config_fpaths) > 0:
|
||||
config_fpath = model_config_fpaths[0]
|
||||
else:
|
||||
config_fpath = "./vocoder/fregan/config.json"
|
||||
with open(config_fpath) as f:
|
||||
data = f.read()
|
||||
json_config = json.loads(data)
|
||||
h = AttrDict(json_config)
|
||||
output_sample_rate = h.sampling_rate
|
||||
torch.manual_seed(h.seed)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# _model = _model.cuda()
|
||||
_device = torch.device('cuda')
|
||||
else:
|
||||
_device = torch.device('cpu')
|
||||
|
||||
generator = FreGAN(h).to(_device)
|
||||
state_dict_g = load_checkpoint(
|
||||
weights_fpath, _device
|
||||
)
|
||||
generator.load_state_dict(state_dict_g['generator'])
|
||||
generator.eval()
|
||||
generator.remove_weight_norm()
|
||||
|
||||
|
||||
def is_loaded():
|
||||
return generator is not None
|
||||
|
||||
|
||||
def infer_waveform(mel, progress_callback=None):
|
||||
|
||||
if generator is None:
|
||||
raise Exception("Please load fre-gan in memory before using it")
|
||||
|
||||
mel = torch.FloatTensor(mel).to(_device)
|
||||
mel = mel.unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
y_g_hat = generator(mel)
|
||||
audio = y_g_hat.squeeze()
|
||||
audio = audio.cpu().numpy()
|
||||
|
||||
return audio, output_sample_rate
|
||||
|
||||
35
models/vocoder/fregan/loss.py
Normal file
35
models/vocoder/fregan/loss.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
|
||||
|
||||
def feature_loss(fmap_r, fmap_g):
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
return loss*2
|
||||
|
||||
|
||||
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
r_loss = torch.mean((1-dr)**2)
|
||||
g_loss = torch.mean(dg**2)
|
||||
loss += (r_loss + g_loss)
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
|
||||
return loss, r_losses, g_losses
|
||||
|
||||
|
||||
def generator_loss(disc_outputs):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
l = torch.mean((1-dg)**2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
||||
176
models/vocoder/fregan/meldataset.py
Normal file
176
models/vocoder/fregan/meldataset.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
from librosa.util import normalize
|
||||
from scipy.io.wavfile import read
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
||||
def load_wav(full_path):
|
||||
sampling_rate, data = read(full_path)
|
||||
return data, sampling_rate
|
||||
|
||||
|
||||
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression(x, C=1):
|
||||
return np.exp(x) / C
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes):
|
||||
output = dynamic_range_compression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(magnitudes):
|
||||
output = dynamic_range_decompression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
mel_basis = {}
|
||||
hann_window = {}
|
||||
|
||||
|
||||
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
if torch.min(y) < -1.:
|
||||
print('min value is ', torch.min(y))
|
||||
if torch.max(y) > 1.:
|
||||
print('max value is ', torch.max(y))
|
||||
|
||||
global mel_basis, hann_window
|
||||
if fmax not in mel_basis:
|
||||
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
||||
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
||||
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
||||
|
||||
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
||||
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
||||
|
||||
spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def get_dataset_filelist(a):
|
||||
#with open(a.input_training_file, 'r', encoding='utf-8') as fi:
|
||||
# training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
|
||||
# for x in fi.read().split('\n') if len(x) > 0]
|
||||
|
||||
#with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
|
||||
# validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
|
||||
# for x in fi.read().split('\n') if len(x) > 0]
|
||||
files = os.listdir(a.input_wavs_dir)
|
||||
random.shuffle(files)
|
||||
files = [os.path.join(a.input_wavs_dir, f) for f in files]
|
||||
training_files = files[: -int(len(files) * 0.05)]
|
||||
validation_files = files[-int(len(files) * 0.05):]
|
||||
return training_files, validation_files
|
||||
|
||||
|
||||
class MelDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, training_files, segment_size, n_fft, num_mels,
|
||||
hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
|
||||
device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
|
||||
self.audio_files = training_files
|
||||
random.seed(1234)
|
||||
if shuffle:
|
||||
random.shuffle(self.audio_files)
|
||||
self.segment_size = segment_size
|
||||
self.sampling_rate = sampling_rate
|
||||
self.split = split
|
||||
self.n_fft = n_fft
|
||||
self.num_mels = num_mels
|
||||
self.hop_size = hop_size
|
||||
self.win_size = win_size
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.fmax_loss = fmax_loss
|
||||
self.cached_wav = None
|
||||
self.n_cache_reuse = n_cache_reuse
|
||||
self._cache_ref_count = 0
|
||||
self.device = device
|
||||
self.fine_tuning = fine_tuning
|
||||
self.base_mels_path = base_mels_path
|
||||
|
||||
def __getitem__(self, index):
|
||||
filename = self.audio_files[index]
|
||||
if self._cache_ref_count == 0:
|
||||
#audio, sampling_rate = load_wav(filename)
|
||||
#audio = audio / MAX_WAV_VALUE
|
||||
audio = np.load(filename)
|
||||
if not self.fine_tuning:
|
||||
audio = normalize(audio) * 0.95
|
||||
self.cached_wav = audio
|
||||
#if sampling_rate != self.sampling_rate:
|
||||
# raise ValueError("{} SR doesn't match target {} SR".format(
|
||||
# sampling_rate, self.sampling_rate))
|
||||
self._cache_ref_count = self.n_cache_reuse
|
||||
else:
|
||||
audio = self.cached_wav
|
||||
self._cache_ref_count -= 1
|
||||
|
||||
audio = torch.FloatTensor(audio)
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
if not self.fine_tuning:
|
||||
if self.split:
|
||||
if audio.size(1) >= self.segment_size:
|
||||
max_audio_start = audio.size(1) - self.segment_size
|
||||
audio_start = random.randint(0, max_audio_start)
|
||||
audio = audio[:, audio_start:audio_start+self.segment_size]
|
||||
else:
|
||||
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
|
||||
|
||||
mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
|
||||
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
|
||||
center=False)
|
||||
else:
|
||||
mel_path = os.path.join(self.base_mels_path, "mel" + "-" + filename.split("/")[-1].split("-")[-1])
|
||||
mel = np.load(mel_path).T
|
||||
#mel = np.load(
|
||||
# os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
|
||||
mel = torch.from_numpy(mel)
|
||||
|
||||
if len(mel.shape) < 3:
|
||||
mel = mel.unsqueeze(0)
|
||||
|
||||
if self.split:
|
||||
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
|
||||
|
||||
if audio.size(1) >= self.segment_size:
|
||||
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
|
||||
mel = mel[:, :, mel_start:mel_start + frames_per_seg]
|
||||
audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
|
||||
else:
|
||||
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
|
||||
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
|
||||
|
||||
mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
|
||||
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
|
||||
center=False)
|
||||
|
||||
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audio_files)
|
||||
201
models/vocoder/fregan/modules.py
Normal file
201
models/vocoder/fregan/modules.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class KernelPredictor(torch.nn.Module):
|
||||
''' Kernel predictor for the location-variable convolutions
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
cond_channels,
|
||||
conv_in_channels,
|
||||
conv_out_channels,
|
||||
conv_layers,
|
||||
conv_kernel_size=3,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
kpnet_nonlinear_activation="LeakyReLU",
|
||||
kpnet_nonlinear_activation_params={"negative_slope": 0.1}
|
||||
):
|
||||
'''
|
||||
Args:
|
||||
cond_channels (int): number of channel for the conditioning sequence,
|
||||
conv_in_channels (int): number of channel for the input sequence,
|
||||
conv_out_channels (int): number of channel for the output sequence,
|
||||
conv_layers (int):
|
||||
kpnet_
|
||||
'''
|
||||
super().__init__()
|
||||
|
||||
self.conv_in_channels = conv_in_channels
|
||||
self.conv_out_channels = conv_out_channels
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.conv_layers = conv_layers
|
||||
|
||||
l_w = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers
|
||||
l_b = conv_out_channels * conv_layers
|
||||
|
||||
padding = (kpnet_conv_size - 1) // 2
|
||||
self.input_conv = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=(5 - 1) // 2, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.residual_conv = torch.nn.Sequential(
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.kernel_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_w, kpnet_conv_size,
|
||||
padding=padding, bias=True)
|
||||
self.bias_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_b, kpnet_conv_size, padding=padding,
|
||||
bias=True)
|
||||
|
||||
def forward(self, c):
|
||||
'''
|
||||
Args:
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
Returns:
|
||||
'''
|
||||
batch, cond_channels, cond_length = c.shape
|
||||
|
||||
c = self.input_conv(c)
|
||||
c = c + self.residual_conv(c)
|
||||
k = self.kernel_conv(c)
|
||||
b = self.bias_conv(c)
|
||||
|
||||
kernels = k.contiguous().view(batch,
|
||||
self.conv_layers,
|
||||
self.conv_in_channels,
|
||||
self.conv_out_channels,
|
||||
self.conv_kernel_size,
|
||||
cond_length)
|
||||
bias = b.contiguous().view(batch,
|
||||
self.conv_layers,
|
||||
self.conv_out_channels,
|
||||
cond_length)
|
||||
return kernels, bias
|
||||
|
||||
|
||||
class LVCBlock(torch.nn.Module):
|
||||
''' the location-variable convolutions
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
cond_channels,
|
||||
upsample_ratio,
|
||||
conv_layers=4,
|
||||
conv_kernel_size=3,
|
||||
cond_hop_length=256,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cond_hop_length = cond_hop_length
|
||||
self.conv_layers = conv_layers
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.convs = torch.nn.ModuleList()
|
||||
|
||||
self.upsample = torch.nn.ConvTranspose1d(in_channels, in_channels,
|
||||
kernel_size=upsample_ratio*2, stride=upsample_ratio,
|
||||
padding=upsample_ratio // 2 + upsample_ratio % 2,
|
||||
output_padding=upsample_ratio % 2)
|
||||
|
||||
|
||||
self.kernel_predictor = KernelPredictor(
|
||||
cond_channels=cond_channels,
|
||||
conv_in_channels=in_channels,
|
||||
conv_out_channels=2 * in_channels,
|
||||
conv_layers=conv_layers,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
kpnet_hidden_channels=kpnet_hidden_channels,
|
||||
kpnet_conv_size=kpnet_conv_size,
|
||||
kpnet_dropout=kpnet_dropout
|
||||
)
|
||||
|
||||
|
||||
for i in range(conv_layers):
|
||||
padding = (3 ** i) * int((conv_kernel_size - 1) / 2)
|
||||
conv = torch.nn.Conv1d(in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3 ** i)
|
||||
|
||||
self.convs.append(conv)
|
||||
|
||||
|
||||
def forward(self, x, c):
|
||||
''' forward propagation of the location-variable convolutions.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length)
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
|
||||
Returns:
|
||||
Tensor: the output sequence (batch, in_channels, in_length)
|
||||
'''
|
||||
batch, in_channels, in_length = x.shape
|
||||
|
||||
|
||||
kernels, bias = self.kernel_predictor(c)
|
||||
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
x = self.upsample(x)
|
||||
|
||||
for i in range(self.conv_layers):
|
||||
y = F.leaky_relu(x, 0.2)
|
||||
y = self.convs[i](y)
|
||||
y = F.leaky_relu(y, 0.2)
|
||||
|
||||
k = kernels[:, i, :, :, :, :]
|
||||
b = bias[:, i, :, :]
|
||||
y = self.location_variable_convolution(y, k, b, 1, self.cond_hop_length)
|
||||
x = x + torch.sigmoid(y[:, :in_channels, :]) * torch.tanh(y[:, in_channels:, :])
|
||||
return x
|
||||
|
||||
def location_variable_convolution(self, x, kernel, bias, dilation, hop_size):
|
||||
''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
||||
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length).
|
||||
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
||||
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
||||
dilation (int): the dilation of convolution.
|
||||
hop_size (int): the hop_size of the conditioning sequence.
|
||||
Returns:
|
||||
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
||||
'''
|
||||
batch, in_channels, in_length = x.shape
|
||||
batch, in_channels, out_channels, kernel_size, kernel_length = kernel.shape
|
||||
|
||||
|
||||
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
|
||||
|
||||
padding = dilation * int((kernel_size - 1) / 2)
|
||||
x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding)
|
||||
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
||||
|
||||
if hop_size < dilation:
|
||||
x = F.pad(x, (0, dilation), 'constant', 0)
|
||||
x = x.unfold(3, dilation,
|
||||
dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
||||
x = x[:, :, :, :, :hop_size]
|
||||
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
||||
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
||||
|
||||
o = torch.einsum('bildsk,biokl->bolsd', x, kernel)
|
||||
o = o + bias.unsqueeze(-1).unsqueeze(-1)
|
||||
o = o.contiguous().view(batch, out_channels, -1)
|
||||
return o
|
||||
136
models/vocoder/fregan/stft_loss.py
Normal file
136
models/vocoder/fregan/stft_loss.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Tomoki Hayashi
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
"""STFT-based Loss modules."""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def stft(x, fft_size, hop_size, win_length, window):
|
||||
"""Perform STFT and convert to magnitude spectrogram.
|
||||
Args:
|
||||
x (Tensor): Input signal tensor (B, T).
|
||||
fft_size (int): FFT size.
|
||||
hop_size (int): Hop size.
|
||||
win_length (int): Window length.
|
||||
window (str): Window function type.
|
||||
Returns:
|
||||
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
||||
"""
|
||||
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
|
||||
real = x_stft[..., 0]
|
||||
imag = x_stft[..., 1]
|
||||
|
||||
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
||||
return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
|
||||
|
||||
|
||||
class SpectralConvergengeLoss(torch.nn.Module):
|
||||
"""Spectral convergence loss module."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initilize spectral convergence loss module."""
|
||||
super(SpectralConvergengeLoss, self).__init__()
|
||||
|
||||
def forward(self, x_mag, y_mag):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
||||
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
||||
Returns:
|
||||
Tensor: Spectral convergence loss value.
|
||||
"""
|
||||
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
|
||||
|
||||
|
||||
class LogSTFTMagnitudeLoss(torch.nn.Module):
|
||||
"""Log STFT magnitude loss module."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initilize los STFT magnitude loss module."""
|
||||
super(LogSTFTMagnitudeLoss, self).__init__()
|
||||
|
||||
def forward(self, x_mag, y_mag):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
||||
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
||||
Returns:
|
||||
Tensor: Log STFT magnitude loss value.
|
||||
"""
|
||||
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
|
||||
|
||||
|
||||
class STFTLoss(torch.nn.Module):
|
||||
"""STFT loss module."""
|
||||
|
||||
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
|
||||
"""Initialize STFT loss module."""
|
||||
super(STFTLoss, self).__init__()
|
||||
self.fft_size = fft_size
|
||||
self.shift_size = shift_size
|
||||
self.win_length = win_length
|
||||
self.window = getattr(torch, window)(win_length)
|
||||
self.spectral_convergenge_loss = SpectralConvergengeLoss()
|
||||
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
||||
|
||||
def forward(self, x, y):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x (Tensor): Predicted signal (B, T).
|
||||
y (Tensor): Groundtruth signal (B, T).
|
||||
Returns:
|
||||
Tensor: Spectral convergence loss value.
|
||||
Tensor: Log STFT magnitude loss value.
|
||||
"""
|
||||
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window.to(x.get_device()))
|
||||
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(x.get_device()))
|
||||
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
||||
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
||||
|
||||
return sc_loss, mag_loss
|
||||
|
||||
|
||||
class MultiResolutionSTFTLoss(torch.nn.Module):
|
||||
"""Multi resolution STFT loss module."""
|
||||
|
||||
def __init__(self,
|
||||
fft_sizes=[1024, 2048, 512],
|
||||
hop_sizes=[120, 240, 50],
|
||||
win_lengths=[600, 1200, 240],
|
||||
window="hann_window"):
|
||||
"""Initialize Multi resolution STFT loss module.
|
||||
Args:
|
||||
fft_sizes (list): List of FFT sizes.
|
||||
hop_sizes (list): List of hop sizes.
|
||||
win_lengths (list): List of window lengths.
|
||||
window (str): Window function type.
|
||||
"""
|
||||
super(MultiResolutionSTFTLoss, self).__init__()
|
||||
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
|
||||
self.stft_losses = torch.nn.ModuleList()
|
||||
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
|
||||
self.stft_losses += [STFTLoss(fs, ss, wl, window)]
|
||||
|
||||
def forward(self, x, y):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x (Tensor): Predicted signal (B, T).
|
||||
y (Tensor): Groundtruth signal (B, T).
|
||||
Returns:
|
||||
Tensor: Multi resolution spectral convergence loss value.
|
||||
Tensor: Multi resolution log STFT magnitude loss value.
|
||||
"""
|
||||
sc_loss = 0.0
|
||||
mag_loss = 0.0
|
||||
for f in self.stft_losses:
|
||||
sc_l, mag_l = f(x, y)
|
||||
sc_loss += sc_l
|
||||
mag_loss += mag_l
|
||||
sc_loss /= len(self.stft_losses)
|
||||
mag_loss /= len(self.stft_losses)
|
||||
|
||||
return sc_loss, mag_loss
|
||||
246
models/vocoder/fregan/train.py
Normal file
246
models/vocoder/fregan/train.py
Normal file
@@ -0,0 +1,246 @@
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||
import itertools
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.utils.data import DistributedSampler, DataLoader
|
||||
from torch.distributed import init_process_group
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from models.vocoder.fregan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
|
||||
from models.vocoder.fregan.generator import FreGAN
|
||||
from models.vocoder.fregan.discriminator import ResWiseMultiPeriodDiscriminator, ResWiseMultiScaleDiscriminator
|
||||
from models.vocoder.fregan.loss import feature_loss, generator_loss, discriminator_loss
|
||||
from models.vocoder.fregan.utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
|
||||
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def train(rank, a, h):
|
||||
|
||||
a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_fregan')
|
||||
a.checkpoint_path.mkdir(exist_ok=True)
|
||||
a.training_epochs = 3100
|
||||
a.stdout_interval = 5
|
||||
a.checkpoint_interval = a.backup_every
|
||||
a.summary_interval = 5000
|
||||
a.validation_interval = 1000
|
||||
a.fine_tuning = True
|
||||
|
||||
a.input_wavs_dir = a.syn_dir.joinpath("audio")
|
||||
a.input_mels_dir = a.syn_dir.joinpath("mels")
|
||||
|
||||
if h.num_gpus > 1:
|
||||
init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
|
||||
world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
|
||||
|
||||
torch.cuda.manual_seed(h.seed)
|
||||
device = torch.device('cuda:{:d}'.format(rank))
|
||||
|
||||
generator = FreGAN(h).to(device)
|
||||
mpd = ResWiseMultiPeriodDiscriminator().to(device)
|
||||
msd = ResWiseMultiScaleDiscriminator().to(device)
|
||||
|
||||
if rank == 0:
|
||||
print(generator)
|
||||
os.makedirs(a.checkpoint_path, exist_ok=True)
|
||||
print("checkpoints directory : ", a.checkpoint_path)
|
||||
|
||||
if os.path.isdir(a.checkpoint_path):
|
||||
cp_g = scan_checkpoint(a.checkpoint_path, 'g_fregan_')
|
||||
cp_do = scan_checkpoint(a.checkpoint_path, 'do_fregan_')
|
||||
|
||||
steps = 0
|
||||
if cp_g is None or cp_do is None:
|
||||
state_dict_do = None
|
||||
last_epoch = -1
|
||||
else:
|
||||
state_dict_g = load_checkpoint(cp_g, device)
|
||||
state_dict_do = load_checkpoint(cp_do, device)
|
||||
generator.load_state_dict(state_dict_g['generator'])
|
||||
mpd.load_state_dict(state_dict_do['mpd'])
|
||||
msd.load_state_dict(state_dict_do['msd'])
|
||||
steps = state_dict_do['steps'] + 1
|
||||
last_epoch = state_dict_do['epoch']
|
||||
|
||||
if h.num_gpus > 1:
|
||||
generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
|
||||
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
|
||||
msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
|
||||
|
||||
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
||||
optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
|
||||
h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
||||
|
||||
if state_dict_do is not None:
|
||||
optim_g.load_state_dict(state_dict_do['optim_g'])
|
||||
optim_d.load_state_dict(state_dict_do['optim_d'])
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||
|
||||
training_filelist, validation_filelist = get_dataset_filelist(a)
|
||||
|
||||
trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
|
||||
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
|
||||
shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
|
||||
fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)
|
||||
|
||||
train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
|
||||
|
||||
train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
|
||||
sampler=train_sampler,
|
||||
batch_size=h.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=True)
|
||||
|
||||
if rank == 0:
|
||||
validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
|
||||
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
|
||||
fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
|
||||
base_mels_path=a.input_mels_dir)
|
||||
validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
|
||||
sampler=None,
|
||||
batch_size=1,
|
||||
pin_memory=True,
|
||||
drop_last=True)
|
||||
|
||||
sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
|
||||
|
||||
generator.train()
|
||||
mpd.train()
|
||||
msd.train()
|
||||
for epoch in range(max(0, last_epoch), a.training_epochs):
|
||||
if rank == 0:
|
||||
start = time.time()
|
||||
print("Epoch: {}".format(epoch + 1))
|
||||
|
||||
if h.num_gpus > 1:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
for i, batch in enumerate(train_loader):
|
||||
if rank == 0:
|
||||
start_b = time.time()
|
||||
x, y, _, y_mel = batch
|
||||
x = torch.autograd.Variable(x.to(device, non_blocking=True))
|
||||
y = torch.autograd.Variable(y.to(device, non_blocking=True))
|
||||
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
|
||||
y = y.unsqueeze(1)
|
||||
y_g_hat = generator(x)
|
||||
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size,
|
||||
h.win_size,
|
||||
h.fmin, h.fmax_for_loss)
|
||||
|
||||
if steps > h.disc_start_step:
|
||||
optim_d.zero_grad()
|
||||
|
||||
# MPD
|
||||
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
||||
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
||||
|
||||
# MSD
|
||||
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
|
||||
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
||||
|
||||
loss_disc_all = loss_disc_s + loss_disc_f
|
||||
|
||||
loss_disc_all.backward()
|
||||
optim_d.step()
|
||||
|
||||
# Generator
|
||||
optim_g.zero_grad()
|
||||
|
||||
|
||||
# L1 Mel-Spectrogram Loss
|
||||
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
|
||||
|
||||
# sc_loss, mag_loss = stft_loss(y_g_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1))
|
||||
# loss_mel = h.lambda_aux * (sc_loss + mag_loss) # STFT Loss
|
||||
|
||||
if steps > h.disc_start_step:
|
||||
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
||||
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
|
||||
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
||||
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
|
||||
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
||||
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
||||
loss_gen_all = loss_gen_s + loss_gen_f + (2 * (loss_fm_s + loss_fm_f)) + loss_mel
|
||||
else:
|
||||
loss_gen_all = loss_mel
|
||||
|
||||
loss_gen_all.backward()
|
||||
optim_g.step()
|
||||
|
||||
if rank == 0:
|
||||
# STDOUT logging
|
||||
if steps % a.stdout_interval == 0:
|
||||
with torch.no_grad():
|
||||
mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
|
||||
|
||||
print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
|
||||
format(steps, loss_gen_all, mel_error, time.time() - start_b))
|
||||
|
||||
# checkpointing
|
||||
if steps % a.checkpoint_interval == 0 and steps != 0:
|
||||
checkpoint_path = "{}/g_fregan_{:08d}.pt".format(a.checkpoint_path, steps)
|
||||
save_checkpoint(checkpoint_path,
|
||||
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
|
||||
checkpoint_path = "{}/do_fregan_{:08d}.pt".format(a.checkpoint_path, steps)
|
||||
save_checkpoint(checkpoint_path,
|
||||
{'mpd': (mpd.module if h.num_gpus > 1
|
||||
else mpd).state_dict(),
|
||||
'msd': (msd.module if h.num_gpus > 1
|
||||
else msd).state_dict(),
|
||||
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
|
||||
'epoch': epoch})
|
||||
|
||||
# Tensorboard summary logging
|
||||
if steps % a.summary_interval == 0:
|
||||
sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
|
||||
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
||||
|
||||
# Validation
|
||||
if steps % a.validation_interval == 0: # and steps != 0:
|
||||
generator.eval()
|
||||
torch.cuda.empty_cache()
|
||||
val_err_tot = 0
|
||||
with torch.no_grad():
|
||||
for j, batch in enumerate(validation_loader):
|
||||
x, y, _, y_mel = batch
|
||||
y_g_hat = generator(x.to(device))
|
||||
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
|
||||
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
|
||||
h.hop_size, h.win_size,
|
||||
h.fmin, h.fmax_for_loss)
|
||||
#val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
|
||||
|
||||
if j <= 4:
|
||||
if steps == 0:
|
||||
sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
|
||||
sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
|
||||
|
||||
sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
|
||||
y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
|
||||
h.sampling_rate, h.hop_size, h.win_size,
|
||||
h.fmin, h.fmax)
|
||||
sw.add_figure('generated/y_hat_spec_{}'.format(j),
|
||||
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
|
||||
|
||||
val_err = val_err_tot / (j + 1)
|
||||
sw.add_scalar("validation/mel_spec_error", val_err, steps)
|
||||
|
||||
generator.train()
|
||||
|
||||
steps += 1
|
||||
|
||||
scheduler_g.step()
|
||||
scheduler_d.step()
|
||||
|
||||
if rank == 0:
|
||||
print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
|
||||
|
||||
|
||||
65
models/vocoder/fregan/utils.py
Normal file
65
models/vocoder/fregan/utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import glob
|
||||
import os
|
||||
import matplotlib
|
||||
import torch
|
||||
from torch.nn.utils import weight_norm
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pylab as plt
|
||||
import shutil
|
||||
|
||||
|
||||
def build_env(config, config_name, path):
|
||||
t_path = os.path.join(path, config_name)
|
||||
if config != t_path:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
shutil.copyfile(config, os.path.join(path, config_name))
|
||||
|
||||
|
||||
def plot_spectrogram(spectrogram):
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
||||
interpolation='none')
|
||||
plt.colorbar(im, ax=ax)
|
||||
|
||||
fig.canvas.draw()
|
||||
plt.close()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def apply_weight_norm(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
weight_norm(m)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size*dilation - dilation)/2)
|
||||
|
||||
|
||||
def load_checkpoint(filepath, device):
|
||||
assert os.path.isfile(filepath)
|
||||
print("Loading '{}'".format(filepath))
|
||||
checkpoint_dict = torch.load(filepath, map_location=device)
|
||||
print("Complete.")
|
||||
return checkpoint_dict
|
||||
|
||||
|
||||
def save_checkpoint(filepath, obj):
|
||||
print("Saving checkpoint to {}".format(filepath))
|
||||
torch.save(obj, filepath)
|
||||
print("Complete.")
|
||||
|
||||
|
||||
def scan_checkpoint(cp_dir, prefix):
|
||||
pattern = os.path.join(cp_dir, prefix + '????????.pt')
|
||||
cp_list = glob.glob(pattern)
|
||||
if len(cp_list) == 0:
|
||||
return None
|
||||
return sorted(cp_list)[-1]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user