Press "Enter" to skip to content

Grid Search in sklearn

➜ test ✗ cat gridSearchCV.py
import numpy as np
from sklearn import cross_validation
from sklearn import datasets
from sklearn import svm
from sklearn import grid_search
# 使用iris数据集
iris = datasets.load_iris()
print(iris.data.shape)
print(iris.target.shape)
# 配置网格搜索法使用的内核, C值
parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]}
# 创建分类器算法
svr = svm.SVC()
# 创建分类器
clf = grid_search.GridSearchCV(svr, parameters)
# 尝试parameters中的所有参数组合, 拟合数据
clf.fit(iris.data, iris.target)
# 输出结果
print(clf.score(iris.data, iris.target))
print(clf.best_params_)

执行, 输出

➜ test ✗ python3 gridSearchCV.py
/usr/local/lib/python3.6/site-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.
  "This module will be removed in 0.20.", DeprecationWarning)
/usr/local/lib/python3.6/site-packages/sklearn/grid_search.py:42: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. This module will be removed in 0.20.
  DeprecationWarning)
(150, 4)
(150,)
0.9933333333333333
{'C': 1, 'kernel': 'linear'}

可以发现, 这里的R2 得分, 比我们之前直接使用cross_validation将数据分开来测试的结果要来的更加精确(详见Cross Validation in sklearn).

Be First to Comment

Leave a Reply

Your email address will not be published. Required fields are marked *