Skip to content
GitLab
Projects Groups Snippets
  • /
  • Help
    • Help
    • Support
    • Community forum
    • Submit feedback
    • Contribute to GitLab
  • Sign in / Register
  • P pyod
  • Project information
    • Project information
    • Activity
    • Labels
    • Members
  • Repository
    • Repository
    • Files
    • Commits
    • Branches
    • Tags
    • Contributors
    • Graph
    • Compare
  • Issues 144
    • Issues 144
    • List
    • Boards
    • Service Desk
    • Milestones
  • Merge requests 16
    • Merge requests 16
  • CI/CD
    • CI/CD
    • Pipelines
    • Jobs
    • Schedules
  • Deployments
    • Deployments
    • Environments
    • Releases
  • Packages and registries
    • Packages and registries
    • Package Registry
    • Infrastructure Registry
  • Monitor
    • Monitor
    • Incidents
  • Analytics
    • Analytics
    • Value stream
    • CI/CD
    • Repository
  • Wiki
    • Wiki
  • Snippets
    • Snippets
  • Activity
  • Graph
  • Create a new issue
  • Jobs
  • Commits
  • Issue Boards
Collapse sidebar
  • Yue Zhao
  • pyod
  • Merge requests
  • !435

Refactored PyTorch Autonencoder built up (Proposal to solve #434)

  • Review changes

  • Download
  • Email patches
  • Plain diff
Merged Lucew requested to merge github/fork/Lucew/development into development Sep 03, 2022
  • Overview 7
  • Commits 3
  • Pipelines 0
  • Changes 1

This Pullrequest is a proposal to fix #434 (closed).

I refactored the setup of the Autoencoder (PyTorch Version) in order to address the points of #434 (closed). The order of layers is now consistent with the Tensorflow implementation.

The order of Neurons->BatchNorm->Activation follows the Standard-Implementation of ResNet

We adopt batch normalization (BN) [16] right after each convolution and before activation, following [16].

The printed Network now looks like:

AutoEncoder(batch_norm=True, batch_size=32, contamination=0.1,
      device=device(type='cpu'), dropout_rate=0.2, epochs=10,
      hidden_activation='relu', hidden_neurons=[64, 32],
      learning_rate=0.001, loss_fn=MSELoss(), preprocessing=True,
      weight_decay=1e-05)
InnerAutoencoder(
  (activation): ReLU()
  (encoder): Sequential(
    (linear0): Linear(in_features=300, out_features=64, bias=True)
    (batch_norm0): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU()
    (dropout0): Dropout(p=0.2, inplace=False)
    (linear1): Linear(in_features=64, out_features=32, bias=True)
    (batch_norm1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU()
    (dropout1): Dropout(p=0.2, inplace=False)
  )
  (decoder): Sequential(
    (linear0): Linear(in_features=32, out_features=64, bias=True)
    (batch_norm0): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU()
    (dropout0): Dropout(p=0.2, inplace=False)
    (linear1): Linear(in_features=64, out_features=300, bias=True)
    (relu1): ReLU()
  )
)

All Submissions Basics:

  • Have you followed the guidelines in our Contributing document?
  • Have you checked to ensure there aren't other open Pull Requests for the same update/change?
  • Have you checked all Issues to tie the PR to a specific one?

All Submissions Cores:

  • Have you added an explanation of what your changes do and why you'd like us to include them?
  • Have you written new tests for your core changes, as applicable?
  • Have you successfully ran tests with your changes locally?
  • Does your submission pass tests, including CircleCI, Travis CI, and AppVeyor?
  • Does your submission have appropriate code coverage? The cutoff threshold is 95% by Coversall.
Assignee
Assign to
Reviewers
Request review from
Time tracking
Source branch: github/fork/Lucew/development