Skip to content
GitLab
    • Explore Projects Groups Snippets
Projects Groups Snippets
  • /
  • Help
    • Help
    • Support
    • Community forum
    • Submit feedback
    • Contribute to GitLab
  • Sign in / Register
  • P PyTorch-GAN
  • Project information
    • Project information
    • Activity
    • Labels
    • Members
  • Repository
    • Repository
    • Files
    • Commits
    • Branches
    • Tags
    • Contributors
    • Graph
    • Compare
  • Issues 102
    • Issues 102
    • List
    • Boards
    • Service Desk
    • Milestones
  • Merge requests 24
    • Merge requests 24
  • CI/CD
    • CI/CD
    • Pipelines
    • Jobs
    • Schedules
  • Deployments
    • Deployments
    • Environments
    • Releases
  • Packages and registries
    • Packages and registries
    • Package Registry
    • Infrastructure Registry
  • Monitor
    • Monitor
    • Incidents
  • Analytics
    • Analytics
    • Value stream
    • CI/CD
    • Repository
  • Wiki
    • Wiki
  • Snippets
    • Snippets
  • Activity
  • Graph
  • Create a new issue
  • Jobs
  • Commits
  • Issue Boards
Collapse sidebar
  • Erik Linder-Norén
  • PyTorch-GAN
  • Merge requests
  • !174

fix WGAN-GP alpha dimension error

  • Review changes

  • Download
  • Email patches
  • Plain diff
Closed Administrator requested to merge github/fork/sudrizzz/master into master 3 years ago
  • Overview 1
  • Commits 1
  • Pipelines 0
  • Changes 1

Created by: sudrizzz

Compare
  • master (base)

and
  • latest version
    3815b577
    1 commit, 2 years ago

1 file
+ 1
- 1

    Preferences

    File browser
    Compare changes
implementations/wgan_gp/wgan_gp.py
+ 1
- 1
  • View file @ 3815b577


@@ -119,7 +119,7 @@ Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
def compute_gradient_penalty(D, real_samples, fake_samples):
"""Calculates the gradient penalty loss for WGAN GP"""
# Random weight term for interpolation between real and fake samples
alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
alpha = Tensor(np.random.random((real_samples.size(0), 1)))
# Get random interpolation between real and fake samples
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
d_interpolates = D(interpolates)
0 Assignees
None
Assign to
0 Reviewers
None
Request review from
Labels
0
None
0
None
    Assign labels
  • Manage project labels

Milestone
No milestone
None
None
Time tracking
No estimate or time spent
Lock merge request
Unlocked
1
1 participant
Administrator
Reference: eriklindernoren/PyTorch-GAN!174
Source branch: github/fork/sudrizzz/master

Menu

Explore Projects Groups Snippets