Kernel SGD

Usage

The kernel sgd module can be imported as:

import sklearn_extensions as ske

mdl = ske.kernel_sgd.KernelSGD()
mdl.fit_predict(X, y)

It also comes along with two helper classes:

ske.kernel_sgd.HingeLoss()
ske.kernel_sgd.GaussianKernel()

Which can be used to create KernelSGD models, or you can bring your own loss and kernel classes.

Examples

from sklearn_extensions.kernel_sgd.kernel_sgd import GaussianKernel, KernelSGD, HingeLoss
import numpy as np


def gen_non_lin_separable_data():
    mean1 = [-1, 2]
    mean2 = [1, -1]
    mean3 = [4, -4]
    mean4 = [-4, 4]
    cov = [[1.0,0.8], [0.8, 1.0]]
    X1 = np.random.multivariate_normal(mean1, cov, 50)
    X1 = np.vstack((X1, np.random.multivariate_normal(mean3, cov, 50)))
    y1 = np.ones(len(X1))
    X2 = np.random.multivariate_normal(mean2, cov, 50)
    X2 = np.vstack((X2, np.random.multivariate_normal(mean4, cov, 50)))
    y2 = np.ones(len(X2)) * -1
    return X1, y1, X2, y2


def split_train(X1, y1, X2, y2):
    X1_train = X1[:90]
    y1_train = y1[:90]
    X2_train = X2[:90]
    y2_train = y2[:90]
    X_train = np.vstack((X1_train, X2_train))
    y_train = np.hstack((y1_train, y2_train))
    return X_train, y_train


def split_test(X1, y1, X2, y2):
    X1_test = X1[90:]
    y1_test = y1[90:]
    X2_test = X2[90:]
    y2_test = y2[90:]
    X_test = np.vstack((X1_test, X2_test))
    y_test = np.hstack((y1_test, y2_test))
    return X_test, y_test


def test():
    X1, y1, X2, y2 = gen_non_lin_separable_data()
    X_train, y_train = split_train(X1, y1, X2, y2)
    X_test, y_test = split_test(X1, y1, X2, y2)

    clf = KernelSGD(kernel=GaussianKernel(), loss=HingeLoss(threshold=0), n_iter=5)
    clf.fit(X_train, y_train)

    y_predict = clf.predict(X_test)
    correct = np.sum(y_predict == y_test)
    print("%d out of %d predictions correct" % (correct, len(y_predict)))

test()

Which yields the output:

8 support vectors out of 180 points
20 out of 20 predictions correct

Third Party Docs

The original unmodified version of this module’s code if from a gist that can be found here: Kernel SGD