diff --git a/synthesizer/inference.py b/synthesizer/inference.py index e5ab1bf..3a6dc6c 100644 --- a/synthesizer/inference.py +++ b/synthesizer/inference.py @@ -62,7 +62,7 @@ class Synthesizer: stop_threshold=hparams.tts_stop_threshold, speaker_embedding_size=hparams.speaker_embedding_size).to(self.device) - self._model.load(self.model_fpath) + self._model.load(self.model_fpath, self.device) self._model.eval() if self.verbose: diff --git a/synthesizer/models/tacotron.py b/synthesizer/models/tacotron.py index 534b0fa..1fdc064 100644 --- a/synthesizer/models/tacotron.py +++ b/synthesizer/models/tacotron.py @@ -470,7 +470,9 @@ class Tacotron(nn.Module): # put after encoder if hparams.use_gst and self.gst is not None: if style_idx >= 0 and style_idx < 10: - query = torch.zeros(1, 1, self.gst.stl.attention.num_units).cuda() + query = torch.zeros(1, 1, self.gst.stl.attention.num_units) + if device.type == 'cuda': + query = query.cuda() gst_embed = torch.tanh(self.gst.stl.embed) key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1) style_embed = self.gst.stl.attention(query, key) @@ -539,9 +541,9 @@ class Tacotron(nn.Module): with open(path, "a") as f: print(msg, file=f) - def load(self, path, optimizer=None): + def load(self, path, device, optimizer=None): # Use device of model params as location for loaded state - checkpoint = torch.load(str(path)) + checkpoint = torch.load(str(path), map_location=device) self.load_state_dict(checkpoint["model_state"], strict=False) if "optimizer_state" in checkpoint and optimizer is not None: diff --git a/synthesizer/synthesize.py b/synthesizer/synthesize.py index e2dd02c..49a06b0 100644 --- a/synthesizer/synthesize.py +++ b/synthesizer/synthesize.py @@ -45,7 +45,7 @@ def run_synthesis(in_dir, out_dir, model_dir, hparams): model_dir = Path(model_dir) model_fpath = model_dir.joinpath(model_dir.stem).with_suffix(".pt") print("\nLoading weights at %s" % model_fpath) - model.load(model_fpath) + model.load(model_fpath, device) print("Tacotron weights loaded from step %d" % model.step) # Synthesize using same reduction factor as the model is currently trained diff --git a/synthesizer/train.py b/synthesizer/train.py index d299afe..7446e80 100644 --- a/synthesizer/train.py +++ b/synthesizer/train.py @@ -111,7 +111,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int, else: print("\nLoading weights at %s" % weights_fpath) - model.load(weights_fpath, optimizer) + model.load(weights_fpath, device, optimizer) print("Tacotron weights loaded from step %d" % model.step) # Initialize the dataset