《scikit-learn機器學習常用算法原理及編程實戰》—2.6 scikit-learn簡介
2.6? scikit-learn簡介
scikit-learn是一個開源的Python語言機器學習工具包,它涵蓋了幾乎所有主流機器學習算法的實現,并且提供了一致的調用接口。它基于Numpy和scipy等Python數值計算庫,提供了高效的算法實現。總結起來,scikit-learn工具包有以下幾個優點。
* 文檔齊全:官方文檔齊全,更新及時。
* 接口易用:針對所有的算法提供了一致的接口調用規則,不管是KNN、K-Mean還是PCA。
* 算法全面:涵蓋主流機器學習任務的算法,包括回歸算法、分類算法、聚類分析、數據降維處理等。
當然,scikit-learn不支持分布式計算,不適合用來處理超大型數據。但這并不影響 scikit-learn作為一個優秀的機器學習工具庫這個事實。許多知名的公司,包括Evernote和Spotify都使用scikit-learn來開發他們的機器學習應用。
2.6.1? scikit-learn示例
回顧前面章節介紹的機器學習應用開發的典型步驟,我們使用scikit-learn來完成一個手寫數字識別的例子。這是一個有監督的學習,數據是標記過的手寫數字的圖片。即通過采集足夠多的手寫數字樣本數據,選擇合適的模型,并使用采集到的數據進行模型訓練,最后驗證手寫識別程序的正確性。
1.數據采集和標記
如果我們從頭實現一個數字手寫識別的程序,需要先采集數據,即讓盡量多不同書寫習慣的用戶,寫出從0~9的所有數字,然后把用戶寫出來的數據進行標記,即用戶每寫出一個數字,就標記他寫出的是哪個數字。
為什么要采集盡量多不同書寫習慣的用戶寫的數字呢?因為只有這樣,采集到的數據才有代表性,才能保證最終訓練出來的模型的準確性。極端的例子,我們采集的都是習慣寫出瘦高形數字的人,那么針對習慣寫出矮胖形數字的人寫出來的數字,模型的識別成功率就會很低。
所幸我們不需要從頭開始這項工作,scikit-learn自帶了一些數據集,其中一個是手寫數字識別圖片的數據,使用以下代碼來加載數據。
from sklearn import datasets
digits = datasets.load_digits()
可以在ipython notebook環境下把數據所表示的圖片用Mathplotlib顯示出來:
# 把數據所代表的圖片顯示出來
images_and_labels = list(zip(digits.images, digits.target))
plt.figure(figsize=(8, 6), dpi=200)
for index, (image, label) in enumerate(images_and_labels[:8]):
plt.subplot(2, 4, index + 1)
plt.axis('off')
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
plt.title('Digit: %i' % label, fontsize=20)
其結果如圖2-19所示。
圖2-19? 數字圖片
從圖2-19中可以看出,圖片是一個個手寫的數字。
2.特征選擇
針對一個手寫的圖片數據,應該怎么樣來選擇特征呢?一個直觀的方法是,直接使用圖片的每個像素點作為一個特征。比如一個圖片是200 ? 200的分辨率,那么我們就有 40000個特征,即特征向量的長度是40000。
實際上,scikit-learn使用Numpy的array對象來表示數據,所有的圖片數據保存在 digits.images里,每個元素都是一個8?8尺寸的灰階圖片。我們在進行機器學習時,需要把數據保存為樣本個數?特征個數格式的array對象,針對手寫數字識別這個案例,scikit-learn已經為我們轉換好了,它就保存在digits.data數據里,可以通過digits.data.shape來查看它的數據格式為:
print("shape of raw image data: {0}".format(digits.images.shape))
print("shape of data: {0}".format(digits.data.shape))
輸出為:
shape of raw image data: (1797, 8, 8)
shape of data: (1797, 64)
可以看到,總共有1797個訓練樣本,其中原始的數據是8?8的圖片,而用來訓練的數據是把圖片的64個象素點都轉換為特征。下面將直接使用digits.data作為訓練數據。
3.數據清洗
人們不可能在8?8這么小的分辨率的圖片上寫出數字,在采集數據的時候,是讓用戶在一個大圖片上寫出這些數字,如果圖片是200 ? 200分辨率,那么一個訓練樣例就有40000個特征,計算量將是巨大的。為了減少計算量,也為了模型的穩定性,我們需要把200 ? 200的圖片縮小為8?8的圖片。這個過程就是數據清洗,即把采集到的、不適合用來做機器學習訓練的數據進行預處理,從而轉換為適合機器學習的數據。
4.模型選擇
不同的機器學習算法模型針對特定的機器學習應用有不同的效率,模型的選擇和驗證留到后面章節詳細介紹。此處,我們使用支持向量機來作為手寫識別算法的模型。關于支持向量機,后面章節也會詳細介紹。
5.模型訓練
在開始訓練我們的模型之前,需要先把數據集分成訓練數據集和測試數據集。為什么要這樣做呢?第1章的模型訓練和測試里有詳細的介紹。我們可以使用下面代碼把數據集分出20%作為測試數據集。
# 把數據分成訓練數據集和測試數據集
from sklearn.cross_validation import train_test_split
Xtrain, Xtest, Ytrain, Ytest = train_test_split(digits.data,
digits.target, test_size=0.20, random_state=2);
接著,使用訓練數據集Xtrain和Ytrain來訓練模型。
# 使用支持向量機來訓練模型
from sklearn import svm
clf = svm.SVC(gamma=0.001, C=100.)
clf.fit(Xtrain, Ytrain);
訓練完成后,clf對象就會包含我們訓練出來的模型參數,可以使用這個模型對象來進行預測。
6.模型測試
我們來測試一下訓練出來的模型的準確度。一個直觀的方法是,我們用訓練出來的模型clf預測測試數據集,然后把預測結果Ypred和真正的結果Ytest比較,看有多少個是正確的,這樣就能評估出模型的準確度了。所幸,scikit-learn提供了現成的方法來完成這項工作:
clf.score(Xtest, Ytest)
筆者計算機上的輸出結果為:
0.9*********5
顯示出模型有97.8%的準確率。讀者如果運行這段代碼的話,在準確率上可能會稍有差異。
除此之外,還可以直接把測試數據集里的部分圖片顯示出來,并且在圖片的左下角顯示預測值,右下角顯示真實值。運行效果如圖2-20所示。
# 查看預測的情況
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
fig.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.imshow(Xtest[i].reshape(8, 8), cmap=plt.cm.gray_r,
interpolation='nearest')
ax.text(0.05, 0.05, str(Ypred[i]), fontsize=32,
transform=ax.transAxes,
color='green' if Ypred[i] == Ytest[i] else 'red')
ax.text(0.8, 0.05, str(Ytest[i]), fontsize=32,
transform=ax.transAxes,
color='black')
ax.set_xticks([])
ax.set_yticks([])
圖2-20? 預測值與真實值
從圖2-20中可以看出來,第二行第一個圖片預測出錯了,真實的數字是4,但預測成了8。
7.模型保存與加載
當我們對模型的準確度感到滿意后,就可以把模型保存下來。這樣下次需要預測時,可以直接加載模型來進行預測,而不是重新訓練一遍模型??梢允褂孟旅娴拇a來保存模型:
# 保存模型參數
from sklearn.externals import joblib
joblib.dump(clf, 'digits_svm.pkl');
當我們需要這個模型來進行預測時,直接加載模型即可進行預測。
# 導入模型參數,直接進行預測
clf = joblib.load('digits_svm.pkl')
Ypred = clf.predict(Xtest);
clf.score(Ytest, Ypred)
筆者計算機上的輸出結果是:
0.9*********5
這個例子包含在隨書代碼ch02.06.ipynb上,讀者可以下載下來運行并參考。
機器學習 scikit-learn
版權聲明:本文內容由網絡用戶投稿,版權歸原作者所有,本站不擁有其著作權,亦不承擔相應法律責任。如果您發現本站中有涉嫌抄襲或描述失實的內容,請聯系我們jiasou666@gmail.com 處理,核實后本網站將在24小時內刪除侵權內容。