博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
02scikit-learn模型训练
阅读量:6247 次
发布时间:2019-06-22

本文共 8892 字,大约阅读时间需要 29 分钟。

模型训练

In [6]:
import numpy as npimport matplotlib.pyplot as pltfrom sklearn.linear_model import LinearRegressionfrom sklearn.datasets import load_bostondata = load_boston()clf = LinearRegression()n_samples, n_features = data.data.shapen_samples, n_features
Out[6]:
(506, 13)
In [11]:
data.keys()
Out[11]:
dict_keys(['data', 'target', 'feature_names', 'DESCR', 'filename'])
In [12]:
data.feature_names
Out[12]:
array(['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT'], dtype='
In [15]:
# play with featurescolumn_i = 5plt.scatter(data.data[:, column_i], data.target)data.feature_names[5]  # room
Out[15]:
'RM'
 
In [16]:
from sklearn.metrics import mean_absolute_errorclf.fit(data.data, data.target)predicted = clf.predict(data.data)mean_absolute_error(data.target, predicted)
Out[16]:
3.270862810900314
In [17]:
plt.scatter(data.target, predicted)plt.xlabel('true_price')plt.ylabel('predict_price')plt.plot(data.target, data.target, color='red')
Out[17]:
[
]
 
In [20]:
# try another non_linear modelfrom sklearn.tree import DecisionTreeRegressorclf2 = DecisionTreeRegressor()clf2.fit(data.data, data.target)predicted2 = clf2.predict(data.data)mean_absolute_error(data.target, predicted2)
Out[20]:
0.0
In [21]:
plt.scatter(data.target, predicted2)plt.xlabel('true_price')plt.ylabel('predict_price')plt.plot(data.target, data.target, color='red')
Out[21]:
[
]
 
 

上图训练的非常好,可能会产生过拟合

In [25]:
# practice classification model# example Logistic Regression and probability predictionfrom sklearn.datasets import load_irisfrom sklearn.linear_model import LogisticRegressioniris = load_iris()clf = LogisticRegression(solver='liblinear', multi_class='auto')clf.fit(iris.data, iris.target)probability = clf.predict_proba(iris.data)  # 返回预测属于某标签的概率probability  # 例如下面第一行,有87.8%的概率是属于第一类的,有12.2%的概率是属于第# 第二类的,依次类推
Out[25]:
array([[0.87803031, 0.1219589 , 0.00001079],       [0.79705829, 0.20291141, 0.00003029],       [0.85199767, 0.14797648, 0.00002586],       [0.82340602, 0.17653616, 0.00005782],       [0.89603497, 0.10395384, 0.00001119],       [0.92623425, 0.07375278, 0.00001296],       [0.89409685, 0.10586394, 0.00003922],       [0.86003441, 0.13994671, 0.00001888],       [0.80102864, 0.19888675, 0.0000846 ],       [0.79266239, 0.207312  , 0.00002561],       [0.89048611, 0.10950773, 0.00000616],       [0.86180067, 0.13816496, 0.00003437],       [0.78536437, 0.21460826, 0.00002737],       [0.83312233, 0.1668456 , 0.00003207],       [0.92710508, 0.07289396, 0.00000097],       [0.96420978, 0.03578796, 0.00000226],       [0.94024468, 0.05975048, 0.00000484],       [0.89038364, 0.1096022 , 0.00001416],       [0.89499643, 0.1049968 , 0.00000677],       [0.92281833, 0.07716985, 0.00001182],       [0.82816884, 0.17181599, 0.00001517],       [0.9211629 , 0.07881927, 0.00001783],       [0.92583055, 0.07416099, 0.00000846],       [0.86642505, 0.13350683, 0.00006812],       [0.83957906, 0.16034845, 0.00007249],       [0.77438785, 0.22557074, 0.00004141],       [0.88014221, 0.11981603, 0.00004176],       [0.86814212, 0.13184633, 0.00001156],       [0.85798154, 0.14200808, 0.00001039],       [0.83013655, 0.16980939, 0.00005406],       [0.80548889, 0.1944593 , 0.00005181],       [0.87080741, 0.12917648, 0.00001611],       [0.9331403 , 0.06685591, 0.00000379],       [0.94556305, 0.05443497, 0.00000198],       [0.8091041 , 0.19086202, 0.00003387],       [0.84540667, 0.15458141, 0.00001192],       [0.8678451 , 0.13215071, 0.00000419],       [0.88781581, 0.11217401, 0.00001018],       [0.82917332, 0.17076892, 0.00005776],       [0.85578733, 0.14419685, 0.00001582],       [0.89902143, 0.10096537, 0.00001321],       [0.68760966, 0.31222729, 0.00016305],       [0.86468741, 0.13526866, 0.00004393],       [0.91572506, 0.08421267, 0.00006227],       [0.91483865, 0.08511969, 0.00004165],       [0.81813982, 0.18181229, 0.00004789],       [0.90880255, 0.09118589, 0.00001156],       [0.84952236, 0.15043821, 0.00003943],       [0.89405093, 0.10594173, 0.00000734],       [0.84954119, 0.15044186, 0.00001695],       [0.02960053, 0.86126971, 0.10912976],       [0.03735662, 0.70599864, 0.25664474],       [0.01171882, 0.74918029, 0.23910089],       [0.01323293, 0.65261527, 0.3341518 ],       [0.0109261 , 0.69975168, 0.28932221],       [0.01074757, 0.58352345, 0.40572897],       [0.02155363, 0.53736735, 0.44107901],       [0.10779613, 0.76976403, 0.12243984],       [0.01756481, 0.82847163, 0.15396355],       [0.0331001 , 0.52843482, 0.43846508],       [0.02909409, 0.77362145, 0.19728446],       [0.04095529, 0.61978222, 0.33926249],       [0.01930641, 0.88041176, 0.10028182],       [0.00871112, 0.59711982, 0.39416907],       [0.16693542, 0.7134194 , 0.11964518],       [0.0471498 , 0.84453959, 0.1083106 ],       [0.01229146, 0.42322112, 0.56448741],       [0.03811694, 0.85107181, 0.11081125],       [0.00308283, 0.59723043, 0.39968674],       [0.03569679, 0.80966589, 0.15463732],       [0.00624631, 0.27162577, 0.72212792],       [0.05767621, 0.82064253, 0.12168126],       [0.00195123, 0.53464684, 0.46340193],       [0.0087242 , 0.70558697, 0.28568882],       [0.03660593, 0.83989907, 0.123495  ],       [0.0358882 , 0.82963776, 0.13447404],       [0.00807402, 0.77816015, 0.21376583],       [0.00463307, 0.52364059, 0.47172634],       [0.01333998, 0.56347986, 0.42318016],       [0.12711691, 0.8329313 , 0.03995179],       [0.03581044, 0.80413792, 0.16005163],       [0.05003383, 0.84711273, 0.10285344],       [0.05656025, 0.81218015, 0.1312596 ],       [0.001226  , 0.39930356, 0.59947044],       [0.01035901, 0.36404062, 0.62560038],       [0.04192755, 0.47659618, 0.48147627],       [0.01894857, 0.74644236, 0.23460907],       [0.00699118, 0.75788979, 0.23511903],       [0.05570461, 0.66760154, 0.27669385],       [0.0210041 , 0.6630778 , 0.3159181 ],       [0.00895359, 0.60041736, 0.39062905],       [0.01518493, 0.63284951, 0.35196556],       [0.03451475, 0.79953197, 0.16595328],       [0.09088978, 0.79694616, 0.11216406],       [0.01979054, 0.64145332, 0.33875613],       [0.0479463 , 0.73160741, 0.22044629],       [0.03437866, 0.67792614, 0.2876952 ],       [0.03365277, 0.79775962, 0.16858761],       [0.25317619, 0.69233045, 0.05449335],       [0.03622935, 0.70484125, 0.2589294 ],       [0.00018858, 0.14637262, 0.8534388 ],       [0.00081403, 0.29344714, 0.70573883],       [0.000279  , 0.33023907, 0.66948192],       [0.00045801, 0.33833991, 0.66120208],       [0.00025341, 0.25571436, 0.74403222],       [0.00006041, 0.38291757, 0.61702202],       [0.00206351, 0.2798062 , 0.71813029],       [0.00012312, 0.42493487, 0.57494202],       [0.0001599 , 0.42361552, 0.57622458],       [0.00035986, 0.1507501 , 0.84889004],       [0.00301206, 0.27686098, 0.72012696],       [0.00064689, 0.35551374, 0.64383936],       [0.00068392, 0.29819313, 0.70112296],       [0.00063265, 0.29576839, 0.70359896],       [0.00061817, 0.17242508, 0.82695675],       [0.0011076 , 0.17147616, 0.82741624],       [0.00080141, 0.34867019, 0.6505284 ],       [0.00019462, 0.23736752, 0.76243786],       [0.00001303, 0.42032294, 0.57966403],       [0.00067999, 0.47061514, 0.52870487],       [0.00050811, 0.22392596, 0.77556593],       [0.0013505 , 0.22852031, 0.77012918],       [0.00003818, 0.42849746, 0.57146436],       [0.00205514, 0.40046885, 0.59747602],       [0.00068243, 0.23590885, 0.76340873],       [0.00045642, 0.39783264, 0.60171094],       [0.00320524, 0.38361431, 0.61318045],       [0.00343776, 0.32673717, 0.66982508],       [0.00030239, 0.29738763, 0.70230998],       [0.00067575, 0.51161821, 0.48770605],       [0.00016147, 0.42854842, 0.57129011],       [0.00064593, 0.34460403, 0.65475004],       [0.00027729, 0.27685415, 0.72286856],       [0.00207367, 0.49125249, 0.50667385],       [0.00035439, 0.44307281, 0.5565728 ],       [0.00018237, 0.34196617, 0.65785146],       [0.00063838, 0.1265659 , 0.87279573],       [0.00092554, 0.32031446, 0.67876   ],       [0.00431283, 0.31746907, 0.6782181 ],       [0.00117132, 0.3003005 , 0.69852818],       [0.00045021, 0.20080005, 0.79874973],       [0.00216404, 0.24761373, 0.75022222],       [0.00081403, 0.29344714, 0.70573883],       [0.00029358, 0.22339673, 0.77630969],       [0.00045525, 0.15204928, 0.84749547],       [0.00116469, 0.23233015, 0.76650517],       [0.0009204 , 0.37926299, 0.6198166 ],       [0.00146455, 0.29758429, 0.70095116],       [0.00110986, 0.12983185, 0.8690583 ],       [0.00169379, 0.27997339, 0.71833282]])
In [28]:
from sklearn.svm import SVCfrom sklearn.metrics import accuracy_scoreclf2 = SVC(gamma='auto')clf2.fit(iris.data, iris.target)predicted = clf.predict(iris.data)predicted2 = clf2.predict(iris.data)print(accuracy_score(iris.target, predicted))print(accuracy_score(iris.target, predicted2))
 
0.960.9866666666666667

转载于:https://www.cnblogs.com/xinmomoyan/p/10408695.html

你可能感兴趣的文章
深入浅出事件流处理NEsper(二)
查看>>
技术人生:如何做非正式的交流
查看>>
利用共享内存和信号灯集实现进程间同步一例
查看>>
类的基础
查看>>
Sql Server系列:使用Transact-SQL编程
查看>>
新增题目功能模块总结
查看>>
三、mono for android 学习:参考书籍
查看>>
javascript练习:8-10事件与this运算符
查看>>
Linux下SVN部署/安全及权限配置,实现web同步更新
查看>>
PHPSPY2013
查看>>
Android学习笔记(四)时钟、时间
查看>>
SQL SERVER 查询性能优化——分析事务与锁(二)
查看>>
动画实现实现上下滚动的TextView
查看>>
HDU-4461 The Power of Xiangqi 签到题
查看>>
方法线程SwingWorker的用法
查看>>
hdu 4313(类似于kruskal)
查看>>
【数据存储】数据查询与Cursor接口(4)
查看>>
DoTA与人生
查看>>
〖Android〗/system/etc/media_codecs.xml
查看>>
ESN 与 MEID
查看>>