diff --git a/pyod/models/auto_encoder_torch.py b/pyod/models/auto_encoder_torch.py
index a79aad8edb172d0e1d035f4013c3c8c6e32c0a60..58d5dc4abb3875c2f1cc1a7ec1450392b209a4ba 100644
--- a/pyod/models/auto_encoder_torch.py
+++ b/pyod/models/auto_encoder_torch.py
@@ -120,19 +120,21 @@ class AutoEncoder(BaseDetector):
     hidden_activation : str, optional (default='relu')
         Activation function to use for hidden layers.
         All hidden layers are forced to use the same type of activation.
-        See https://keras.io/activations/
+        See https://pytorch.org/docs/stable/nn.html for details.
+        Currently only
+        'relu': nn.ReLU()
+        'sigmoid': nn.Sigmoid()
+        'tanh': nn.Tanh()
+        are supported. See pyod/utils/torch_utility.py for details.
 
     batch_norm : boolean, optional (default=True)
         Whether to apply Batch Normalization,
         See https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
 
-    loss : str or obj, optional (default=torch.nn.MSELoss)
-        String (name of objective function) or objective function.
-        NOT SUPPORT FOR CHANGE YET.
-
-    optimizer : str, optional (default='adam')
-        String (name of optimizer) or optimizer instance.
-        NOT SUPPORT FOR CHANGE YET.
+    learning_rate : float, optional (default=1e-3)
+        Learning rate for the optimizer. This learning_rate is given to
+        an Adam optimizer (torch.optim.Adam).
+        See https://pytorch.org/docs/stable/generated/torch.optim.Adam.html
 
     epochs : int, optional (default=100)
         Number of epochs to train the model.
@@ -143,17 +145,18 @@ class AutoEncoder(BaseDetector):
     dropout_rate : float in (0., 1), optional (default=0.2)
         The dropout to be used across all layers.
 
-    l2_regularizer : float in (0., 1), optional (default=0.1)
-        The regularization strength of activity_regularizer
-        applied on each layer. By default, l2 regularizer is used. See
-        https://keras.io/regularizers/
-
-    validation_size : float in (0., 1), optional (default=0.1)
-        The percentage of data to be used for validation.
+    weight_decay : float, optional (default=1e-5)
+        The weight decay for Adam optimizer.
+        See https://pytorch.org/docs/stable/generated/torch.optim.Adam.html
 
     preprocessing : bool, optional (default=True)
         If True, apply standardization on the data.
 
+    loss_fn : obj, optional (default=torch.nn.MSELoss)
+        Optimizer instance which implements torch.nn._Loss.
+        One of https://pytorch.org/docs/stable/nn.html#loss-functions
+        or a custom loss. Custom losses are currently unstable.
+
     verbose : int, optional (default=1)
         Verbosity mode.
 
@@ -162,6 +165,7 @@ class AutoEncoder(BaseDetector):
         - 2 = one line per epoch.
 
         For verbose >= 1, model summary may be printed.
+        !CURRENTLY NOT SUPPORTED.!
 
     random_state : random_state: int, RandomState instance or None, optional
         (default=None)
@@ -169,6 +173,7 @@ class AutoEncoder(BaseDetector):
         number generator; If RandomState instance, random_state is the random
         number generator; If None, the random number generator is the
         RandomState instance used by `np.random`.
+        !CURRENTLY NOT SUPPORTED.!
 
     contamination : float in (0., 0.5), optional (default=0.1)
         The amount of contamination of the data set, i.e.
@@ -212,13 +217,10 @@ class AutoEncoder(BaseDetector):
                  hidden_neurons=None,
                  hidden_activation='relu',
                  batch_norm=True,
-                 # loss='mse',
-                 # optimizer='adam',
                  learning_rate=1e-3,
                  epochs=100,
                  batch_size=32,
                  dropout_rate=0.2,
-                 # l2_regularizer=0.1,
                  weight_decay=1e-5,
                  # validation_size=0.1,
                  preprocessing=True,
@@ -228,33 +230,34 @@ class AutoEncoder(BaseDetector):
                  contamination=0.1,
                  device=None):
         super(AutoEncoder, self).__init__(contamination=contamination)
+
+        # save the initialization values
         self.hidden_neurons = hidden_neurons
         self.hidden_activation = hidden_activation
         self.batch_norm = batch_norm
         self.learning_rate = learning_rate
-
         self.epochs = epochs
         self.batch_size = batch_size
-
         self.dropout_rate = dropout_rate
         self.weight_decay = weight_decay
         self.preprocessing = preprocessing
+        self.loss_fn = loss_fn
+        # self.verbose = verbose
+        self.device = device
 
-        if loss_fn is None:
+        # create default loss functions
+        if self.loss_fn is None:
             self.loss_fn = torch.nn.MSELoss()
 
-        if device is None:
+        # create default calculation device (support GPU if available)
+        if self.device is None:
             self.device = torch.device(
                 "cuda:0" if torch.cuda.is_available() else "cpu")
-        else:
-            self.device = device
 
-        # default values
+        # default values for the amount of hidden neurons
         if self.hidden_neurons is None:
             self.hidden_neurons = [64, 32]
 
-        # self.verbose = verbose
-
     # noinspection PyUnresolvedReferences
     def fit(self, X, y=None):
         """Fit detector. y is ignored in unsupervised methods.