Unverified Commit 62484754 authored by Yue Zhao's avatar Yue Zhao Committed by GitHub
Browse files

Merge pull request #449 from yzhao062/development

v1.0.6
parents f6029d57 4bcb5643
Showing with 835 additions and 54 deletions
+835 -54
......@@ -170,3 +170,5 @@ v<1.0.4>, <07/29/2022> -- Add LUNAR (#415).
v<1.0.5>, <07/29/2022> -- Import optimization.
v<1.0.5>, <08/27/2022> -- Code optimization.
v<1.0.5>, <09/14/2022> -- Add ALAD.
v<1.0.6>, <09/23/2022> -- Update ADBench benchmark for NeruIPS 2022.
v<1.0.6>, <10/23/2022> -- ADD KPCA.
......@@ -58,8 +58,11 @@ Python Outlier Detection (PyOD)
-----
**News**: We just released a 36-page, the most comprehensive `anomaly detection benchmark paper <https://www.andrew.cmu.edu/user/yuezhao2/papers/22-preprint-adbench.pdf>`_.
The fully `open-sourced ADBench <https://github.com/Minqi824/ADBench>`_ compares 30 anomaly detection algorithms on 55 benchmark datasets.
**News**: We just released a 45-page, the most comprehensive `anomaly detection benchmark paper <https://www.andrew.cmu.edu/user/yuezhao2/papers/22-neurips-adbench.pdf>`_.
The fully `open-sourced ADBench <https://github.com/Minqi824/ADBench>`_ compares 30 anomaly detection algorithms on 57 benchmark datasets.
**For time-series outlier detection**, please use `TODS <https://github.com/datamllab/tods>`_.
**For graph outlier detection**, please use `PyGOD <https://pygod.org/>`_.
PyOD is the most comprehensive and scalable **Python library** for **detecting outlying objects** in
multivariate data. This exciting yet challenging field is commonly referred as
......@@ -68,7 +71,7 @@ or `Anomaly Detection <https://en.wikipedia.org/wiki/Anomaly_detection>`_.
PyOD includes more than 40 detection algorithms, from classical LOF (SIGMOD 2000) to
the latest ECOD (TKDE 2022). Since 2017, PyOD has been successfully used in numerous academic researches and
commercial products with more than `8 million downloads <https://pepy.tech/project/pyod>`_.
commercial products with more than `10 million downloads <https://pepy.tech/project/pyod>`_.
It is also well acknowledged by the machine learning community with various dedicated posts/tutorials, including
`Analytics Vidhya <https://www.analyticsvidhya.com/blog/2019/02/outlier-detection-python-pyod/>`_,
`KDnuggets <https://www.kdnuggets.com/2019/02/outlier-detection-methods-cheat-sheet.html>`_, and
......@@ -114,20 +117,29 @@ If you use PyOD in a scientific publication, we would appreciate
citations to the following paper::
@article{zhao2019pyod,
author = {Zhao, Yue and Nasrullah, Zain and Li, Zheng},
title = {PyOD: A Python Toolbox for Scalable Outlier Detection},
journal = {Journal of Machine Learning Research},
year = {2019},
volume = {20},
number = {96},
pages = {1-7},
url = {http://jmlr.org/papers/v20/19-011.html}
author = {Zhao, Yue and Nasrullah, Zain and Li, Zheng},
title = {PyOD: A Python Toolbox for Scalable Outlier Detection},
journal = {Journal of Machine Learning Research},
year = {2019},
volume = {20},
number = {96},
pages = {1-7},
url = {http://jmlr.org/papers/v20/19-011.html}
}
or::
Zhao, Y., Nasrullah, Z. and Li, Z., 2019. PyOD: A Python Toolbox for Scalable Outlier Detection. Journal of machine learning research (JMLR), 20(96), pp.1-7.
If you want more general insights of anomaly detection and/or algorithm performance comparison, please see our
NeurIPS 2022 paper `ADBench: Anomaly Detection Benchmark <https://www.andrew.cmu.edu/user/yuezhao2/papers/22-neurips-adbench.pdf>`_::
@inproceedings{han2022adbench,
title={ADBench: Anomaly Detection Benchmark},
author={Songqiao Han and Xiyang Hu and Hailiang Huang and Mingqi Jiang and Yue Zhao},
booktitle={Neural Information Processing Systems (NeurIPS)}
year={2022},
}
**Key Links and Resources**\ :
......@@ -238,8 +250,8 @@ Key Attributes of a fitted model:
ADBench Benchmark
^^^^^^^^^^^^^^^^^
We just released a 36-page, the most comprehensive `anomaly detection benchmark paper <https://www.andrew.cmu.edu/user/yuezhao2/papers/22-preprint-adbench.pdf>`_ [#Han2022ADBench]_.
The fully `open-sourced ADBench <https://github.com/Minqi824/ADBench>`_ compares 30 anomaly detection algorithms on 55 benchmark datasets.
We just released a 45-page, the most comprehensive `ADBench: Anomaly Detection Benchmark <https://arxiv.org/abs/2206.09426>`_ [#Han2022ADBench]_.
The fully `open-sourced ADBench <https://github.com/Minqi824/ADBench>`_ compares 30 anomaly detection algorithms on 57 benchmark datasets.
The organization of **ADBench** is provided below:
......@@ -342,6 +354,7 @@ Probabilistic KDE Outlier Detection with Kernel Density F
Probabilistic Sampling Rapid distance-based outlier detection via sampling 2013 [#Sugiyama2013Rapid]_
Probabilistic GMM Probabilistic Mixture Modeling for Outlier Analysis [#Aggarwal2015Outlier]_ [Ch.2]
Linear Model PCA Principal Component Analysis (the sum of weighted projected distances to the eigenvector hyperplanes) 2003 [#Shyu2003A]_
Linear Model KPCA Kernel Principal Component Analysis 2007 [#Hoffmann2007Kernel]_
Linear Model MCD Minimum Covariance Determinant (use the mahalanobis distances as the outlier scores) 1999 [#Hardin2004Outlier]_ [#Rousseeuw1999A]_
Linear Model CD Use Cook's distance for outlier detection 1977 [#Cook1977Detection]_
Linear Model OCSVM One-Class Support Vector Machines 2001 [#Scholkopf2001Estimating]_
......@@ -564,6 +577,8 @@ Reference
.. [#He2003Discovering] He, Z., Xu, X. and Deng, S., 2003. Discovering cluster-based local outliers. *Pattern Recognition Letters*\ , 24(9-10), pp.1641-1650.
.. [#Hoffmann2007Kernel] Hoffmann, H., 2007. Kernel PCA for novelty detection. Pattern recognition, 40(3), pp.863-874.
.. [#Iglewicz1993How] Iglewicz, B. and Hoaglin, D.C., 1993. How to detect and handle outliers (Vol. 16). Asq Press.
.. [#Janssens2012Stochastic] Janssens, J.H.M., Huszár, F., Postma, E.O. and van den Herik, H.J., 2012. Stochastic outlier selection. Technical report TiCC TR 2012-001, Tilburg University, Tilburg Center for Cognition and Communication, Tilburg, The Netherlands.
......
......@@ -64,8 +64,11 @@ Welcome to PyOD documentation!
----
**News**: We just released a 36-page, the most comprehensive `anomaly detection benchmark paper <https://www.andrew.cmu.edu/user/yuezhao2/papers/22-preprint-adbench.pdf>`_.
The fully `open-sourced ADBench <https://github.com/Minqi824/ADBench>`_ compares 30 anomaly detection algorithms on 55 benchmark datasets.
**News**: We just released a 45-page, the most comprehensive `anomaly detection benchmark paper <https://www.andrew.cmu.edu/user/yuezhao2/papers/22-neurips-adbench.pdf>`_.
The fully `open-sourced ADBench <https://github.com/Minqi824/ADBench>`_ compares 30 anomaly detection algorithms on 57 benchmark datasets.
**For time-series outlier detection**, please use `TODS <https://github.com/datamllab/tods>`_.
**For graph outlier detection**, please use `PyGOD <https://pygod.org/>`_.
PyOD is the most comprehensive and scalable **Python library** for **detecting outlying objects** in
multivariate data. This exciting yet challenging field is commonly referred as
......@@ -73,8 +76,8 @@ multivariate data. This exciting yet challenging field is commonly referred as
or `Anomaly Detection <https://en.wikipedia.org/wiki/Anomaly_detection>`_.
PyOD includes more than 40 detection algorithms, from classical LOF (SIGMOD 2000) to
the latest ECOD (TKDE 2020). Since 2017, PyOD :cite:`a-zhao2019pyod` has been successfully used in numerous
academic researches and commercial products with more than `8 million downloads <https://pepy.tech/project/pyod>`_.
the latest ECOD (TKDE 2022). Since 2017, PyOD :cite:`a-zhao2019pyod` has been successfully used in numerous
academic researches and commercial products with more than `10 million downloads <https://pepy.tech/project/pyod>`_.
It is also well acknowledged by the machine learning community with various dedicated posts/tutorials, including
`Analytics Vidhya <https://www.analyticsvidhya.com/blog/2019/02/outlier-detection-python-pyod/>`_,
`KDnuggets <https://www.kdnuggets.com/2019/02/outlier-detection-methods-cheat-sheet.html>`_, and
......@@ -121,20 +124,29 @@ If you use PyOD in a scientific publication, we would appreciate
citations to the following paper::
@article{zhao2019pyod,
author = {Zhao, Yue and Nasrullah, Zain and Li, Zheng},
title = {PyOD: A Python Toolbox for Scalable Outlier Detection},
journal = {Journal of Machine Learning Research},
year = {2019},
volume = {20},
number = {96},
pages = {1-7},
url = {http://jmlr.org/papers/v20/19-011.html}
author = {Zhao, Yue and Nasrullah, Zain and Li, Zheng},
title = {PyOD: A Python Toolbox for Scalable Outlier Detection},
journal = {Journal of Machine Learning Research},
year = {2019},
volume = {20},
number = {96},
pages = {1-7},
url = {http://jmlr.org/papers/v20/19-011.html}
}
or::
Zhao, Y., Nasrullah, Z. and Li, Z., 2019. PyOD: A Python Toolbox for Scalable Outlier Detection. Journal of machine learning research (JMLR), 20(96), pp.1-7.
If you want more general insights of anomaly detection and/or algorithm performance comparison, please see our
NeurIPS 2022 paper `ADBench: Anomaly Detection Benchmark <https://www.andrew.cmu.edu/user/yuezhao2/papers/22-neurips-adbench.pdf>`_::
@inproceedings{han2022adbench,
title={ADBench: Anomaly Detection Benchmark},
author={Songqiao Han and Xiyang Hu and Hailiang Huang and Mingqi Jiang and Yue Zhao},
booktitle={Neural Information Processing Systems (NeurIPS)}
year={2022},
}
**Key Links and Resources**\ :
......@@ -148,8 +160,8 @@ or::
Benchmark
=========
We just released a 36-page, the most comprehensive `anomaly detection benchmark paper <https://www.andrew.cmu.edu/user/yuezhao2/papers/22-preprint-adbench.pdf>`_.
The fully `open-sourced ADBench <https://github.com/Minqi824/ADBench>`_ compares 30 anomaly detection algorithms on 55 benchmark datasets.
We just released a 45-page, the most comprehensive `ADBench: Anomaly Detection Benchmark <https://arxiv.org/abs/2206.09426>`_.
The fully `open-sourced ADBench <https://github.com/Minqi824/ADBench>`_ compares 30 anomaly detection algorithms on 57 benchmark datasets.
The organization of **ADBench** is provided below:
......@@ -178,6 +190,7 @@ Probabilistic KDE Outlier Detection with Kernel Density Fun
Probabilistic Sampling Rapid distance-based outlier detection via sampling 2013 :class:`pyod.models.sampling.Sampling` :cite:`a-sugiyama2013rapid`
Probabilistic GMM Probabilistic Mixture Modeling for Outlier Analysis :class:`pyod.models.gmm.GMM` :cite:`a-aggarwal2015outlier` [Ch.2]
Linear Model PCA Principal Component Analysis (the sum of weighted projected distances to the eigenvector hyperplanes) 2003 :class:`pyod.models.pca.PCA` :cite:`a-shyu2003novel`
Linear Model KPCA Kernel Principal Component Analysis 2007 :class:`pyod.models.kpca.KPCA` :cite:`a-hoffmann2007kernel`
Linear Model MCD Minimum Covariance Determinant (use the mahalanobis distances as the outlier scores) 1999 :class:`pyod.models.mcd.MCD` :cite:`a-rousseeuw1999fast,a-hardin2004outlier`
Linear Model CD Use Cook's distance for outlier detection 1977 :class:`pyod.models.cd.CD` :cite:`a-cook1977detection`
Linear Model OCSVM One-Class Support Vector Machines 2001 :class:`pyod.models.ocsvm.OCSVM` :cite:`a-scholkopf2001estimating`
......
......@@ -181,6 +181,15 @@ pyod.models.knn module
:show-inheritance:
:inherited-members:
pyod.models.kpca module
-----------------------
.. automodule:: pyod.models.kpca
:members:
:undoc-members:
:show-inheritance:
:inherited-members:
pyod.models.lmdd module
-----------------------
......
......@@ -467,4 +467,15 @@
pages={727--736},
year={2018},
organization={IEEE}
}
@article{hoffmann2007kernel,
title={Kernel PCA for novelty detection},
author={Hoffmann, Heiko},
journal={Pattern recognition},
volume={40},
number={3},
pages={863--874},
year={2007},
publisher={Elsevier}
}
\ No newline at end of file
examples/ALL.png

1.06 MB | W: 0px | H: 0px

examples/ALL.png

1.05 MB | W: 0px | H: 0px

examples/ALL.png
examples/ALL.png
examples/ALL.png
examples/ALL.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -101,10 +101,15 @@ classifiers = {
'Locally Selective Combination (LSCP)': LSCP(
detector_list, contamination=outliers_fraction,
random_state=random_state),
'INNE': INNE(contamination=outliers_fraction),
'GMM': GMM(contamination=outliers_fraction),
'INNE': INNE(
max_samples=2, contamination=outliers_fraction,
random_state=random_state,
),
'GMM': GMM(contamination=outliers_fraction,
random_state=random_state),
'KDE': KDE(contamination=outliers_fraction),
'LMDD': LMDD(contamination=outliers_fraction),
'LMDD': LMDD(contamination=outliers_fraction,
random_state=random_state),
}
# Show all detectors
......
# -*- coding: utf-8 -*-
"""Example of outlier detection based on Kernel PCA.
"""
# Author: Akira Tamamori <tamamori5917@gmail.com>
# License: BSD 2 clause
from __future__ import division, print_function
import os
import sys
from pyod.models.kpca import KPCA
from pyod.utils.data import evaluate_print, generate_data
from pyod.utils.example import visualize
# temporary solution for relative imports in case pyod is not installed
# if pyod is installed, no need to use the following line
sys.path.append(os.path.abspath(os.path.join(os.path.dirname("__file__"), "..")))
if __name__ == "__main__":
contamination = 0.1 # percentage of outliers
n_train = 200 # number of training points
n_test = 100 # number of testing points
n_features = 2
# Generate sample data
X_train, X_test, y_train, y_test = generate_data(
n_train=n_train,
n_test=n_test,
n_features=2,
contamination=contamination,
random_state=42,
behaviour="new",
)
# train KPCA detector
clf_name = "KPCA"
clf = KPCA()
clf.fit(X_train)
# get the prediction labels and outlier scores of the training data
y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers)
y_train_scores = clf.decision_scores_ # raw outlier scores
# get the prediction on the test data
y_test_pred = clf.predict(X_test) # outlier labels (0 or 1)
y_test_scores = clf.decision_function(X_test) # outlier scores
# evaluate and print the results
print("\nOn Training Data:")
evaluate_print(clf_name, y_train, y_train_scores)
print("\nOn Test Data:")
evaluate_print(clf_name, y_test, y_test_scores)
# visualize the results
visualize(
clf_name,
X_train,
y_train,
X_test,
y_test,
y_train_pred,
y_test_pred,
show_figure=True,
)
......@@ -43,55 +43,85 @@ class PyODDataset(torch.utils.data.Dataset):
return torch.from_numpy(sample), idx
class inner_autoencoder(nn.Module):
class InnerAutoencoder(nn.Module):
def __init__(self,
n_features,
hidden_neurons=[128, 64],
hidden_neurons=(128, 64),
dropout_rate=0.2,
batch_norm=True,
hidden_activation='relu'):
super(inner_autoencoder, self).__init__()
# initialize the super class
super(InnerAutoencoder, self).__init__()
# save the default values
self.n_features = n_features
self.dropout_rate = dropout_rate
self.batch_norm = batch_norm
self.hidden_activation = hidden_activation
# create the dimensions for the input and hidden layers
self.layers_neurons_encoder_ = [self.n_features, *hidden_neurons]
self.layers_neurons_decoder_ = self.layers_neurons_encoder_[::-1]
# get the object for the activations functions
self.activation = get_activation_by_name(hidden_activation)
self.layers_neurons_ = [self.n_features, *hidden_neurons]
self.layers_neurons_decoder_ = self.layers_neurons_[::-1]
# initialize encoder and decoder as a sequential
self.encoder = nn.Sequential()
self.decoder = nn.Sequential()
for idx, layer in enumerate(self.layers_neurons_[:-1]):
# fill the encoder sequential with hidden layers
for idx, layer in enumerate(self.layers_neurons_encoder_[:-1]):
# create a linear layer of neurons
self.encoder.add_module(
"linear" + str(idx),
torch.nn.Linear(layer,self.layers_neurons_encoder_[idx + 1]))
# add a batch norm per layer if wanted (leave out first layer)
if batch_norm:
self.encoder.add_module("batch_norm" + str(idx),
nn.BatchNorm1d(
self.layers_neurons_[idx]))
self.encoder.add_module("linear" + str(idx),
torch.nn.Linear(self.layers_neurons_[idx],
self.layers_neurons_[
idx + 1]))
self.layers_neurons_encoder_[
idx + 1]))
# create the activation
self.encoder.add_module(self.hidden_activation + str(idx),
self.activation)
# create a dropout layer
self.encoder.add_module("dropout" + str(idx),
torch.nn.Dropout(dropout_rate))
for idx, layer in enumerate(self.layers_neurons_[:-1]):
if batch_norm:
# fill the decoder layer
for idx, layer in enumerate(self.layers_neurons_decoder_[:-1]):
# create a linear layer of neurons
self.decoder.add_module(
"linear" + str(idx),
torch.nn.Linear(layer,self.layers_neurons_decoder_[idx + 1]))
# create a batch norm per layer if wanted (only if it is not the
# last layer)
if batch_norm and idx < len(self.layers_neurons_decoder_[:-1]) - 1:
self.decoder.add_module("batch_norm" + str(idx),
nn.BatchNorm1d(
self.layers_neurons_decoder_[idx]))
self.decoder.add_module("linear" + str(idx), torch.nn.Linear(
self.layers_neurons_decoder_[idx],
self.layers_neurons_decoder_[idx + 1]))
self.encoder.add_module(self.hidden_activation + str(idx),
self.layers_neurons_decoder_[
idx + 1]))
# create the activation
self.decoder.add_module(self.hidden_activation + str(idx),
self.activation)
self.decoder.add_module("dropout" + str(idx),
torch.nn.Dropout(dropout_rate))
# create a dropout layer (only if it is not the last layer)
if idx < len(self.layers_neurons_decoder_[:-1]) - 1:
self.decoder.add_module("dropout" + str(idx),
torch.nn.Dropout(dropout_rate))
def forward(self, x):
# we could return the latent representation here after the encoder as the latent representation
# we could return the latent representation here after the encoder
# as the latent representation
x = self.encoder(x)
x = self.decoder(x)
return x
......@@ -293,7 +323,7 @@ class AutoEncoder(BaseDetector):
shuffle=True)
# initialize the model
self.model = inner_autoencoder(
self.model = InnerAutoencoder(
n_features=n_features,
hidden_neurons=self.hidden_neurons,
dropout_rate=self.dropout_rate,
......
......@@ -92,7 +92,7 @@ class KNN(BaseDetector):
Valid values for metric are:
- from scikit-learn: ['cityblock', 'cosine', 'euclidean', 'l1', 'l2',
- from scikit-learn: ['cityblock', 'euclidean', 'l1', 'l2',
'manhattan']
- from scipy.spatial.distance: ['braycurtis', 'canberra', 'chebyshev',
......
# -*- coding: utf-8 -*-
"""Kernel Principal Component Analysis (KPCA) Outlier Detector
"""
# Author: Akira Tamamori <tamamori5917@gmail.com>
# License: BSD 2 clause
import numpy as np
import sklearn
from sklearn.decomposition import KernelPCA
from sklearn.utils import check_array, check_random_state
from sklearn.utils.validation import check_is_fitted
from .base import BaseDetector
from ..utils.utility import check_parameter
class PyODKernelPCA(KernelPCA):
"""A wrapper class for KernelPCA class of scikit-learn."""
def __init__(
self,
n_components=None,
kernel="rbf",
gamma=None,
degree=3,
coef0=1,
kernel_params=None,
alpha=1.0,
fit_inverse_transform=False,
eigen_solver="auto",
tol=0,
max_iter=None,
remove_zero_eig=False,
copy_X=True,
n_jobs=None,
random_state=None,
):
super().__init__(
kernel=kernel,
gamma=gamma,
degree=degree,
coef0=coef0,
kernel_params=kernel_params,
alpha=alpha,
fit_inverse_transform=fit_inverse_transform,
eigen_solver=eigen_solver,
tol=tol,
max_iter=max_iter,
remove_zero_eig=remove_zero_eig,
n_jobs=n_jobs,
copy_X=copy_X,
random_state=check_random_state(random_state),
)
@property
def get_centerer(self):
"""Return a protected member _centerer."""
return self._centerer
@property
def get_kernel(self):
"""Return a protected member _get_kernel."""
return self._get_kernel
class KPCA(BaseDetector):
"""KPCA class for outlier detection.
PCA is performed on the feature space uniquely determined by the kernel,
and the reconstruction error on the feature space is used as the anomaly score.
See :cite:`hoffmann2007kernel`
Heiko Hoffmann, "Kernel PCA for novelty detection,"
Pattern Recognition, vol.40, no.3, pp. 863-874, 2007.
https://www.sciencedirect.com/science/article/pii/S0031320306003414
for details.
Parameters
----------
n_components : int, optional (default=None)
Number of components. If None, all non-zero components are kept.
n_selected_components : int, optional (default=None)
Number of selected principal components
for calculating the outlier scores. It is not necessarily equal to
the total number of the principal components. If not set, use
all principal components.
kernel : string {'linear', 'poly', 'rbf', 'sigmoid',
'cosine', 'precomputed'}, optional (default='rbf')
Kernel used for PCA.
gamma : float, optional (default=None)
Kernel coefficient for rbf, poly and sigmoid kernels. Ignored by other
kernels. If ``gamma`` is ``None``, then it is set to ``1/n_features``.
degree : int, optional (default=3)
Degree for poly kernels. Ignored by other kernels.
coef0 : float, optional (default=1)
Independent term in poly and sigmoid kernels.
Ignored by other kernels.
kernel_params : dict, optional (default=None)
Parameters (keyword arguments) and
values for kernel passed as callable object.
Ignored by other kernels.
alpha : float, optional (default=1.0)
Hyperparameter of the ridge regression that learns the
inverse transform (when inverse_transform=True).
eigen_solver : string, {'auto', 'dense', 'arpack', 'randomized'}, \
default='auto'
Select eigensolver to use. If `n_components` is much
less than the number of training samples, randomized (or arpack to a
smaller extend) may be more efficient than the dense eigensolver.
Randomized SVD is performed according to the method of Halko et al.
auto :
the solver is selected by a default policy based on n_samples
(the number of training samples) and `n_components`:
if the number of components to extract is less than 10 (strict) and
the number of samples is more than 200 (strict), the 'arpack'
method is enabled. Otherwise the exact full eigenvalue
decomposition is computed and optionally truncated afterwards
('dense' method).
dense :
run exact full eigenvalue decomposition calling the standard
LAPACK solver via `scipy.linalg.eigh`, and select the components
by postprocessing.
arpack :
run SVD truncated to n_components calling ARPACK solver using
`scipy.sparse.linalg.eigsh`. It requires strictly
0 < n_components < n_samples
randomized :
run randomized SVD.
implementation selects eigenvalues based on their module; therefore
using this method can lead to unexpected results if the kernel is
not positive semi-definite.
tol : float, optional (default=0)
Convergence tolerance for arpack.
If 0, optimal value will be chosen by arpack.
max_iter : int, optional (default=None)
Maximum number of iterations for arpack.
If None, optimal value will be chosen by arpack.
remove_zero_eig : bool, optional (default=False)
If True, then all components with zero eigenvalues are removed, so
that the number of components in the output may be < n_components
(and sometimes even zero due to numerical instability).
When n_components is None, this parameter is ignored and components
with zero eigenvalues are removed regardless.
copy_X : bool, optional (default=True)
If True, input X is copied and stored by the model in the `X_fit_`
attribute. If no further changes will be done to X, setting
`copy_X=False` saves memory by storing a reference.
n_jobs : int, optional (default=None)
The number of parallel jobs to run.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors.
sampling : bool, optional (default=False)
If True, sampling subset from the dataset is performed only once,
in order to reduce time complexity while keeping detection performance.
subset_size : float in (0., 1.0) or int (0, n_samples), optional (default=20)
If sampling is True, the size of subset is specified.
random_state : int, RandomState instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance
used by np.random.
Attributes
----------
decision_scores_ : numpy array of shape (n_samples,)
The outlier scores of the training data.
The higher, the more abnormal. Outliers tend to have higher
scores. This value is available once the detector is
fitted.
threshold_ : float
The threshold is based on ``contamination``. It is the
``n_samples * contamination`` most abnormal samples in
``decision_scores_``. The threshold is calculated for generating
binary outlier labels.
labels_ : int, either 0 or 1
The binary labels of the training data. 0 stands for inliers
and 1 for outliers/anomalies. It is generated by applying
``threshold_`` on ``decision_scores_``.
"""
def __init__(
self,
contamination=0.1,
n_components=None,
n_selected_components=None,
kernel="rbf",
gamma=None,
degree=3,
coef0=1,
kernel_params=None,
alpha=1.0,
eigen_solver="auto",
tol=0,
max_iter=None,
remove_zero_eig=False,
copy_X=True,
n_jobs=None,
sampling=False,
subset_size=20,
random_state=None,
):
super().__init__(contamination=contamination)
self.n_components = n_components
self.n_selected_components = n_selected_components
self.copy_x = copy_X
self.sampling = sampling
self.subset_size = subset_size
self.random_state = check_random_state(random_state)
self.decision_scores_ = None
self.n_selected_components_ = None
self.kpca = PyODKernelPCA(
n_components=n_components,
kernel=kernel,
gamma=gamma,
degree=degree,
coef0=coef0,
kernel_params=kernel_params,
alpha=alpha,
fit_inverse_transform=False,
eigen_solver=eigen_solver,
tol=tol,
max_iter=max_iter,
remove_zero_eig=remove_zero_eig,
copy_X=copy_X,
n_jobs=n_jobs,
)
def _check_subset_size(self, array):
"""Check subset size."""
n_samples, _ = array.shape
if isinstance(self.subset_size, int) is True:
if 0 < self.subset_size <= n_samples:
subset_size = self.subset_size
else:
raise ValueError(
f"subset_size={self.subset_size} "
f"must be between 0 and n_samples={n_samples}."
)
if isinstance(self.subset_size, float) is True:
if 0.0 < self.subset_size <= 1.0:
subset_size = int(self.subset_size * n_samples)
else:
raise ValueError("subset_size=%r must be between 0.0 and 1.0")
return subset_size
def fit(self, X, y=None):
"""Fit detector. y is ignored in unsupervised methods.
Parameters
----------
X : numpy array of shape (n_samples, n_features)
The input samples.
y : Ignored
Not used, present for API consistency by convention.
Returns
-------
self : object
Fitted estimator.
"""
# validate inputs X and y (optional)
X = check_array(X, copy=self.copy_x)
self._set_n_classes(y)
# perform subsampling to reduce time complexity
if self.sampling is True:
subset_size = self._check_subset_size(X)
random_indices = self.random_state.choice(
X.shape[0],
size=subset_size,
replace=False,
)
X = X[random_indices, :]
# copy the attributes from the sklearn Kernel PCA object
if self.n_components is None:
n_components = X.shape[1] # use all dimensions
else:
if self.n_components < 1:
raise ValueError(
f"`n_components` should be >= 1, got: {self.n_components}"
)
n_components = min(X.shape[0], self.n_components)
# validate the number of components to be used for outlier detection
if self.n_selected_components is None:
self.n_selected_components_ = n_components
else:
self.n_selected_components_ = self.n_selected_components
check_parameter(
self.n_selected_components_,
1,
n_components,
include_left=True,
include_right=True,
param_name="n_selected_components",
)
self.kpca.fit(X)
centerer = self.kpca.get_centerer
kernel = self.kpca.get_kernel
if int(sklearn.__version__[0]) < 1:
eigenvalues_ = self.kpca.lambdas_
eigenvectors_ = self.kpca.alphas_
else:
eigenvalues_ = self.kpca.eigenvalues_
eigenvectors_ = self.kpca.eigenvectors_
x_transformed = eigenvectors_ * np.sqrt(eigenvalues_)
x_transformed = x_transformed[:, : self.n_selected_components_]
potential = []
for i in range(X.shape[0]):
sample = X[i, :].reshape(1, -1)
potential.append(kernel(sample))
potential = np.array(potential).squeeze()
potential = potential - 2 * centerer.K_fit_rows_ + centerer.K_fit_all_
# reconstruction error
self.decision_scores_ = potential - np.sum(np.square(x_transformed), axis=1)
self._process_decision_scores()
return self
def decision_function(self, X):
"""Predict raw anomaly score of X using the fitted detector.
The anomaly score of an input sample is computed based on different
detector algorithms. For consistency, outliers are assigned with
larger anomaly scores.
Parameters
----------
X : numpy array of shape (n_samples, n_features)
The training input samples. Sparse matrices are accepted only
if they are supported by the base estimator.
Returns
-------
anomaly_scores : numpy array of shape (n_samples,)
The anomaly score of the input samples.
"""
check_is_fitted(self, ["decision_scores_", "threshold_", "labels_"])
X = check_array(X)
# Compute centered gram matrix between X and training data X_fit_
centerer = self.kpca.get_centerer
kernel = self.kpca.get_kernel
gram_matrix = kernel(X, self.kpca.X_fit_)
centered_g = centerer.transform(gram_matrix)
if int(sklearn.__version__[0]) < 1:
eigenvalues_ = self.kpca.lambdas_
eigenvectors_ = self.kpca.alphas_
else:
eigenvalues_ = self.kpca.eigenvalues_
eigenvectors_ = self.kpca.eigenvectors_
# scale eigenvectors (properly account for null-space for dot product)
non_zeros = np.flatnonzero(eigenvalues_)
scaled_alphas = np.zeros_like(eigenvectors_)
scaled_alphas[:, non_zeros] = eigenvectors_[:, non_zeros] / np.sqrt(
eigenvalues_[non_zeros]
)
# Project with a scalar product between K and the scaled eigenvectors
x_transformed = np.dot(centered_g, scaled_alphas)
x_transformed = x_transformed[:, : self.n_selected_components_]
potential = []
for i in range(X.shape[0]):
sample = X[i, :].reshape(1, -1)
potential.append(kernel(sample))
potential = np.array(potential).squeeze()
gram_fit_rows = np.sum(gram_matrix, axis=1) / gram_matrix.shape[1]
potential = potential - 2 * gram_fit_rows + centerer.K_fit_all_
# reconstruction error
anomaly_scores = potential - np.sum(np.square(x_transformed), axis=1)
return anomaly_scores
# -*- coding: utf-8 -*-
from __future__ import division, print_function
import os
import sys
import unittest
# noinspection PyProtectedMember
from numpy.testing import (assert_allclose, assert_array_less, assert_equal,
assert_raises)
from scipy.stats import rankdata
from sklearn.base import clone
from sklearn.metrics import roc_auc_score
from pyod.models.kpca import KPCA
from pyod.utils.data import generate_data
# temporary solution for relative imports in case pyod is not installed
# if pyod is installed, no need to use the following line
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
class TestKPCA(unittest.TestCase):
def setUp(self):
self.n_train = 200
self.n_test = 100
self.contamination = 0.1
self.roc_floor = 0.8
self.X_train, self.X_test, self.y_train, self.y_test = generate_data(
n_train=self.n_train,
n_test=self.n_test,
contamination=self.contamination,
random_state=42,
)
self.clf = KPCA(contamination=self.contamination, random_state=42)
self.clf.fit(self.X_train)
def test_parameters(self):
assert (
hasattr(self.clf, "decision_scores_")
and self.clf.decision_scores_ is not None
)
assert hasattr(self.clf, "labels_") and self.clf.labels_ is not None
assert hasattr(self.clf, "threshold_") and self.clf.threshold_ is not None
def test_train_scores(self):
assert_equal(len(self.clf.decision_scores_), self.X_train.shape[0])
def test_prediction_scores(self):
pred_scores = self.clf.decision_function(self.X_test)
# check score shapes
assert_equal(pred_scores.shape[0], self.X_test.shape[0])
# check performance
assert roc_auc_score(self.y_test, pred_scores) >= self.roc_floor
def test_prediction_labels(self):
pred_labels = self.clf.predict(self.X_test)
assert_equal(pred_labels.shape, self.y_test.shape)
def test_prediction_proba(self):
pred_proba = self.clf.predict_proba(self.X_test)
assert pred_proba.min() >= 0
assert pred_proba.max() <= 1
def test_prediction_proba_linear(self):
pred_proba = self.clf.predict_proba(self.X_test, method="linear")
assert pred_proba.min() >= 0
assert pred_proba.max() <= 1
def test_prediction_proba_unify(self):
pred_proba = self.clf.predict_proba(self.X_test, method="unify")
assert pred_proba.min() >= 0
assert pred_proba.max() <= 1
def test_prediction_proba_parameter(self):
with assert_raises(ValueError):
self.clf.predict_proba(self.X_test, method="something")
def test_prediction_labels_confidence(self):
pred_labels, confidence = self.clf.predict(self.X_test, return_confidence=True)
assert_equal(pred_labels.shape, self.y_test.shape)
assert_equal(confidence.shape, self.y_test.shape)
assert confidence.min() >= 0
assert confidence.max() <= 1
def test_prediction_proba_linear_confidence(self):
pred_proba, confidence = self.clf.predict_proba(
self.X_test, method="linear", return_confidence=True
)
assert pred_proba.min() >= 0
assert pred_proba.max() <= 1
assert_equal(confidence.shape, self.y_test.shape)
assert confidence.min() >= 0
assert confidence.max() <= 1
def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
assert_equal(pred_labels.shape, self.y_train.shape)
def test_fit_predict_score(self):
self.clf.fit_predict_score(self.X_test, self.y_test)
self.clf.fit_predict_score(self.X_test, self.y_test, scoring="roc_auc_score")
self.clf.fit_predict_score(self.X_test, self.y_test, scoring="prc_n_score")
with assert_raises(NotImplementedError):
self.clf.fit_predict_score(self.X_test, self.y_test, scoring="something")
def test_predict_rank(self):
pred_socres = self.clf.decision_function(self.X_test)
pred_ranks = self.clf._predict_rank(self.X_test)
# assert the order is reserved
assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=4)
assert_array_less(pred_ranks, self.X_train.shape[0] + 1)
assert_array_less(-0.1, pred_ranks)
def test_predict_rank_normalized(self):
pred_socres = self.clf.decision_function(self.X_test)
pred_ranks = self.clf._predict_rank(self.X_test, normalized=True)
# assert the order is reserved
assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=4)
assert_array_less(pred_ranks, 1.01)
assert_array_less(-0.1, pred_ranks)
def test_model_clone(self):
clone_clf = clone(self.clf)
def tearDown(self):
pass
class TestKPCASubsetBound(unittest.TestCase):
def setUp(self):
self.n_train = 200
self.n_test = 100
self.contamination = 0.1
self.roc_floor = 0.8
self.X_train, self.X_test, self.y_train, self.y_test = generate_data(
n_train=self.n_train,
n_test=self.n_test,
contamination=self.contamination,
random_state=42,
)
self.clf_float = KPCA(
sampling=True,
subset_size=0.1,
contamination=self.contamination,
random_state=42,
)
self.clf_int = KPCA(
sampling=True,
subset_size=50,
contamination=self.contamination,
random_state=42,
)
self.clf_float_upper = KPCA(sampling=True, subset_size=1.5, random_state=42)
self.clf_float_lower = KPCA(sampling=True, subset_size=0, random_state=42)
self.clf_int_upper = KPCA(
sampling=True, subset_size=self.n_train + 100, random_state=42
)
self.clf_int_lower = KPCA(sampling=True, subset_size=-1, random_state=42)
def test_bound(self):
self.clf_float.fit(self.X_train)
self.clf_int.fit(self.X_train)
with assert_raises(ValueError):
self.clf_float_upper.fit(self.X_train)
with assert_raises(ValueError):
self.clf_float_lower.fit(self.X_train)
with assert_raises(ValueError):
self.clf_int_upper.fit(self.X_train)
with assert_raises(ValueError):
self.clf_int_lower.fit(self.X_train)
def tearDown(self):
pass
class TestKPCAComponentsBound(unittest.TestCase):
def setUp(self):
self.n_train = 200
self.n_test = 100
self.contamination = 0.1
self.roc_floor = 0.8
self.X_train, self.X_test, self.y_train, self.y_test = generate_data(
n_train=self.n_train,
n_test=self.n_test,
contamination=self.contamination,
random_state=42,
)
self.clf = KPCA(contamination=self.contamination, random_state=42)
self.clf_component_neg = KPCA(n_components=-1, random_state=42)
self.clf_selected_components = KPCA(
n_components=10, n_selected_components=5, random_state=42
)
self.clf_selected_components_upper = KPCA(
n_components=10, n_selected_components=50, random_state=42
)
self.clf_selected_components_lower = KPCA(
n_components=10, n_selected_components=0, random_state=42
)
def test_bound(self):
self.clf.fit(self.X_train)
with assert_raises(ValueError):
self.clf_component_neg.fit(self.X_train)
self.clf_selected_components.fit(self.X_train)
with assert_raises(ValueError):
self.clf_selected_components_upper.fit(self.X_train)
with assert_raises(ValueError):
self.clf_selected_components_lower.fit(self.X_train)
def tearDown(self):
pass
if __name__ == "__main__":
unittest.main()
......@@ -20,4 +20,4 @@
# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'
#
__version__ = '1.0.5' # pragma: no cover
__version__ = '1.0.6' # pragma: no cover
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment