利用Scikit-Learn实现SVM

实现模式

整体的思路就是,先导入库,划分数据,建立模型,训练,评估,之后进行预测。

1
2
3
4
5
6
7
8
9
10
11
12
#Import Library
from sklearn import svm
#Assumed you have, X (predictor) and Y (target) for training data set and x_test(predictor) of test_dataset
# Create SVM classification object
model = svm.svc(kernel='linear', c=1, gamma=1)
# there is various option associated with it, like changing kernel, gamma and C value.
# Will discuss more about it in next section.
# Train the model using the training sets and check score
model.fit(X, y)
model.score(X, y)
#Predict Output
predicted= model.predict(x_test)

示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# -*- coding:utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
# import some data to play with
iris = datasets.load_iris()
X = iris.data[:, :2] # we only take the first two features. We could
# avoid this ugly slicing by using a two-dim dataset
y = iris.target
# we create an instance of SVM and fit out data. We do not scale our
# data since we want to plot the support vectors
C = 1.0 # SVM regularization parameter
svc = svm.SVC(kernel='rbf', C=1,gamma=0).fit(X, y)
print svc.score(X,y)
h = 0.01
# create a mesh to plot in
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
plt.subplot(1, 1, 1)
Z = svc.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=plt.cm.Paired, alpha=0.8)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.title('SVC with linear kernel')
plt.show()

参数调节

以三个有重要影响的参数为例,主要包括 kernel, gamma 和 C ,参见原文的图解可以有很直观的印象。

核类型(kernel)

Here “rbf” and “poly” are useful for non-linear hyper-plane.

I would suggest you to go for linear kernel if you have large number of features (>1000) because it is more likely that the data is linearly separable in high dimensional space. Also, you can RBF but do not forget to cross validate for its parameters as to avoid over-fitting.

gamma

Kernel coefficient for ‘rbf’, ‘poly’ and ‘sigmoid’. Higher the value of gamma, will try to exact fit the as per training data set i.e. generalization error and cause over-fitting problem.

C(惩罚系数)

Penalty parameter C of the error term. It also controls the trade off between smooth decision boundary and classifying the training points correctly.

此处添加
非线性可分的分类器的公式

模型评估

cross validation score

1
svc.score(X,y)

关于思考题的想法

假想一个辅助圆吧,其圆心是Q(0,y1)
可以发现星星类(star)始终离圆心Q比较近,而圆类(circle)离Q比较远。

所以变换应该是
$$ Z=(x-0)^{2}+(y-y1)^{2} $$

参考资料

[1] Understanding Support Vector Machine algorithm from examples (along with code)
[2] 《统计学习方法》 李航