ML之xgboost:利用xgboost算法(sklearn+GridSearchCV)训练mushroom蘑菇数据集(22+1,6513+1611)来预测蘑菇是否毒性(二分类预测)

2021/6/15 20:22:26

本文主要是介绍ML之xgboost:利用xgboost算法(sklearn+GridSearchCV)训练mushroom蘑菇数据集(22+1,6513+1611)来预测蘑菇是否毒性(二分类预测),对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

ML之xgboost:利用xgboost算法(sklearn+GridSearchCV)训练mushroom蘑菇数据集(22+1,6513+1611)来预测蘑菇是否毒性(二分类预测)

 

 

目录

输出结果

设计思路

核心代码

更多输出


 

 

 

 

输出结果

正在更新……

 

设计思路

正在更新……

 

核心代码

from sklearn.grid_search import GridSearchCV

param_test = { 'n_estimators': range(1, 51, 1)}
clf = GridSearchCV(estimator = bst, param_grid = param_test, cv=5)
clf.fit(X_train, y_train)
clf.grid_scores_, clf.best_params_, clf.best_score_



grid_scores_mean= [0.90542,  0.94749,  0.90542,  0.94749,  0.90573,  0.94718, 
                   0.90542,  0.94242,  0.94473,  0.97482,  0.94887,  0.97850, 
                   0.97298,  0.97850,  0.97298,  0.97850,  0.97850,  0.97850, 
                   0.97850,  0.97850,  0.97850,  0.97850,  0.97850,  0.97850, 
                   0.97850,  0.97804,  0.97774,  0.97835,  0.98296,  0.98419,    
                   0.98342,  0.98372,  0.98419,  0.98419,  0.98419,  0.98419, 
                   0.98419,  0.98419,  0.98419,  0.98419,  0.98419,  0.98419, 
                   0.98419,  0.98419,  0.98419,  0.98419,  0.98419,  0.98419, 
                   0.98419 ]
 
 
grid_scores_std = [0.08996,  0.07458,  0.08996,  0.07458,  0.09028,  0.07436,  
                   0.08996,  0.07331,  0.07739,  0.02235,  0.07621,  0.02387,  
                   0.03186,  0.02387,  0.03186,  0.02387,  0.02387,  0.02387,  
                   0.02387,  0.02387,  0.02387,  0.02387,  0.02387,  0.02387,  
                   0.02387,  0.02365,  0.02337,  0.02383,  0.01963,  0.02040,  
                   0.01988,  0.02008,  0.02040,  0.02040,  0.02040,  0.02040,  
                   0.02040,  0.02040,  0.02040,  0.02040,  0.02040,  0.02040,  
                   0.02040,  0.02040,  0.02040,  0.02040,  0.02040,  0.02040,  
                   0.02040  ]
 
 
#7-CrVa交叉验证曲线可视化
import matplotlib.pyplot as plt
 
x = range(0,len(grid_scores_mean))
y1 = grid_scores_mean
y2 = grid_scores_std
Xlabel = 'n_estimators'
Ylabel = 'value'
title = 'mushroom datase: xgboost(sklearn+GridSearchCV) model'
 
plt.plot(x,y1,'r',label='Mean')                      #绘制mean曲线
plt.plot(x,y2,'g',label='Std')                       #绘制std曲线
 
plt.rcParams['font.sans-serif']=['Times New Roman']  #手动添加中文字体,或者['font.sans-serif'] = ['FangSong']   SimHei
#myfont = matplotlib.font_manager.FontProperties(fname='C:/Windows/Fonts/msyh.ttf')  #也可以指定win系统字体路径
plt.rcParams['axes.unicode_minus'] = False  #对坐标轴的负号进行正常显示
 
plt.xlabel(Xlabel)
plt.ylabel(Ylabel)
plt.title(title)
 
plt.legend(loc=1)   
plt.show()

 

更多输出

GridSearchCV time: 79.7655139499154
clf.grid_scores_: [mean: 0.90542, std: 0.08996, params: {'n_estimators': 1}, mean: 0.94749, std: 0.07458, params: {'n_estimators': 2}, 
                   mean: 0.90542, std: 0.08996, params: {'n_estimators': 3}, mean: 0.94749, std: 0.07458, params: {'n_estimators': 4}, 
                   mean: 0.90573, std: 0.09028, params: {'n_estimators': 5}, mean: 0.94718, std: 0.07436, params: {'n_estimators': 6}, 
                   mean: 0.90542, std: 0.08996, params: {'n_estimators': 7}, mean: 0.94242, std: 0.07331, params: {'n_estimators': 8}, 
                   mean: 0.94473, std: 0.07739, params: {'n_estimators': 9}, mean: 0.97482, std: 0.02235, params: {'n_estimators': 10}, 
                   mean: 0.94887, std: 0.07621, params: {'n_estimators': 11}, mean: 0.97850, std: 0.02387, params: {'n_estimators': 12}, 
                   mean: 0.97298, std: 0.03186, params: {'n_estimators': 13}, mean: 0.97850, std: 0.02387, params: {'n_estimators': 14}, 
                   mean: 0.97298, std: 0.03186, params: {'n_estimators': 15}, mean: 0.97850, std: 0.02387, params: {'n_estimators': 16}, 
                   mean: 0.97850, std: 0.02387, params: {'n_estimators': 17}, mean: 0.97850, std: 0.02387, params: {'n_estimators': 18}, 
                   mean: 0.97850, std: 0.02387, params: {'n_estimators': 19}, mean: 0.97850, std: 0.02387, params: {'n_estimators': 20}, 
                   mean: 0.97850, std: 0.02387, params: {'n_estimators': 21}, mean: 0.97850, std: 0.02387, params: {'n_estimators': 22}, 
                   mean: 0.97850, std: 0.02387, params: {'n_estimators': 23}, mean: 0.97850, std: 0.02387, params: {'n_estimators': 24}, 
                   mean: 0.97850, std: 0.02387, params: {'n_estimators': 25}, mean: 0.97804, std: 0.02365, params: {'n_estimators': 26}, 
                   mean: 0.97774, std: 0.02337, params: {'n_estimators': 27}, mean: 0.97835, std: 0.02383, params: {'n_estimators': 28}, 
                   mean: 0.98296, std: 0.01963, params: {'n_estimators': 29}, mean: 0.98419, std: 0.02040, params: {'n_estimators': 30}, 
                   mean: 0.98342, std: 0.01988, params: {'n_estimators': 31}, mean: 0.98372, std: 0.02008, params: {'n_estimators': 32}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 33}, mean: 0.98419, std: 0.02040, params: {'n_estimators': 34}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 35}, mean: 0.98419, std: 0.02040, params: {'n_estimators': 36}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 37}, mean: 0.98419, std: 0.02040, params: {'n_estimators': 38}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 39}, mean: 0.98419, std: 0.02040, params: {'n_estimators': 40}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 41}, mean: 0.98419, std: 0.02040, params: {'n_estimators': 42}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 43}, mean: 0.98419, std: 0.02040, params: {'n_estimators': 44}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 45}, mean: 0.98419, std: 0.02040, params: {'n_estimators': 46}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 47}, mean: 0.98419, std: 0.02040, params: {'n_estimators': 48}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 49}, mean: 0.98419, std: 0.02040, params: {'n_estimators': 50}]
clf.best_params_: {'n_estimators': 30}
clf.best_score_: 0.9841854752034392
[mean: 0.90542, std: 0.08996, params: {'n_estimators': 1}, 
                   mean: 0.94749, std: 0.07458, params: {'n_estimators': 2}, 
                   mean: 0.90542, std: 0.08996, params: {'n_estimators': 3}, 
                   mean: 0.94749, std: 0.07458, params: {'n_estimators': 4}, 
                   mean: 0.90573, std: 0.09028, params: {'n_estimators': 5}, 
                   mean: 0.94718, std: 0.07436, params: {'n_estimators': 6}, 
                   mean: 0.90542, std: 0.08996, params: {'n_estimators': 7}, 
                   mean: 0.94242, std: 0.07331, params: {'n_estimators': 8}, 
                   mean: 0.94473, std: 0.07739, params: {'n_estimators': 9}, 
                   mean: 0.97482, std: 0.02235, params: {'n_estimators': 10}, 
                   mean: 0.94887, std: 0.07621, params: {'n_estimators': 11}, 
                   mean: 0.97850, std: 0.02387, params: {'n_estimators': 12}, 
                   mean: 0.97298, std: 0.03186, params: {'n_estimators': 13}, 
                   mean: 0.97850, std: 0.02387, params: {'n_estimators': 14}, 
                   mean: 0.97298, std: 0.03186, params: {'n_estimators': 15}, 
                   mean: 0.97850, std: 0.02387, params: {'n_estimators': 16}, 
                   mean: 0.97850, std: 0.02387, params: {'n_estimators': 17}, 
                   mean: 0.97850, std: 0.02387, params: {'n_estimators': 18}, 
                   mean: 0.97850, std: 0.02387, params: {'n_estimators': 19}, 
                   mean: 0.97850, std: 0.02387, params: {'n_estimators': 20}, 
                   mean: 0.97850, std: 0.02387, params: {'n_estimators': 21}, 
                   mean: 0.97850, std: 0.02387, params: {'n_estimators': 22}, 
                   mean: 0.97850, std: 0.02387, params: {'n_estimators': 23}, 
                   mean: 0.97850, std: 0.02387, params: {'n_estimators': 24}, 
                   mean: 0.97850, std: 0.02387, params: {'n_estimators': 25}, 
                   mean: 0.97804, std: 0.02365, params: {'n_estimators': 26}, 
                   mean: 0.97774, std: 0.02337, params: {'n_estimators': 27}, 
                   mean: 0.97835, std: 0.02383, params: {'n_estimators': 28}, 
                   mean: 0.98296, std: 0.01963, params: {'n_estimators': 29}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 30}, 
                   mean: 0.98342, std: 0.01988, params: {'n_estimators': 31}, 
                   mean: 0.98372, std: 0.02008, params: {'n_estimators': 32}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 33}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 34}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 35}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 36}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 37}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 38}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 39}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 40}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 41}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 42}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 43}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 44}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 45}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 46}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 47}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 48}, 
                   mean: 0.98419, std: 0.02040, params: {'n_estimators': 49}]


grid_scores_ = [mean: 0.90542, std: 0.08996, 
                   mean: 0.94749, std: 0.07458, 
                   mean: 0.90542, std: 0.08996, 
                   mean: 0.94749, std: 0.07458, 
                   mean: 0.90573, std: 0.09028, 
                   mean: 0.94718, std: 0.07436,
                   mean: 0.90542, std: 0.08996, 
                   mean: 0.94242, std: 0.07331, 
                   mean: 0.94473, std: 0.07739,  
                   mean: 0.97482, std: 0.02235,
                   mean: 0.94887, std: 0.07621, 
                   mean: 0.97850, std: 0.02387, 
                   mean: 0.97298, std: 0.03186,  
                   mean: 0.97850, std: 0.02387, 
                   mean: 0.97298, std: 0.03186,  
                   mean: 0.97850, std: 0.02387, 
                   mean: 0.97850, std: 0.02387,  
                   mean: 0.97850, std: 0.02387, 
                   mean: 0.97850, std: 0.02387, 
                   mean: 0.97850, std: 0.02387, 
                   mean: 0.97850, std: 0.02387, 
                   mean: 0.97850, std: 0.02387, 
                   mean: 0.97850, std: 0.02387,
                   mean: 0.97850, std: 0.02387, 
                   mean: 0.97850, std: 0.02387,
                   mean: 0.97804, std: 0.02365,  
                   mean: 0.97774, std: 0.02337, 
                   mean: 0.97835, std: 0.02383, 
                   mean: 0.98296, std: 0.01963, 
                   mean: 0.98419, std: 0.02040,  
                   mean: 0.98342, std: 0.01988, 
                   mean: 0.98372, std: 0.02008, 
                   mean: 0.98419, std: 0.02040,  
                   mean: 0.98419, std: 0.02040,  
                   mean: 0.98419, std: 0.02040,
                   mean: 0.98419, std: 0.02040,
                   mean: 0.98419, std: 0.02040,
                   mean: 0.98419, std: 0.02040, 
                   mean: 0.98419, std: 0.02040,
                   mean: 0.98419, std: 0.02040,
                   mean: 0.98419, std: 0.02040,
                   mean: 0.98419, std: 0.02040, 
                   mean: 0.98419, std: 0.02040,
                   mean: 0.98419, std: 0.02040,
                   mean: 0.98419, std: 0.02040,
                   mean: 0.98419, std: 0.02040,
                   mean: 0.98419, std: 0.02040,
                   mean: 0.98419, std: 0.02040,
                   mean: 0.98419, std: 0.02040  ]

 



这篇关于ML之xgboost:利用xgboost算法(sklearn+GridSearchCV)训练mushroom蘑菇数据集(22+1,6513+1611)来预测蘑菇是否毒性(二分类预测)的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!


扫一扫关注最新编程教程