diff --git a/implementations/wgan_gp/wgan_gp.py b/implementations/wgan_gp/wgan_gp.py
index 17d5afbf8ae15115d465b1142b6efcfb6b1ea1c6..6418d990f5a5f75a75c2a1331759b9b9282f29d6 100644
--- a/implementations/wgan_gp/wgan_gp.py
+++ b/implementations/wgan_gp/wgan_gp.py
@@ -119,7 +119,7 @@ Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
 def compute_gradient_penalty(D, real_samples, fake_samples):
     """Calculates the gradient penalty loss for WGAN GP"""
     # Random weight term for interpolation between real and fake samples
-    alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
+    alpha = Tensor(np.random.random((real_samples.size(0), 1)))
     # Get random interpolation between real and fake samples
     interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
     d_interpolates = D(interpolates)