close
我們以鳶尾花的資料來做KNN分類。
首先載入python套件
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
載入鳶尾花的資料
iris = datasets.load_iris()
把資料跟標記抓出來
iris_data = iris.data
iris_label = iris.target
看資料有幾筆
print(iris_data.shape, iris_label.shape)
可以發現有150筆資料 ,每個樣本有四項特徵,分別是花萼和花瓣的長度和寬度,分為3個種類(setosa, versicolor, virginica)。
更詳細的資料可以參考連結。
接下來把資料切分為訓練跟測試資料
train_data, test_data, train_label, test_label = train_test_split(iris_data, iris_label, test_size=0.2)
準備訓練knn模型
knn = KNeighborsClassifier()
knn.fit(train_data,train_label)
檢查訓練的模型
pred_label = knn.predict(test_data)
print('預測結果:', pred_label)
print('正確答案:', test_label)
print('正確率:', np.mean(pred_label == test_label))
來看訓練出來的效果如何吧
有93%的準確率看起來還可以阿。
全站熱搜
留言列表