Kernel SGD ========== Usage ----- The kernel sgd module can be imported as: .. code-block:: python import sklearn_extensions as ske mdl = ske.kernel_sgd.KernelSGD() mdl.fit_predict(X, y) It also comes along with two helper classes: .. code-block:: python ske.kernel_sgd.HingeLoss() ske.kernel_sgd.GaussianKernel() Which can be used to create KernelSGD models, or you can bring your own loss and kernel classes. Examples -------- .. code-block:: python from sklearn_extensions.kernel_sgd.kernel_sgd import GaussianKernel, KernelSGD, HingeLoss import numpy as np def gen_non_lin_separable_data(): mean1 = [-1, 2] mean2 = [1, -1] mean3 = [4, -4] mean4 = [-4, 4] cov = [[1.0,0.8], [0.8, 1.0]] X1 = np.random.multivariate_normal(mean1, cov, 50) X1 = np.vstack((X1, np.random.multivariate_normal(mean3, cov, 50))) y1 = np.ones(len(X1)) X2 = np.random.multivariate_normal(mean2, cov, 50) X2 = np.vstack((X2, np.random.multivariate_normal(mean4, cov, 50))) y2 = np.ones(len(X2)) * -1 return X1, y1, X2, y2 def split_train(X1, y1, X2, y2): X1_train = X1[:90] y1_train = y1[:90] X2_train = X2[:90] y2_train = y2[:90] X_train = np.vstack((X1_train, X2_train)) y_train = np.hstack((y1_train, y2_train)) return X_train, y_train def split_test(X1, y1, X2, y2): X1_test = X1[90:] y1_test = y1[90:] X2_test = X2[90:] y2_test = y2[90:] X_test = np.vstack((X1_test, X2_test)) y_test = np.hstack((y1_test, y2_test)) return X_test, y_test def test(): X1, y1, X2, y2 = gen_non_lin_separable_data() X_train, y_train = split_train(X1, y1, X2, y2) X_test, y_test = split_test(X1, y1, X2, y2) clf = KernelSGD(kernel=GaussianKernel(), loss=HingeLoss(threshold=0), n_iter=5) clf.fit(X_train, y_train) y_predict = clf.predict(X_test) correct = np.sum(y_predict == y_test) print("%d out of %d predictions correct" % (correct, len(y_predict))) test() Which yields the output: .. code-block:: python 8 support vectors out of 180 points 20 out of 20 predictions correct Third Party Docs ---------------- The original unmodified version of this module's code if from a gist that can be found here: `Kernel SGD `_