diff --git a/implementations/wgan_gp/wgan_gp.py b/implementations/wgan_gp/wgan_gp.py index 17d5afbf8ae15115d465b1142b6efcfb6b1ea1c6..6418d990f5a5f75a75c2a1331759b9b9282f29d6 100644 --- a/implementations/wgan_gp/wgan_gp.py +++ b/implementations/wgan_gp/wgan_gp.py @@ -119,7 +119,7 @@ Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor def compute_gradient_penalty(D, real_samples, fake_samples): """Calculates the gradient penalty loss for WGAN GP""" # Random weight term for interpolation between real and fake samples - alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1))) + alpha = Tensor(np.random.random((real_samples.size(0), 1))) # Get random interpolation between real and fake samples interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) d_interpolates = D(interpolates)