diff --git a/CHANGES.txt b/CHANGES.txt index bf7b29d96a00cb43f42c1f403961b95fa78d7971..d589772c968ea776743eaf8a579ee455b776506d 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -42,6 +42,7 @@ v<0.6.5>, <11/28/2018> -- Add Stochastic Outlier Selection (SOS). v<0.6.5>, <11/30/2018> -- Add CircleCI continuous integration. v<0.6.5>, <12/03/2018> -- Add Local Correlation Integral (LOCI). v<0.6.6>, <12/06/2018> -- Add LSCP (production version). +v<0.6.6>, <12/08/2018> -- Add XGBOD. diff --git a/README.rst b/README.rst index d6ec076fb3d2ad8ea9f10aa342c5e9bcbc73c8c9..b7fa65d1b8ba315e807930864f138b5b8d101f8e 100644 --- a/README.rst +++ b/README.rst @@ -118,25 +118,28 @@ PyOD toolkit consists of three major groups of functionalities: **(i) Individual Detection Algorithms** : -=================== ================ ===================================================================================================== ===== ======================================== -Type Abbr Algorithm Year Ref -=================== ================ ===================================================================================================== ===== ======================================== -Linear Model PCA Principal Component Analysis (the sum of weighted projected distances to the eigenvector hyperplanes) 2003 [#Shyu2003A]_ -Linear Model MCD Minimum Covariance Determinant (use the mahalanobis distances as the outlier scores) 1999 [#Hardin2004Outlier]_ [#Rousseeuw1999A]_ -Linear Model OCSVM One-Class Support Vector Machines 2003 [#Ma2003Time]_ -Proximity-Based LOF Local Outlier Factor 2000 [#Breunig2000LOF]_ -Proximity-Based CBLOF Clustering-Based Local Outlier Factor 2003 [#He2003Discovering]_ -Proximity-Based LOCI LOCI: Fast outlier detection using the local correlation integral 2003 [#Papadimitriou2003LOCI]_ -Proximity-Based HBOS Histogram-based Outlier Score 2012 [#Goldstein2012Histogram]_ -Proximity-Based kNN k Nearest Neighbors (use the distance to the kth nearest neighbor as the outlier score 2000 [#Ramaswamy2000Efficient]_ -Proximity-Based AvgKNN Average kNN (use the average distance to k nearest neighbors as the outlier score) 2002 [#Angiulli2002Fast]_ -Proximity-Based MedKNN Median kNN (use the median distance to k nearest neighbors as the outlier score) 2002 [#Angiulli2002Fast]_ -Probabilistic ABOD Angle-Based Outlier Detection 2008 [#Kriegel2008Angle]_ -Probabilistic FastABOD Fast Angle-Based Outlier Detection using approximation 2008 [#Kriegel2008Angle]_ -Probabilistic SOS Stochastic Outlier Selection 2012 [#Janssens2012Stochastic]_ -Outlier Ensembles IForest Isolation Forest 2008 [#Liu2008Isolation]_ -Neural Networks AutoEncoder Fully connected AutoEncoder (use reconstruction error as the outlier score) [#Aggarwal2015Outlier]_ [Ch.3] -=================== ================ ===================================================================================================== ===== ======================================== +=================== ================ ====================================================================================================== ===== ======================================== +Type Abbr Algorithm Year Ref +=================== ================ ====================================================================================================== ===== ======================================== +Linear Model PCA Principal Component Analysis (the sum of weighted projected distances to the eigenvector hyperplanes) 2003 [#Shyu2003A]_ +Linear Model MCD Minimum Covariance Determinant (use the mahalanobis distances as the outlier scores) 1999 [#Hardin2004Outlier]_ [#Rousseeuw1999A]_ +Linear Model OCSVM One-Class Support Vector Machines 2003 [#Ma2003Time]_ +Proximity-Based LOF Local Outlier Factor 2000 [#Breunig2000LOF]_ +Proximity-Based CBLOF Clustering-Based Local Outlier Factor 2003 [#He2003Discovering]_ +Proximity-Based LOCI LOCI: Fast outlier detection using the local correlation integral 2003 [#Papadimitriou2003LOCI]_ +Proximity-Based HBOS Histogram-based Outlier Score 2012 [#Goldstein2012Histogram]_ +Proximity-Based kNN k Nearest Neighbors (use the distance to the kth nearest neighbor as the outlier score 2000 [#Ramaswamy2000Efficient]_ +Proximity-Based AvgKNN Average kNN (use the average distance to k nearest neighbors as the outlier score) 2002 [#Angiulli2002Fast]_ +Proximity-Based MedKNN Median kNN (use the median distance to k nearest neighbors as the outlier score) 2002 [#Angiulli2002Fast]_ +Probabilistic ABOD Angle-Based Outlier Detection 2008 [#Kriegel2008Angle]_ +Probabilistic FastABOD Fast Angle-Based Outlier Detection using approximation 2008 [#Kriegel2008Angle]_ +Probabilistic SOS Stochastic Outlier Selection 2012 [#Janssens2012Stochastic]_ +Outlier Ensembles IForest Isolation Forest 2008 [#Liu2008Isolation]_ +Outlier Ensembles Feature Bagging 2005 [#Lazarevic2005Feature]_ +Outlier Ensembles LSCP LSCP: Locally Selective Combination of Parallel Outlier Ensembles 2019 [#Zhao2019LSCP]_ +Outlier Ensembles XGBOD **Supervised** XGBOD: Improving Supervised Outlier Detection with Unsupervised Representation Learning 2018 [#Zhao2018XGBOD]_ +Neural Networks AutoEncoder Fully connected AutoEncoder (use reconstruction error as the outlier score) [#Aggarwal2015Outlier]_ [Ch.3] +=================== ================ ====================================================================================================== ===== ======================================== FAQ regarding AutoEncoder in PyOD and debugging advice: `known issues <https://github.com/yzhao062/Pyod/issues/19>`_ @@ -196,12 +199,13 @@ Alternatively, install from github directly (\ **NOT Recommended**\ ) * scipy>=0.19.1 * scikit_learn>=0.19.1 -**Optional Dependencies (required for running examples and AutoEncoder)**\ : +**Optional Dependencies (see details below)**\ : * Keras (optional, required if calling AutoEncoder, other backend works) * Matplotlib (optional, required for running examples) * TensorFlow (optional, required if calling AutoEncoder, other backend works) +* XGBoost (optional, required if calling XGBOD) **Known Issue 1**\ : Running examples needs Matplotlib, which may throw errors in conda virtual environment on mac OS. See reasons and solutions `issue6 <https://github.com/yzhao062/Pyod/issues/6>`_. @@ -221,11 +225,11 @@ Full API Reference: (https://pyod.readthedocs.io/en/latest/pyod.html). API cheat * **fit(X)**\ : Fit detector. -* **fit_predict(X)**\ : Fit detector and predict if a particular sample is an outlier or not. -* **fit_predict_score(X, y)**\ : Fit, predict and then evaluate with predefined metrics (ROC and precision @ rank n). -* **decision_function(X)**\ : Predict anomaly score of X of the base classifiers. -* **predict(X)**\ : Predict if a particular sample is an outlier or not. The model must be fitted first. -* **predict_proba(X)**\ : Predict the probability of a sample being outlier. The model must be fitted first. +* **fit_predict(X)**\ : Fit detector first and then predict whether a particular sample is an outlier or not. +* **fit_predict_score(X, y)**\ : Fit the detector, predict on samples, and evaluate the model by predefined metrics, e.g., ROC. +* **decision_function(X)**\ : Predict raw anomaly score of X using the fitted detector. +* **predict(X)**\ : Predict if a particular sample is an outlier or not using the fitted detector. +* **predict_proba(X)**\ : Predict the probability of a sample being outlier using the fitted detector. Key Attributes of a fitted model: diff --git a/docs/api_cc.rst b/docs/api_cc.rst index c70c9905d80cb6f26dd90896dcc3e7f365eb0ada..29dfb865c8039e591919ffda1f941667ccc0ce65 100644 --- a/docs/api_cc.rst +++ b/docs/api_cc.rst @@ -1,12 +1,21 @@ API CheatSheet ============== -* :func:`pyod.models.base.BaseDetector.fit`: Fit detector. -* :func:`pyod.models.base.BaseDetector.fit_predict`: Fit detector and predict if a particular sample is an outlier or not. -* :func:`pyod.models.base.BaseDetector.fit_predict_evaluate`: Fit, predict and then evaluate with predefined metrics (ROC and precision @ rank n). -* :func:`pyod.models.base.BaseDetector.decision_function`: Predict anomaly score of X of the base classifiers. -* :func:`pyod.models.base.BaseDetector.predict`: Predict if a particular sample is an outlier or not. The model must be fitted first. -* :func:`pyod.models.base.BaseDetector.predict_proba`: Predict the probability of a sample being outlier. The model must be fitted first. - -See full API reference :doc:`pyod`. +* :func:`pyod.models.base.BaseDetector.fit`: Fit detector. y is optional for unsupervised methods. +* :func:`pyod.models.base.BaseDetector.fit_predict`: Fit detector first and then predict whether a particular sample is an outlier or not. +* :func:`pyod.models.base.BaseDetector.fit_predict_score`: Fit the detector, predict on samples, and evaluate the model by predefined metrics, e.g., ROC. +* :func:`pyod.models.base.BaseDetector.decision_function`: Predict raw anomaly score of X using the fitted detector. +* :func:`pyod.models.base.BaseDetector.predict`: Predict if a particular sample is an outlier or not using the fitted detector. +* :func:`pyod.models.base.BaseDetector.predict_proba`: Predict the probability of a sample being outlier using the fitted detector. + +See base class definition below: + +pyod.models.base module +----------------------- + +.. automodule:: pyod.models.base + :members: + :undoc-members: + :show-inheritance: + :inherited-members: diff --git a/docs/index.rst b/docs/index.rst index b059bb357e855d8702decee0fa85461e196ddd61..56dbb6f6f68563435e21413bbd5622afd4240ce8 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -126,12 +126,9 @@ detection utility functions. i. **LOF: Local Outlier Factor** :cite:`a-breunig2000lof`: :class:`pyod.models.lof.LOF` ii. **CBLOF: Clustering-Based Local Outlier Factor** :cite:`a-he2003discovering`: :class:`pyod.models.cblof.CBLOF` iii. **LOCI: Local Correlation Integral** :cite:`a-papadimitriou2003loci`: :class:`pyod.models.loci.LOCI` - iv. **kNN: k Nearest Neighbors** (use the distance to the kth nearest - neighbor as the outlier score) :cite:`a-ramaswamy2000efficient,a-angiulli2002fast`: :class:`pyod.models.knn.KNN` - v. **Average kNN** (use the average distance to k nearest neighbors as - the outlier score): :class:`pyod.models.knn.KNN` - vi. **Median kNN** (use the median distance to k nearest neighbors - as the outlier score): :class:`pyod.models.knn.KNN` + iv. **kNN: k Nearest Neighbors** (use the distance to the kth nearest neighbor as the outlier score) :cite:`a-ramaswamy2000efficient,a-angiulli2002fast`: :class:`pyod.models.knn.KNN` + v. **Average kNN** (use the average distance to k nearest neighbors as the outlier score): :class:`pyod.models.knn.KNN` + vi. **Median kNN** (use the median distance to k nearest neighbors as the outlier score): :class:`pyod.models.knn.KNN` vii. **HBOS: Histogram-based Outlier Score** :cite:`a-goldstein2012histogram`: :class:`pyod.models.hbos.HBOS` @@ -141,10 +138,13 @@ detection utility functions. ii. **FastABOD: Fast Angle-Based Outlier Detection using approximation** :cite:`a-kriegel2008angle`: :class:`pyod.models.abod.ABOD` iii. **SOS: Stochastic Outlier Selection** :cite:`a-janssens2012stochastic`: :class:`pyod.models.sos.SOS` -4. Outlier Ensembles and Combination Frameworks +4. Outlier Ensembles: i. **Isolation Forest** :cite:`a-liu2008isolation,a-liu2012isolation`: :class:`pyod.models.iforest.IForest` ii. **Feature Bagging** :cite:`a-lazarevic2005feature`: :class:`pyod.models.feature_bagging.FeatureBagging` + iii. **LSCP**: Locally Selective Combination of Parallel Outlier Ensembles :cite:`a-zhao2018lscp`: :class:`pyod.models.lscp.LSCP` + iv. **XGBOD** :cite:`a-zhao2018xgbod`: :class:`pyod.models.xgbod.XGBOD` + 5. Neural Networks and Deep Learning Models (implemented in Keras): @@ -175,12 +175,12 @@ Key APIs & Attributes The following APIs are applicable for all detector models for easy use. -* :func:`pyod.models.base.BaseDetector.fit`: Fit detector. -* :func:`pyod.models.base.BaseDetector.fit_predict`: Fit detector and predict if a particular sample is an outlier or not. -* :func:`pyod.models.base.BaseDetector.fit_predict_evaluate`: Fit, predict and then evaluate with predefined metrics (ROC and precision @ rank n). -* :func:`pyod.models.base.BaseDetector.decision_function`: Predict anomaly score of X of the base classifiers. -* :func:`pyod.models.base.BaseDetector.predict`: Predict if a particular sample is an outlier or not. The model must be fitted first. -* :func:`pyod.models.base.BaseDetector.predict_proba`: Predict the probability of a sample being outlier. The model must be fitted first. +* :func:`pyod.models.base.BaseDetector.fit`: Fit detector. y is optional for unsupervised methods. +* :func:`pyod.models.base.BaseDetector.fit_predict`: Fit detector first and then predict whether a particular sample is an outlier or not. +* :func:`pyod.models.base.BaseDetector.fit_predict_score`: Fit the detector, predict on samples, and evaluate the model by predefined metrics, e.g., ROC. +* :func:`pyod.models.base.BaseDetector.decision_function`: Predict raw anomaly score of X using the fitted detector. +* :func:`pyod.models.base.BaseDetector.predict`: Predict if a particular sample is an outlier or not using the fitted detector. +* :func:`pyod.models.base.BaseDetector.predict_proba`: Predict the probability of a sample being outlier using the fitted detector. Key Attributes of a fitted model: diff --git a/docs/install.rst b/docs/install.rst index 649226458aa25e4e6ad7fe8058471c8a6867bd56..aba65a85334cf2b3b1bfdd614c927c317b36d4d6 100644 --- a/docs/install.rst +++ b/docs/install.rst @@ -26,11 +26,12 @@ Alternatively, install from github directly (\ **NOT Recommended**\ ) * scikit_learn>=0.19.1 -**Optional Dependencies (required for running examples or AutoEncoder)**: +**Optional Dependencies (see details below)**: - Keras (optional, required if calling AutoEncoder, other backend works) - Matplotlib (optional, required for running examples) - Tensorflow (optional, required if calling AutoEncoder, other backend works) +- XGBoost (optional, required if calling XGBOD) .. warning:: diff --git a/docs/pyod.models.rst b/docs/pyod.models.rst index 0dfdee627e17821a8216c63818bd1223e2c518fc..20e94b0b375ea1c8f90a9487d0f936bd6145f6f0 100644 --- a/docs/pyod.models.rst +++ b/docs/pyod.models.rst @@ -4,6 +4,8 @@ pyod.models package Submodules ---------- +---- + pyod.models.abod module ----------------------- @@ -22,15 +24,6 @@ pyod.models.auto\_encoder module :show-inheritance: :inherited-members: -pyod.models.base module ------------------------ - -.. automodule:: pyod.models.base - :members: - :undoc-members: - :show-inheritance: - :inherited-members: - pyod.models.cblof module ------------------------ @@ -153,6 +146,15 @@ pyod.models.sos module :show-inheritance: :inherited-members: +pyod.models.xgbod module +------------------------ + +.. automodule:: pyod.models.xgbod + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + Module contents --------------- diff --git a/docs/requirements.txt b/docs/requirements.txt index 94f58ad967c9522eed2f6303582b3e47c6bcc029..6357757766ae82893ce56f973afd42268fcf176e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -8,3 +8,4 @@ scipy>=0.19.1 scikit_learn>=0.19.1 sphinxcontrib-bibtex tensorflow +xgboost>=0.7 diff --git a/environment.yml b/environment.yml index 320963c634c56de1c16eb55b5e796ea778c0a56d..1f529727fa19a094182c6803bbce415eb1552eae 100644 --- a/environment.yml +++ b/environment.yml @@ -9,3 +9,4 @@ dependencies: - scikit-learn - scipy - tensorflow + - xgboost diff --git a/examples/temp_do_not_use.py b/examples/temp_do_not_use.py index 6cb3e1e3d766ab583186b5eec5f8316de80beb71..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/examples/temp_do_not_use.py +++ b/examples/temp_do_not_use.py @@ -1,118 +0,0 @@ -# License: BSD 2 clause - -from __future__ import division -from __future__ import print_function - -import os -import sys - -# temporary solution for relative imports in case pyod is not installed -# if pyod is installed, no need to use the following line -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) - -from sklearn.utils import check_X_y -import numpy as np -from sklearn.model_selection import train_test_split -from xgboost.sklearn import XGBClassifier -from scipy.io import loadmat - -from pyod.models.knn import KNN -from pyod.models.lof import LOF -from pyod.models.iforest import IForest -from pyod.models.hbos import HBOS -from pyod.models.ocsvm import OCSVM -from pyod.utils.data import generate_data -from pyod.utils.data import get_color_codes -from pyod.utils.data import evaluate_print -from pyod.utils.utility import standardizer - -if __name__ == "__main__": - - # Define data file and read X and y - # Generate some data if the source data is missing - mat_file = 'cardio.mat' - try: - mat = loadmat(os.path.join('data', mat_file)) - - except TypeError: - print('{data_file} does not exist. Use generated data'.format( - data_file=mat_file)) - X, y = generate_data(train_only=True) # load data - except IOError: - print('{data_file} does not exist. Use generated data'.format( - data_file=mat_file)) - X, y = generate_data(train_only=True) # load data - else: - X = mat['X'] - y = mat['y'].ravel() - - X_train, X_test, y_train, y_test = train_test_split(X, y, - test_size=0.4, - random_state=1) - # X_train_norm, X_test_norm = X_train, X_test - X_train_norm, X_test_norm = standardizer(X_train, X_test) - - estimator_list = [] - normalization_list = [] - - # predefined range of k - k_range = [1, 2, 3, 4, 5, 10, 15, 20, 30, 40, 50, - 60, 70, 80, 90, 100, 150, 200, 250] - # validate the value of k - k_range = [k for k in k_range if k < X.shape[0]] - - for k in k_range: - estimator_list.append(KNN(n_neighbors=k, method='largest')) - estimator_list.append(KNN(n_neighbors=k, method='mean')) - estimator_list.append(LOF(n_neighbors=k)) - normalization_list.append(True) - normalization_list.append(True) - normalization_list.append(True) - - n_bins_range = [3, 5, 7, 9, 12, 15, 20, 25, 30, 50] - for n_bins in n_bins_range: - estimator_list.append(HBOS(n_bins=n_bins)) - normalization_list.append(False) - - # predefined range of nu for one-class svm - nu_range = [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99] - for nu in nu_range: - estimator_list.append(OCSVM(nu=nu)) - normalization_list.append(True) - - # predefined range for number of estimators in isolation forests - n_range = [10, 20, 50, 70, 100, 150, 200, 250] - for n in n_range: - estimator_list.append(IForest(n_estimators=n)) - normalization_list.append(False) - - X_train_add = np.zeros([X_train.shape[0], len(estimator_list)]) - X_test_add = np.zeros([X_test.shape[0], len(estimator_list)]) - - # fit the model - for index, estimator in enumerate(estimator_list): - if normalization_list[index]: - estimator.fit(X_train_norm) - X_train_add[:, index] = estimator.decision_scores_ - X_test_add[:, index] = estimator.decision_function(X_test_norm) - else: - estimator.fit(X_train) - X_train_add[:, index] = estimator.decision_scores_ - X_test_add[:, index] = estimator.decision_function(X_test) - - # prepare the new feature space - X_train_new = np.concatenate((X_train, X_train_add), axis=1) - X_test_new = np.concatenate((X_test, X_test_add), axis=1) - - clf = XGBClassifier() - clf.fit(X_train_new, y_train) - y_test_scores = clf.predict_proba(X_test_new) # outlier scores - - evaluate_print('XGBOD', y_test, y_test_scores[:, 1]) - - clf = XGBClassifier() - clf.fit(X_train, y_train) - y_test_scores_orig = clf.predict_proba(X_test) # outlier scores - - evaluate_print('old', y_test, y_test_scores_orig[:, 1]) diff --git a/examples/xgbod_example.py b/examples/xgbod_example.py new file mode 100644 index 0000000000000000000000000000000000000000..2adb18028be0d419eabde780260fe1426af18f1b --- /dev/null +++ b/examples/xgbod_example.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +"""Example of using XGBOD for outlier detection +""" +# Author: Yue Zhao <yuezhao@cs.toronto.edu> +# License: BSD 2 clause + +from __future__ import division +from __future__ import print_function + +import os +import sys + +# temporary solution for relative imports in case pyod is not installed +# if pyod is installed, no need to use the following line +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) + +from sklearn.model_selection import train_test_split +from sklearn.utils.validation import check_X_y +from scipy.io import loadmat + +from pyod.models.xgbod import XGBOD +from pyod.utils.data import generate_data +from pyod.utils.data import evaluate_print + +if __name__ == "__main__": + # Define data file and read X and y + # Generate some data if the source data is missing + mat_file = 'cardio.mat' + try: + mat = loadmat(os.path.join('data', mat_file)) + + except TypeError: + print('{data_file} does not exist. Use generated data'.format( + data_file=mat_file)) + X, y = generate_data(train_only=True) # load data + except IOError: + print('{data_file} does not exist. Use generated data'.format( + data_file=mat_file)) + X, y = generate_data(train_only=True) # load data + else: + X = mat['X'] + y = mat['y'].ravel() + X, y = check_X_y(X, y) + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, + random_state=42) + + # train XGBOD detector + clf_name = 'XGBOD' + clf = XGBOD(random_state=42) + clf.fit(X_train, y_train) + + # get the prediction labels and outlier scores of the training data + y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers) + y_train_scores = clf.decision_scores_ # raw outlier scores + + # get the prediction on the test data + y_test_pred = clf.predict(X_test) # outlier labels (0 or 1) + y_test_scores = clf.decision_function(X_test) # outlier scores + + # evaluate and print the results + print("\nOn Training Data:") + evaluate_print(clf_name, y_train, y_train_scores) + print("\nOn Test Data:") + evaluate_print(clf_name, y_test, y_test_scores) diff --git a/pyod/models/abod.py b/pyod/models/abod.py index 0951c0f0f0602b6576d211d702c97dac3e5465db..c98559a7a941fda3aa0b2e25c56fa5c869c86996 100644 --- a/pyod/models/abod.py +++ b/pyod/models/abod.py @@ -115,19 +115,19 @@ class ABOD(BaseDetector): Attributes ---------- - decision_scores\_ : numpy array of shape (n_samples,) + decision_scores_ : numpy array of shape (n_samples,) The outlier scores of the training data. The higher, the more abnormal. Outliers tend to have higher scores. This value is available once the detector is fitted. - threshold\_ : float + threshold_ : float The threshold is based on ``contamination``. It is the ``n_samples * contamination`` most abnormal samples in ``decision_scores_``. The threshold is calculated for generating binary outlier labels. - labels\_ : int, either 0 or 1 + labels_ : int, either 0 or 1 The binary labels of the training data. 0 stands for inliers and 1 for outliers/anomalies. It is generated by applying ``threshold_`` on ``decision_scores_``. diff --git a/pyod/models/auto_encoder.py b/pyod/models/auto_encoder.py index fa19cf614329b14b8b2af083513d7631aad9f50f..7edf52b805624f829495ce1d6eacd23733e91a42 100644 --- a/pyod/models/auto_encoder.py +++ b/pyod/models/auto_encoder.py @@ -92,32 +92,32 @@ class AutoEncoder(BaseDetector): Attributes ---------- - encoding_dim\_ : int + encoding_dim_ : int The number of neurons in the encoding layer. - compression_rate\_ : float + compression_rate_ : float The ratio between the original feature and the number of neurons in the encoding layer. - model\_ : Keras Object + model_ : Keras Object The underlying AutoEncoder in Keras. - history\_: Keras Object + history_: Keras Object The AutoEncoder training history. - decision_scores\_ : numpy array of shape (n_samples,) + decision_scores_ : numpy array of shape (n_samples,) The outlier scores of the training data. The higher, the more abnormal. Outliers tend to have higher scores. This value is available once the detector is fitted. - threshold\_ : float + threshold_ : float The threshold is based on ``contamination``. It is the ``n_samples * contamination`` most abnormal samples in ``decision_scores_``. The threshold is calculated for generating binary outlier labels. - labels\_ : int, either 0 or 1 + labels_ : int, either 0 or 1 The binary labels of the training data. 0 stands for inliers and 1 for outliers/anomalies. It is generated by applying ``threshold_`` on ``decision_scores_``. diff --git a/pyod/models/base.py b/pyod/models/base.py index d98548630b649e871542513fbef4f3099db4b1c6..d62079106ca65d6a1b9012b37994c85374f5d2fe 100644 --- a/pyod/models/base.py +++ b/pyod/models/base.py @@ -39,19 +39,19 @@ class BaseDetector(object): Attributes ---------- - decision_scores\_ : numpy array of shape (n_samples,) + decision_scores_ : numpy array of shape (n_samples,) The outlier scores of the training data. The higher, the more abnormal. Outliers tend to have higher scores. This value is available once the detector is fitted. - threshold\_ : float + threshold_ : float The threshold is based on ``contamination``. It is the ``n_samples * contamination`` most abnormal samples in ``decision_scores_``. The threshold is calculated for generating binary outlier labels. - labels\_ : int, either 0 or 1 + labels_ : int, either 0 or 1 The binary labels of the training data. 0 stands for inliers and 1 for outliers/anomalies. It is generated by applying ``threshold_`` on ``decision_scores_``. @@ -68,18 +68,22 @@ class BaseDetector(object): @abc.abstractmethod def decision_function(self, X): - """Predict anomaly score of X of the base classifiers. + """Predict raw anomaly score of X using the fitted detector. The anomaly score of an input sample is computed based on different detector algorithms. For consistency, outliers are assigned with larger anomaly scores. - :param X: The training input samples. Sparse matrices are accepted only + Parameters + ---------- + X : numpy array of shape (n_samples, n_features) + The training input samples. Sparse matrices are accepted only if they are supported by the base estimator. - :type X: numpy array of shape (n_samples, n_features) - :return: The anomaly score of the input samples. - :rtype: array, shape (n_samples,) + Returns + ------- + anomaly_scores : numpy array of shape (n_samples,) + The anomaly score of the input samples. """ pass @@ -125,13 +129,17 @@ class BaseDetector(object): def predict(self, X): """Predict if a particular sample is an outlier or not. - :param X: The input samples - :type X: numpy array of shape (n_samples, n_features) + Parameters + ---------- + X : numpy array of shape (n_samples, n_features) + The input samples. - :return: For each observation, tells whether or not - it should be considered as an outlier according to the fitted - model. 0 stands for inliers and 1 for outliers. - :rtype: array, shape (n_samples,) + Returns + ------- + outlier_labels : numpy array of shape (n_samples,) + For each observation, tells whether or not + it should be considered as an outlier according to the + fitted model. 0 stands for inliers and 1 for outliers. """ check_is_fitted(self, ['decision_scores_', 'threshold_', 'labels_']) @@ -148,17 +156,22 @@ class BaseDetector(object): fitted first. 2. use unifying scores, see :cite:`kriegel2011interpreting`. - :param X: The input samples - :type X: numpy array of shape (n_samples, n_features) + Parameters + ---------- + X : numpy array of shape (n_samples, n_features) + The input samples. - :param method: probability conversion method. It must be one of + method : str, optional (default='linear') + probability conversion method. It must be one of 'linear' or 'unify'. - :type method: str, optional (default='linear') - - :return: For each observation, return the outlier probability, ranging - in [0,1] - :rtype: array, shape (n_samples,) + Returns + ------- + outlier_labels : numpy array of shape (n_samples,) + For each observation, tells whether or not + it should be considered as an outlier according to the + fitted model. Return the outlier probability, ranging + in [0,1]. """ check_is_fitted(self, ['decision_scores_', 'threshold_', 'labels_']) @@ -215,22 +228,25 @@ class BaseDetector(object): def fit_predict_score(self, X, y, scoring='roc_auc_score'): """Fit the detector, predict on samples, and evaluate the model by - ROC and Precision @ rank n + predefined metrics, e.g., ROC. - :param X: The input samples - :type X: numpy array of shape (n_samples, n_features) + Parameters + ---------- + X : numpy array of shape (n_samples, n_features) + The input samples. - :param y: Outlier labels of the input samples - :type y: array, shape (n_samples,) + y : numpy array of shape (n_samples,), optional (default=None) + The ground truth of the input samples (labels). - :param scoring: Evaluation metric + scoring : str, optional (default='roc_auc_score') + Evaluation metric - -' roc_auc_score': ROC score - - 'prc_n_score': Precision @ rank n score - :type scoring: str, optional (default='roc_auc_score') + -' roc_auc_score': ROC score + - 'prc_n_score': Precision @ rank n score - :return: Evaluation score - :rtype: float + Returns + ------- + evaluation_score : float """ self.fit(X) diff --git a/pyod/models/hbos.py b/pyod/models/hbos.py index fae14c878ed57742b47128774edd774412de921e..da50b8431c228fd415f11966ceb5908344c3a159 100644 --- a/pyod/models/hbos.py +++ b/pyod/models/hbos.py @@ -13,54 +13,59 @@ from sklearn.utils import check_array from sklearn.utils.validation import check_is_fitted from ..utils.utility import check_parameter +from ..utils.utility import invert_order from .base import BaseDetector class HBOS(BaseDetector): - """ - Histogram- based outlier detection (HBOS) is an efficient unsupervised - method [1]. It assumes the feature independence and calculates the degree + """Histogram- based outlier detection (HBOS) is an efficient unsupervised + method. It assumes the feature independence and calculates the degree of outlyingness by building histograms. See :cite:`goldstein2012histogram` for details. - :param n_bins: The number of bins - :type n_bins: int, optional (default=10) + Parameters + ---------- + n_bins : int, optional (default=10) + The number of bins - :param alpha: The regularizer for preventing overflow - :type alpha: float in (0, 1), optional (default=0.1) + alpha : float in (0, 1), optional (default=0.1) + The regularizer for preventing overflow - :param tol: The parameter to decide the flexibility while dealing + tol : float in (0, 1), optional (default=0.1) + The parameter to decide the flexibility while dealing the samples falling outside the bins. - :type tol: float in (0, 1), optional (default=0.1) - :param contamination: The amount of contamination of the data set, i.e. - the proportion of outliers in the data set. When fitting this is used - to define the threshold on the decision function. - :type contamination: float in (0., 0.5), optional (default=0.1) - :var bin_edges\_: The edges of the bins - :vartype bin_edges\_: numpy array of shape (n_bins + 1, n_features ) + contamination : float in (0., 0.5), optional (default=0.1) + The amount of contamination of the data set, + i.e. the proportion of outliers in the data set. Used when fitting to + define the threshold on the decision function. + + Attributes + ---------- + bin_edges_ : numpy array of shape (n_bins + 1, n_features ) + The edges of the bins - :var hist\_: The density of each histogram - :vartype hist\_: numpy array of shape (n_bins, n_features) + hist_ : numpy array of shape (n_bins, n_features) + The density of each histogram - :var decision_scores\_: The outlier scores of the training data. + decision_scores_ : numpy array of shape (n_samples,) + The outlier scores of the training data. The higher, the more abnormal. Outliers tend to have higher scores. This value is available once the detector is fitted. - :vartype decision_scores\_: numpy array of shape (n_samples,) - :var threshold\_: The threshold is based on ``contamination``. It is the + threshold_ : float + The threshold is based on ``contamination``. It is the ``n_samples * contamination`` most abnormal samples in ``decision_scores_``. The threshold is calculated for generating binary outlier labels. - :vartype threshold\_: float - :var labels\_: The binary labels of the training data. 0 stands for inliers + labels_ : int, either 0 or 1 + The binary labels of the training data. 0 stands for inliers and 1 for outliers/anomalies. It is generated by applying ``threshold_`` on ``decision_scores_``. - :vartype labels\_: int, either 0 or 1 """ def __init__(self, n_bins=10, alpha=0.1, tol=0.5, contamination=0.1): @@ -73,7 +78,7 @@ class HBOS(BaseDetector): check_parameter(tol, 0, 1, param_name='tol') def fit(self, X, y=None): - # Validate inputs X and y (optional) + # validate inputs X and y (optional) X = check_array(X) self._set_n_classes(y) @@ -95,8 +100,8 @@ class HBOS(BaseDetector): self.n_bins, self.alpha, self.tol) - # Invert decision_scores_. Outliers comes with higher outlier scores - self.decision_scores_ = np.sum(outlier_scores, axis=1) * -1 + # invert decision_scores_. Outliers comes with higher outlier scores + self.decision_scores_ = invert_order(np.sum(outlier_scores, axis=1)) self._process_decision_scores() return self @@ -109,7 +114,7 @@ class HBOS(BaseDetector): self.hist_, self.n_bins, self.alpha, self.tol) - return np.sum(outlier_scores, axis=1).ravel() * -1 + return invert_order(np.sum(outlier_scores, axis=1)) @njit diff --git a/pyod/models/xgbod.py b/pyod/models/xgbod.py index 500aaf3a788a3cafc51e93b87c3a3f4f1bf51896..b6c749b50e4c0ecdd3287b6ca3bdd4ae3c144855 100644 --- a/pyod/models/xgbod.py +++ b/pyod/models/xgbod.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """XGBOD: Improving Supervised Outlier Detection with Unsupervised -Representation Learning +Representation Learning. A semi-supervised outlier detection framework. """ # Author: Yue Zhao <yuezhao@cs.toronto.edu> # License: BSD 2 clause @@ -8,10 +8,11 @@ from __future__ import division from __future__ import print_function import numpy as np -from sklearn.neighbors import NearestNeighbors -from sklearn.neighbors import KDTree +from sklearn.metrics import roc_auc_score from sklearn.utils import check_array from sklearn.utils.validation import check_is_fitted +from sklearn.utils.validation import check_X_y +from xgboost.sklearn import XGBClassifier from .base import BaseDetector from .knn import KNN @@ -19,47 +20,174 @@ from .lof import LOF from .iforest import IForest from .hbos import HBOS from .ocsvm import OCSVM + +from ..utils.utility import check_parameter from ..utils.utility import check_detector +from ..utils.utility import standardizer +from ..utils.utility import precision_n_scores class XGBOD(BaseDetector): """XGBOD class for outlier detection. - For an observation, its distance to its kth nearest neighbor could be - viewed as the outlying score. It could be viewed as a way to measure - the density. See :cite:`ramaswamy2000efficient,angiulli2002fast` for - details. - + It first use the passed in unsupervised outlier detectors to extract + richer representation of the data and then concatenate the newly + generated features to the original feature for constructing the augmented + feature space. An XGBoost classifier is then applied on this augmented + feature space. Read more in the :cite:`zhao2018xgbod`. Parameters ---------- - estimator_list - standardization_flag_list - random_state + estimator_list : list, optional (default=None) + The list of pyod detectors passed in for unsupervised learning + + standardization_flag_list : list, optional (default=None) + The list of boolean flags for indicating whether to take + standardization for each detector. + + max_depth : int + Maximum tree depth for base learners. + + learning_rate : float + Boosting learning rate (xgb's "eta") + + n_estimators : int + Number of boosted trees to fit. + + silent : boolean + Whether to print messages while running boosting. + + objective : string or callable + Specify the learning task and the corresponding learning objective or + a custom objective function to be used (see note below). + + booster : string + Specify which booster to use: gbtree, gblinear or dart. + + n_jobs : int + Number of parallel threads used to run xgboost. (replaces ``nthread``) + + gamma : float + Minimum loss reduction required to make a further partition on a leaf + node of the tree. + + min_child_weight : int + Minimum sum of instance weight(hessian) needed in a child. + + max_delta_step : int + Maximum delta step we allow each tree's weight estimation to be. + + subsample : float + Subsample ratio of the training instance. + + colsample_bytree : float + Subsample ratio of columns when constructing each tree. + + colsample_bylevel : float + Subsample ratio of columns for each split, in each level. + + reg_alpha : float (xgb's alpha) + L1 regularization term on weights. + + reg_lambda : float (xgb's lambda) + L2 regularization term on weights. + + scale_pos_weight : float + Balancing of positive and negative weights. + + base_score: + The initial prediction score of all instances, global bias. + + random_state : int + Random number seed. (replaces seed) + + missing : float, optional + Value in the data which needs to be present as a missing value. If + None, defaults to np.nan. + + importance_type: string, default "gain" + The feature importance type for the ``feature_importances_`` + property: either "gain", + "weight", "cover", "total_gain" or "total_cover". + + \*\*kwargs : dict, optional + Keyword arguments for XGBoost Booster object. Full documentation of + parameters can be found here: + https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst. + Attempting to set a parameter via the constructor args and \*\*kwargs + dict simultaneously will result in a TypeError. + .. note:: \*\*kwargs unsupported by scikit-learn + \*\*kwargs is unsupported by scikit-learn. We do not guarantee + that parameters passed via this argument will interact properly + with scikit-learn. + + Attributes + ---------- + n_detector_ : int + The number of unsupervised of detectors used. + clf_ : object + The XGBoost classifier. """ def __init__(self, estimator_list=None, standardization_flag_list=None, - random_state=None): + max_depth=3, learning_rate=0.1, + n_estimators=100, silent=True, + objective="binary:logistic", booster='gbtree', + n_jobs=1, nthread=None, gamma=0, min_child_weight=1, + max_delta_step=0, subsample=1, colsample_bytree=1, + colsample_bylevel=1, + reg_alpha=0, reg_lambda=1, scale_pos_weight=1, + base_score=0.5, random_state=0, missing=None, + **kwargs): super(XGBOD, self).__init__() self.estimator_list = estimator_list self.standardization_flag_list = standardization_flag_list + self.max_depth = max_depth + self.learning_rate = learning_rate + self.n_estimators = n_estimators + self.silent = silent + self.objective = objective + self.booster = booster + self.n_jobs = n_jobs + self.nthread = nthread + self.gamma = gamma + self.min_child_weight = min_child_weight + self.max_delta_step = max_delta_step + self.subsample = subsample + self.colsample_bytree = colsample_bytree + self.colsample_bylevel = colsample_bylevel + self.reg_alpha = reg_alpha + self.reg_lambda = reg_lambda + self.scale_pos_weight = scale_pos_weight + self.base_score = base_score self.random_state = random_state - - if self.standardization_flag_list is None: - if len(self.estimator_list) != len(self.standardization_flag_list): - raise ValueError( - "estimator_list length ({0}) is not equal " - "to standardization_flag_list length ({1})".format( - len(self.estimator_list), - len(self.standardization_flag_list))) + self.missing = missing + self.kwargs = kwargs def _init_detectors(self, X): + """initialize unsupervised detectors if no predefined detectors is + provided. + + Parameters + ---------- + X : numpy array of shape (n_samples, n_features) + The train data + + Returns + ------- + estimator_list : list of object + The initialized list of detectors + + standardization_flag_list : list of boolean + The list of bool flag to indicate whether standardization is needed + + """ estimator_list = [] standardization_flag_list = [] - # predefined range of k - k_range = [1, 2, 3, 4, 5, 10, 15, 20, 30, 40, 50, - 60, 70, 80, 90, 100, 150, 200, 250] + # predefined range of n_neighbors for KNN, AvgKNN, and LOF + k_range = [1, 3, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] + # validate the value of k k_range = [k for k in k_range if k < X.shape[0]] @@ -79,76 +207,194 @@ class XGBOD(BaseDetector): # predefined range of nu for one-class svm nu_range = [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99] for nu in nu_range: - estimator_list.append(OCSVM(nu=nu)) + estimator_list.append(OCSVM(nu=nu, random_state=self.random_state)) standardization_flag_list.append(True) # predefined range for number of estimators in isolation forests n_range = [10, 20, 50, 70, 100, 150, 200, 250] for n in n_range: - estimator_list.append(IForest(n_estimators=n)) + estimator_list.append( + IForest(n_estimators=n, random_state=self.random_state)) standardization_flag_list.append(False) return estimator_list, standardization_flag_list - def fit(self, X, y=None): - - # Validate inputs X and y (optional) - X = check_array(X) - self._set_n_classes(y) - + def _validate_estimator(self, X): if self.estimator_list is None: self.estimator_list, \ self.standardization_flag_list = self._init_detectors(X) + # perform standardization for all detectors by default if self.standardization_flag_list is None: self.standardization_flag_list = [True] * len(self.estimator_list) - self.tree_ = KDTree(X, leaf_size=self.leaf_size, metric=self.metric) - self.neigh_.fit(X) + # validate two lists length + if len(self.estimator_list) != len(self.standardization_flag_list): + raise ValueError( + "estimator_list length ({0}) is not equal " + "to standardization_flag_list length ({1})".format( + len(self.estimator_list), + len(self.standardization_flag_list))) + + # validate the estimator list is not empty + check_parameter(len(self.estimator_list), low=1, + param_name='number of estimators', + include_left=True, include_right=True) + + for estimator in self.estimator_list: + check_detector(estimator) + + return len(self.estimator_list) - dist_arr, _ = self.neigh_.kneighbors(n_neighbors=self.n_neighbors, - return_distance=True) + def _generate_new_features(self, X): + X_add = np.zeros([X.shape[0], self.n_detector_]) - dist = np.zeros(shape=(X.shape[0], 1)) - if self.method == 'largest': - dist = dist_arr[:, -1] - elif self.method == 'mean': - dist = np.mean(dist_arr, axis=1) - elif self.method == 'median': - dist = np.median(dist_arr, axis=1) + # keep the standardization scalar for test conversion + X_norm = self._scalar.transform(X) - self.decision_scores_ = dist.ravel() - self._process_decision_scores() + for ind, estimator in enumerate(self.estimator_list): + if self.standardization_flag_list[ind]: + X_add[:, ind] = estimator.decision_function(X_norm) + + else: + X_add[:, ind] = estimator.decision_function(X) + return X_add + + def fit(self, X, y): + """Fit the model using X and y as training data. + + Parameters + ---------- + X : numpy array of shape (n_samples, n_features) + Training data. + + y : numpy array of shape (n_samples,) + The ground truth (binary label) + + - 0 : inliers + - 1 : outliers + + Returns + ------- + self : object + """ + + # Validate inputs X and y + X, y = check_X_y(X, y) + X = check_array(X) + self._set_n_classes(y) + self.n_detector_ = self._validate_estimator(X) + self.X_train_add_ = np.zeros([X.shape[0], self.n_detector_]) + + # keep the standardization scalar for test conversion + X_norm, self._scalar = standardizer(X, keep_scalar=True) + + for ind, estimator in enumerate(self.estimator_list): + if self.standardization_flag_list[ind]: + estimator.fit(X_norm) + self.X_train_add_[:, ind] = estimator.decision_scores_ + + else: + estimator.fit(X) + self.X_train_add_[:, ind] = estimator.decision_scores_ + + # construct the new feature space + self.X_train_new_ = np.concatenate((X, self.X_train_add_), axis=1) + + # initialize, train, and predict on XGBoost + self.clf_ = clf = XGBClassifier(max_depth=self.max_depth, + learning_rate=self.learning_rate, + n_estimators=self.n_estimators, + silent=self.silent, + objective=self.objective, + booster=self.booster, + n_jobs=self.n_jobs, + nthread=self.nthread, + gamma=self.gamma, + min_child_weight=self.min_child_weight, + max_delta_step=self.max_delta_step, + subsample=self.subsample, + colsample_bytree=self.colsample_bytree, + colsample_bylevel=self.colsample_bylevel, + reg_alpha=self.reg_alpha, + reg_lambda=self.reg_lambda, + scale_pos_weight=self.scale_pos_weight, + base_score=self.base_score, + random_state=self.random_state, + missing=self.missing, + **self.kwargs) + self.clf_.fit(self.X_train_new_, y) + self.decision_scores_ = self.clf_.predict_proba( + self.X_train_new_)[:, 1] + self.labels_ = self.clf_.predict(self.X_train_new_).ravel() return self def decision_function(self, X): - check_is_fitted(self, - ['tree_', 'decision_scores_', 'threshold_', 'labels_']) + check_is_fitted(self, ['clf_', 'decision_scores_', + 'labels_', '_scalar']) X = check_array(X) - # initialize the output score - pred_scores = np.zeros([X.shape[0], 1]) + # construct the new feature space + X_add = self._generate_new_features(X) + X_new = np.concatenate((X, X_add), axis=1) - for i in range(X.shape[0]): - x_i = X[i, :] - x_i = np.asarray(x_i).reshape(1, x_i.shape[0]) + pred_scores = self.clf_.predict_proba(X_new)[:, 1] + return pred_scores.ravel() - # get the distance of the current point - dist_arr, _ = self.tree_.query(x_i, k=self.n_neighbors) + def predict(self, X): - if self.method == 'largest': - dist = dist_arr[:, -1] - elif self.method == 'mean': - dist = np.mean(dist_arr, axis=1) - elif self.method == 'median': - dist = np.median(dist_arr, axis=1) + check_is_fitted(self, ['clf_', 'decision_scores_', + 'labels_', '_scalar']) - pred_score_i = dist[-1] + X = check_array(X) - # record the current item - pred_scores[i, :] = pred_score_i + # construct the new feature space + X_add = self._generate_new_features(X) + X_new = np.concatenate((X, X_add), axis=1) + pred_scores = self.clf_.predict(X_new) return pred_scores.ravel() + + def predict_proba(self, X): + return self.decision_function(X) + + def fit_predict(self, X, y): + self.fit(X, y) + return self.labels_ + + def fit_predict_score(self, X, y, scoring='roc_auc_score'): + """Fit the detector, predict on samples, and evaluate the model by + ROC and Precision @ rank n + + :param X: The input samples + :type X: numpy array of shape (n_samples, n_features) + + :param y: Outlier labels of the input samples + :type y: array, shape (n_samples,) + + :param scoring: Evaluation metric + + -' roc_auc_score': ROC score + - 'prc_n_score': Precision @ rank n score + :type scoring: str, optional (default='roc_auc_score') + + :return: Evaluation score + :rtype: float + """ + + self.fit(X, y) + + if scoring == 'roc_auc_score': + score = roc_auc_score(y, self.decision_scores_) + elif scoring == 'prc_n_score': + score = precision_n_scores(y, self.decision_scores_) + else: + raise NotImplementedError('PyOD built-in scoring only supports ' + 'ROC and Precision @ rank n') + + print("{metric}: {score}".format(metric=scoring, score=score)) + + return score diff --git a/pyod/test/data/pima.mat b/pyod/test/data/pima.mat new file mode 100644 index 0000000000000000000000000000000000000000..15e626e9228f82a7965f0677f4bff5e2cf4421dc Binary files /dev/null and b/pyod/test/data/pima.mat differ diff --git a/pyod/test/test_utility.py b/pyod/test/test_utility.py index a11b4872d5cad068f304876744a7b10b60bbd99c..c09f8ba2478e2d8aaeb51720102e1499d3c14f84 100644 --- a/pyod/test/test_utility.py +++ b/pyod/test/test_utility.py @@ -13,6 +13,7 @@ from sklearn.utils.testing import assert_allclose from sklearn.utils.testing import assert_less_equal from sklearn.utils.testing import assert_raises from sklearn.metrics import precision_score +from sklearn.utils import check_random_state import numpy as np @@ -195,24 +196,55 @@ class TestParameters(unittest.TestCase): class TestScaler(unittest.TestCase): def setUp(self): - self.X_train = np.random.rand(500, 5) - self.X_test = np.random.rand(50, 5) + random_state = check_random_state(42) + self.X_train = random_state.rand(500, 5) + self.X_test = random_state.rand(100, 5) + self.X_test_diff = random_state.rand(100, 10) self.scores1 = [0.1, 0.3, 0.5, 0.7, 0.2, 0.1] self.scores2 = np.array([0.1, 0.3, 0.5, 0.7, 0.2, 0.1]) def test_normalization(self): - norm_X_train, norm_X_test = standardizer(self.X_train, self.X_train) + + # test when X_t is presented and no scalar + norm_X_train, norm_X_test = standardizer(self.X_train, self.X_test) assert_allclose(norm_X_train.mean(), 0, atol=0.05) assert_allclose(norm_X_train.std(), 1, atol=0.05) assert_allclose(norm_X_test.mean(), 0, atol=0.05) assert_allclose(norm_X_test.std(), 1, atol=0.05) - # test when X_t is not presented + # test when X_t is not presented and no scalar norm_X_train = standardizer(self.X_train) assert_allclose(norm_X_train.mean(), 0, atol=0.05) assert_allclose(norm_X_train.std(), 1, atol=0.05) + # test when X_t is presented and the scalar is kept + norm_X_train, norm_X_test, scalar = standardizer(self.X_train, + self.X_test, + keep_scalar=True) + + assert_allclose(norm_X_train.mean(), 0, atol=0.05) + assert_allclose(norm_X_train.std(), 1, atol=0.05) + + assert_allclose(norm_X_test.mean(), 0, atol=0.05) + assert_allclose(norm_X_test.std(), 1, atol=0.05) + + if not hasattr(scalar, 'fit') or not hasattr(scalar, 'transform'): + raise AttributeError("%s is not a detector instance." % (scalar)) + + # test when X_t is not presented and the scalar is kept + norm_X_train, scalar = standardizer(self.X_train, keep_scalar=True) + + assert_allclose(norm_X_train.mean(), 0, atol=0.05) + assert_allclose(norm_X_train.std(), 1, atol=0.05) + + if not hasattr(scalar, 'fit') or not hasattr(scalar, 'transform'): + raise AttributeError("%s is not a detector instance." % (scalar)) + + # test shape difference + with assert_raises(ValueError): + standardizer(self.X_train, self.X_test_diff) + def test_invert_order(self): target = np.array([-0.1, -0.3, -0.5, -0.7, -0.2, -0.1]).ravel() scores1 = invert_order(self.scores1) @@ -265,12 +297,14 @@ class TestCheckDetector(unittest.TestCase): class DummyNegativeModel(): def fit_negative(self): return + def decision_function_negative(self): return class DummyPostiveModel(): def fit(self): return + def decision_function(self): return diff --git a/pyod/test/test_xgbod.py b/pyod/test/test_xgbod.py new file mode 100644 index 0000000000000000000000000000000000000000..a90b7a1b22865488dc0e4b4e89528eca0f7267f6 --- /dev/null +++ b/pyod/test/test_xgbod.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- + +from __future__ import division +from __future__ import print_function + +import os +import sys +from os import path + +import unittest +# noinspection PyProtectedMember +from sklearn.utils.testing import assert_allclose +from sklearn.utils.testing import assert_array_less +from sklearn.utils.testing import assert_equal +from sklearn.utils.testing import assert_greater +from sklearn.utils.testing import assert_greater_equal +from sklearn.utils.testing import assert_less_equal +from sklearn.utils.testing import assert_raises +from sklearn.utils.testing import assert_true +from sklearn.utils.estimator_checks import check_estimator + +from sklearn.metrics import roc_auc_score +from sklearn.model_selection import train_test_split +from sklearn.utils.validation import check_X_y +from scipy.io import loadmat +from scipy.stats import rankdata + +# temporary solution for relative imports in case pyod is not installed +# if pyod is installed, no need to use the following line +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from pyod.models.xgbod import XGBOD +from pyod.utils.data import generate_data + + +class TestXGBOD(unittest.TestCase): + def setUp(self): + # Define data file and read X and y + # Generate some data if the source data is missing + this_directory = path.abspath(path.dirname(__file__)) + mat_file = 'pima.mat' + try: + mat = loadmat(path.join(*[this_directory, 'data', mat_file])) + + except TypeError: + print('{data_file} does not exist. Use generated data'.format( + data_file=mat_file)) + X, y = generate_data(train_only=True) # load data + except IOError: + print('{data_file} does not exist. Use generated data'.format( + data_file=mat_file)) + X, y = generate_data(train_only=True) # load data + else: + X = mat['X'] + y = mat['y'].ravel() + X, y = check_X_y(X, y) + + self.X_train, self.X_test, self.y_train, self.y_test = \ + train_test_split(X, y, test_size=0.4, random_state=42) + + self.clf = XGBOD(random_state=42) + self.clf.fit(self.X_train, self.y_train) + + self.roc_floor = 0.8 + + def test_parameters(self): + assert_true(hasattr(self.clf, 'clf_') and + self.clf.decision_scores_ is not None) + assert_true(hasattr(self.clf, '_scalar') and + self.clf.labels_ is not None) + assert_true(hasattr(self.clf, 'n_detector_') and + self.clf.labels_ is not None) + assert_true(hasattr(self.clf, 'X_train_add_') and + self.clf.labels_ is not None) + assert_true(hasattr(self.clf, 'decision_scores_') and + self.clf.decision_scores_ is not None) + assert_true(hasattr(self.clf, 'labels_') and + self.clf.labels_ is not None) + + def test_train_scores(self): + assert_equal(len(self.clf.decision_scores_), self.X_train.shape[0]) + + def test_prediction_scores(self): + pred_scores = self.clf.decision_function(self.X_test) + + # check score shapes + assert_equal(pred_scores.shape[0], self.X_test.shape[0]) + + # check performance + assert_greater(roc_auc_score(self.y_test, pred_scores), self.roc_floor) + + def test_prediction_labels(self): + pred_labels = self.clf.predict(self.X_test) + assert_equal(pred_labels.shape, self.y_test.shape) + + def test_prediction_proba(self): + pred_proba = self.clf.predict_proba(self.X_test) + assert_greater_equal(pred_proba.min(), 0) + assert_less_equal(pred_proba.max(), 1) + # check performance + assert_greater(roc_auc_score(self.y_test, pred_proba), self.roc_floor) + + def test_fit_predict(self): + pred_labels = self.clf.fit_predict(self.X_train, self.y_train) + assert_equal(pred_labels.shape, self.y_train.shape) + + def test_fit_predict_score(self): + self.clf.fit_predict_score(self.X_test, self.y_test) + self.clf.fit_predict_score(self.X_test, self.y_test, + scoring='roc_auc_score') + self.clf.fit_predict_score(self.X_test, self.y_test, + scoring='prc_n_score') + with assert_raises(NotImplementedError): + self.clf.fit_predict_score(self.X_test, self.y_test, + scoring='something') + + def test_predict_rank(self): + pred_socres = self.clf.decision_function(self.X_test) + pred_ranks = self.clf._predict_rank(self.X_test) + print(pred_ranks) + + # assert the order is reserved + assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=4) + assert_array_less(pred_ranks, self.X_train.shape[0] + 1) + assert_array_less(-0.1, pred_ranks) + + def test_predict_rank_normalized(self): + pred_socres = self.clf.decision_function(self.X_test) + pred_ranks = self.clf._predict_rank(self.X_test, normalized=True) + + # assert the order is reserved + assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=4) + assert_array_less(pred_ranks, 1.01) + assert_array_less(-0.1, pred_ranks) + + def tearDown(self): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/pyod/utils/utility.py b/pyod/utils/utility.py index cc377943ceae99e724ab2b7afef30132e651e78c..42f100ffe07f3a7057ce95ce358ffd3f098f60f0 100644 --- a/pyod/utils/utility.py +++ b/pyod/utils/utility.py @@ -115,11 +115,12 @@ def check_detector(detector): """ - if not hasattr(detector, 'fit') or not hasattr(detector, 'decision_function'): + if not hasattr(detector, 'fit') or not hasattr(detector, + 'decision_function'): raise AttributeError("%s is not a detector instance." % (detector)) -def standardizer(X, X_t=None): +def standardizer(X, X_t=None, keep_scalar=False): """Conduct Z-normalization on data to turn input samples become zero-mean and unit variance. @@ -128,9 +129,12 @@ def standardizer(X, X_t=None): X : ndarray (n_samples, n_features) The training samples - X_t : ndarray (n_samples_new, n_features), default=None + X_t : ndarray (n_samples_new, n_features), optional (default=None) The data to be converted + keep_scalar : bool, optional (default=False) + The flag to indicate whether to return the scalar + Returns ------- X_norm : ndarray (n_samples, n_features) @@ -139,15 +143,29 @@ def standardizer(X, X_t=None): X_t_norm : ndarray (n_samples, n_features) X_t after the Z-score normalization + scalar : sklearn scalar object + The scalar used in conversion + """ X = check_array(X) - if X_t is None: - return StandardScaler().fit_transform(X) - - X_t = check_array(X_t) - assert_equal(X.shape[1], X_t.shape[1]) scaler = StandardScaler().fit(X) - return scaler.transform(X), scaler.transform(X_t) + + if X_t is None: + if keep_scalar: + return scaler.transform(X), scaler + else: + return scaler.transform(X) + else: + X_t = check_array(X_t) + if X.shape[1] != X_t.shape[1]: + raise ValueError( + "The number of input data feature should be consistnt" + "X has {0} features and X_t has {1} features.".format( + X.shape[1], X_t.shape[1])) + if keep_scalar: + return scaler.transform(X), scaler.transform(X_t), scaler + else: + return scaler.transform(X), scaler.transform(X_t) def score_to_label(pred_scores, outliers_fraction=0.1): @@ -183,6 +201,7 @@ def precision_n_scores(y, y_pred, n=None): return precision_score(y, y_pred) + def generate_bagging_indices(random_state, bootstrap_features, n_features, min_features, max_features): """ Randomly draw feature indices. Internal use only. @@ -217,7 +236,7 @@ def generate_bagging_indices(random_state, bootstrap_features, n_features, # Draw indices feature_indices = generate_indices(random_state, bootstrap_features, - n_features, random_n_features) + n_features, random_n_features) return feature_indices diff --git a/requirements_ci.txt b/requirements_ci.txt index ad8518553d6abc258d2e7e04240f3b32d6ef14d6..77947a4b422cc33f524569c1ad900ced0a090fe4 100644 --- a/requirements_ci.txt +++ b/requirements_ci.txt @@ -5,4 +5,5 @@ numpy>=1.13 numba>=0.35 scipy>=0.19.1 scikit_learn>=0.19.1 -tensorflow \ No newline at end of file +tensorflow +xgboost>=0.7 \ No newline at end of file