Press "Enter" to skip to content

Cross Validation in sklearn

➜ test ✗ cat cross_validation.py
import numpy as np
from sklearn import cross_validation
from sklearn import datasets
from sklearn import svm
# 使用iris数据集
iris = datasets.load_iris()
print(iris.data.shape)
print(iris.target.shape)
# 将数据集分为两个部分: 60%为训练集, 40%为测试集
X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target, test_size = 0.4, random_state = 0)
# 输出分出来的训练/测试集大小
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)
# 对使用SVC决策器数据进行拟合
clf = svm.SVC(kernel = 'linear', C = 1)
clf.fit(X_train, y_train)
# 输出测试集评分
print(clf.score(X_test, y_test))

执行输出结果

➜ test ✗ python3 cross_validation.py
/usr/local/lib/python3.6/site-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.
  "This module will be removed in 0.20.", DeprecationWarning)
(150, 4)
(150,)
(90, 4)
(90,)
(60, 4)
(60,)
0.9666666666666667

Be First to Comment

Leave a Reply

Your email address will not be published. Required fields are marked *