【Kaggle】模型调参利器 gridSearchCV(网格搜索)
nanshan 2024-10-12 05:41 27 浏览 0 评论
gridSearchCV(网格搜索)的参数、方法及示例
1.简介
GridSearchCV的sklearn官方网址:
http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html#sklearn.model_selection.GridSearchCV
GridSearchCV,它存在的意义就是自动调参,只要把参数输进去,就能给出最优化的结果和参数。但是这个方法适合于小数据集,一旦数据的量级上去了,很难得出结果。这个时候就是需要动脑筋了。数据量比较大的时候可以使用一个快速调优的方法——坐标下降。它其实是一种贪心算法:拿当前对模型影响最大的参数调优,直到最优化;再拿下一个影响最大的参数调优,如此下去,直到所有的参数调整完毕。这个方法的缺点就是可能会调到局部最优而不是全局最优,但是省时间省力,巨大的优势面前,还是试一试吧,后续可以再拿bagging再优化。
通常算法不够好,需要调试参数时必不可少。比如SVM的惩罚因子C,核函数kernel,gamma参数等,对于不同的数据使用不同的参数,结果效果可能差1-5个点,sklearn为我们提供专门调试参数的函数grid_search。
2.参数说明
class sklearn.model_selection.GridSearchCV(estimator, param_grid, scoring=None, fit_params=None, n_jobs=1, iid=True, refit=True, cv=None, verbose=0, pre_dispatch=‘2*n_jobs’, error_score=’raise’, return_train_score=’warn’)
(1)estimator
选择使用的分类器,并且传入除需要确定最佳的参数之外的其他参数。每一个分类器都需要一个scoring参数,或者score方法:estimator=RandomForestClassifier(min_samples_split=100,min_samples_leaf=20,max_depth=8,max_features='sqrt',random_state=10),
(2)param_grid
需要最优化的参数的取值,值为字典或者列表,例如:param_grid =param_test1,param_test1 = {'n_estimators':range(10,71,10)}。
(3)scoring=None
模型评价标准,默认None,这时需要使用score函数;或者如scoring='roc_auc',根据所选模型不同,评价准则不同。字符串(函数名),或是可调用对象,需要其函数签名形如:scorer(estimator, X, y);如果是None,则使用estimator的误差估计函数。具体值的选取看本篇第三节内容。
(4)fit_params=None
(5) n_jobs=1
n_jobs: 并行数,int:个数,-1:跟CPU核数一致, 1:默认值
(6) iid=True
iid:默认True,为True时,默认为各个样本fold概率分布一致,误差估计为所有样本之和,而非各个fold的平均。
(7) refit=True
默认为True,程序将会以交叉验证训练集得到的最佳参数,重新对所有可用的训练集与开发集进行,作为最终用于性能评估的最佳模型参数。即在搜索参数结束后,用最佳参数结果再次fit一遍全部数据集。
(8)cv=None
交叉验证参数,默认None,使用三折交叉验证。指定fold数量,默认为3,也可以是yield训练/测试数据的生成器。
(9)verbose=0, scoring=None
verbose:日志冗长度,int:冗长度,0:不输出训练过程,1:偶尔输出,>1:对每个子模型都输出。
(10)pre_dispatch=‘2*n_jobs’
指定总共分发的并行任务数。当n_jobs大于1时,数据将在每个运行点进行复制,这可能导致OOM,而设置pre_dispatch参数,则可以预先划分总共的job数量,使数据最多被复制pre_dispatch次
(11)error_score=’raise’
(12)return_train_score=’warn’
如果“False”,cv_results_属性将不包括训练分数
回到sklearn里面的GridSearchCV,GridSearchCV用于系统地遍历多种参数组合,通过交叉验证确定最佳效果参数。
3. Scoring parameter:评价标准参数详细说明
Model-evaluation tools using cross-validation (such as model_selection.cross_val_score andmodel_selection.GridSearchCV) rely on an internal scoring strategy. This is discussed in the section The scoring parameter: defining model evaluation rules.
For the most common use cases, you can designate a scorer object with the scoring parameter; the table below shows all possible values. All scorer objects follow the convention that higher return values are better than lower return values. Thus metrics which measure the distance between the model and the data, like metrics.mean_squared_error, are available as neg_mean_squared_error which return the negated value of the metric.
4.属性
(1)cv_results_ : dict of numpy (masked) ndarrays
具有键作为列标题和值作为列的dict,可以导入到DataFrame中。注意,“params”键用于存储所有参数候选项的参数设置列表。
(2)best_estimator_ : estimator
通过搜索选择的估计器,即在左侧数据上给出最高分数(或指定的最小损失)的估计器。 如果refit = False,则不可用。
(3)best_score_ : float best_estimator的分数
(4)best_params_ : dict 在保存数据上给出最佳结果的参数设置
(5)best_index_ : int 对应于最佳候选参数设置的索引(cv_results_数组)。
search.cv_results _ ['params'] [search.best_index_]中的dict给出了最佳模型的参数设置,给出了最高的平均分数(search.best_score_)。
(6)scorer_ : function
Scorer function used on the held out data to choose the best parameters for the model.
(7)n_splits_ : int
The number of cross-validation splits (folds/iterations).
(8)grid_scores_:给出不同参数情况下的评价结果
6.网格搜索实例
import pandas as pd # 数据科学计算工具
import numpy as np # 数值计算工具
import matplotlib.pyplot as plt # 可视化
import seaborn as sns # matplotlib的高级API
from sklearn.model_selection import StratifiedKFold #交叉验证
from sklearn.model_selection import GridSearchCV #网格搜索
from sklearn.model_selection import train_test_split #将数据集分开成训练集和测试集
from xgboost import XGBClassifier #xgboost
pima = pd.read_csv("pima_indians-diabetes.csv")
print(pima.head())
x = pima.iloc[:,0:8]
y = pima.iloc[:,8]
seed = 7 #重现随机生成的训练
test_size = 0.33 #33%测试,67%训练
X_train, X_test, Y_train, Y_test = train_test_split(x, y, test_size=test_size, random_state=seed
model = XGBClassifier()
learning_rate = [0.0001,0.001,0.01,0.1,0.2,0.3] #学习率
gamma = [1, 0.1, 0.01, 0.001]
param_grid = dict(learning_rate = learning_rate,gamma = gamma)#转化为字典格式,网络搜索要求
kflod = StratifiedKFold(n_splits=10, shuffle = True,random_state=7)#将训练/测试数据集划分10个互斥子集,
grid_search = GridSearchCV(model,param_grid,scoring = 'neg_log_loss',n_jobs = -1,cv = kflod)
#scoring指定损失函数类型,n_jobs指定全部cpu跑,cv指定交叉验证
grid_result = grid_search.fit(X_train, Y_train) #运行网格搜索
print("Best: %f using %s" % (grid_result.best_score_,grid_search.best_params_))
#grid_scores_:给出不同参数情况下的评价结果。best_params_:描述了已取得最佳结果的参数的组合
#best_score_:成员提供优化过程期间观察到的最好的评分
#具有键作为列标题和值作为列的dict,可以导入到DataFrame中。
#注意,“params”键用于存储所有参数候选项的参数设置列表。
means = grid_result.cv_results_['mean_test_score']
params = grid_result.cv_results_['params']
for mean,param in zip(means,params):
print("%f with: %r" % (mean,param))
结果如下:
参考:机器学习(四)--模型调参利器 gridSearchCV(网格搜索)
- 上一篇:以太坊stratum协议原理
- 下一篇:去中心化矿池协议 Stratum V2 概述
相关推荐
- 服务器数据恢复—Raid5数据灾难不用愁,Raid5数据恢复原理了解下
-
Raid5数据恢复算法原理:分布式奇偶校验的独立磁盘结构(被称之为raid5)的数据恢复有一个“奇偶校验”的概念。可以简单的理解为二进制运算中的“异或运算”,通常使用的标识是xor。运算规则:若二者值...
- 服务器数据恢复—多次异常断电导致服务器raid不可用的数据恢复
-
服务器数据恢复环境&故障:由于机房多次断电导致一台服务器中raid阵列信息丢失。该阵列中存放的是文档,上层安装的是Windowsserver操作系统,没有配置ups。因为服务器异常断电重启后,rai...
- 服务器数据恢复-V7000存储更换磁盘数据同步失败的数据恢复案例
-
服务器数据恢复环境:P740+AIX+Sybase+V7000存储,存储阵列柜上共12块SAS机械硬盘(其中一块为热备盘)。服务器故障:存储阵列柜中有磁盘出现故障,工作人员发现后更换磁盘,新更换的磁盘...
- 「服务器数据恢复」重装系统导致XFS文件系统分区丢失的数据恢复
-
服务器数据恢复环境:DellPowerVault系列磁盘柜;用RAID卡创建的一组RAID5;分配一个LUN。服务器故障:在Linux系统层面对LUN进行分区,划分sdc1和sdc2两个分区。将sd...
- 服务器数据恢复-ESXi虚拟机被误删的数据恢复案例
-
服务器数据恢复环境:一台服务器安装的ESXi虚拟化系统,该虚拟化系统连接了多个LUN,其中一个LUN上运行了数台虚拟机,虚拟机安装WindowsServer操作系统。服务器故障&分析:管理员因误操作...
- 「服务器数据恢复」Raid5阵列两块硬盘亮黄灯掉线的数据恢复案例
-
服务器数据恢复环境:HPStorageWorks某型号存储;虚拟化平台为vmwareexsi;10块磁盘组成raid5(有1块热备盘)。服务器故障:raid5阵列中两块硬盘指示灯变黄掉线,无法读取...
- 服务器数据恢复—基于oracle数据库的SAP数据恢复案例
-
服务器存储数据恢复环境:某品牌服务器存储中有一组由6块SAS硬盘组建的RAID5阵列,其中有1块硬盘作为热备盘使用。上层划分若干lun,存放Oracle数据库数据。服务器存储故障&分析:该RAID5阵...
- 「服务器虚拟化数据恢复」Xen Server环境下数据库数据恢复案例
-
服务器虚拟化数据恢复环境:Dell某型号服务器;数块STAT硬盘通过raid卡组建的RAID10;XenServer服务器虚拟化系统;故障虚拟机操作系统:WindowsServer,部署Web服务...
- 服务器数据恢复—RAID故障导致oracle无法启动的数据恢复案例
-
服务器数据恢复环境:某品牌服务器中有一组由4块SAS磁盘做的RAID5磁盘阵列。该服务器操作系统为windowsserver,运行了一个单节点Oracle,数据存储为文件系统,无归档。该oracle...
- 服务器数据恢复—服务器磁盘阵列常见故障表现&解决方案
-
RAID(磁盘阵列)是一种将多块物理硬盘整合成一个虚拟存储的技术,raid模块相当于一个存储管理的中间层,上层接收并执行操作系统及文件系统的数据读写指令,下层管理数据在各个物理硬盘上的存储及读写。相对...
- 「服务器数据恢复」IBM某型号服务器RAID5磁盘阵列数据恢复案例
-
服务器数据恢复环境:IBM某型号服务器;5块SAS硬盘组成RAID5磁盘阵列;存储划分为1个LUN和3个分区:第一个分区存放windowsserver系统,第二个分区存放SQLServer数据库,...
- 服务器数据恢复—Zfs文件系统下误删除文件如何恢复数据?
-
服务器故障:一台zfs文件系统服务器,管理员误操作删除服务器上的数据。服务器数据恢复过程:1、将故障服务器所有磁盘编号后取出,硬件工程师检测所有硬盘后没有发现有磁盘存在硬件故障。以只读方式将全部磁盘做...
- 服务器数据恢复—Linux+raid5服务器数据恢复案例
-
服务器数据恢复环境:某品牌linux操作系统服务器,服务器中有4块SAS接口硬盘组建一组raid5阵列。服务器中存放的数据有数据库、办公文档、代码文件等。服务器故障&检测:服务器在运行过程中突然瘫痪,...
- 服务器数据恢复—Sql Server数据库数据恢复案例
-
服务器数据恢复环境:一台安装windowsserver操作系统的服务器。一组由8块硬盘组建的RAID5,划分LUN供这台服务器使用。在windows服务器内装有SqlServer数据库。存储空间LU...
- 服务器数据恢复—阿里云ECS网站服务器数据恢复案例
-
云服务器数据恢复环境:阿里云ECS网站服务器,linux操作系统+mysql数据库。云服务器故障:在执行数据库版本更新测试时,在生产库误执行了本来应该在测试库执行的sql脚本,导致生产库部分表被tru...
你 发表评论:
欢迎- 一周热门
-
-
爱折腾的特斯拉车主必看!手把手教你TESLAMATE的备份和恢复
-
如何在安装前及安装后修改黑群晖的Mac地址和Sn系列号
-
[常用工具] OpenCV_contrib库在windows下编译使用指南
-
WindowsServer2022|配置NTP服务器的命令
-
Ubuntu系统Daphne + Nginx + supervisor部署Django项目
-
WIN11 安装配置 linux 子系统 Ubuntu 图形界面 桌面系统
-
解决Linux终端中“-bash: nano: command not found”问题
-
NBA 2K25虚拟内存不足/爆内存/内存占用100% 一文速解
-
Linux 中的文件描述符是什么?(linux 打开文件表 文件描述符)
-
K3s禁用Service Load Balancer,解决获取浏览器IP不正确问题
-
- 最近发表
-
- 服务器数据恢复—Raid5数据灾难不用愁,Raid5数据恢复原理了解下
- 服务器数据恢复—多次异常断电导致服务器raid不可用的数据恢复
- 服务器数据恢复-V7000存储更换磁盘数据同步失败的数据恢复案例
- 「服务器数据恢复」重装系统导致XFS文件系统分区丢失的数据恢复
- 服务器数据恢复-ESXi虚拟机被误删的数据恢复案例
- 「服务器数据恢复」Raid5阵列两块硬盘亮黄灯掉线的数据恢复案例
- 服务器数据恢复—基于oracle数据库的SAP数据恢复案例
- 「服务器虚拟化数据恢复」Xen Server环境下数据库数据恢复案例
- 服务器数据恢复—RAID故障导致oracle无法启动的数据恢复案例
- 服务器数据恢复—服务器磁盘阵列常见故障表现&解决方案
- 标签列表
-
- linux 查询端口号 (58)
- docker映射容器目录到宿主机 (66)
- 杀端口 (60)
- yum更换阿里源 (62)
- internet explorer 增强的安全配置已启用 (65)
- linux自动挂载 (56)
- 禁用selinux (55)
- sysv-rc-conf (69)
- ubuntu防火墙状态查看 (64)
- windows server 2022激活密钥 (56)
- 无法与服务器建立安全连接是什么意思 (74)
- 443/80端口被占用怎么解决 (56)
- ping无法访问目标主机怎么解决 (58)
- fdatasync (59)
- 405 not allowed (56)
- 免备案虚拟主机zxhost (55)
- linux根据pid查看进程 (60)
- dhcp工具 (62)
- mysql 1045 (57)
- 宝塔远程工具 (56)
- ssh服务器拒绝了密码 请再试一次 (56)
- ubuntu卸载docker (56)
- linux查看nginx状态 (63)
- tomcat 乱码 (76)
- 2008r2激活序列号 (65)