GridSearchCV中的scoring

发布时间 2023-05-16 16:40:33作者: 温小皮

说明

scoring参数输入形式

包括字符串、可调用对象或评分函数。以下是常用的评分规则示例:

  1. 使用预定义的字符串指定评分规则:

    • 'accuracy':准确率(分类问题)
      'precision':精确率(分类问题)
      'recall':召回率(分类问题)
      'f1':F1分数(分类问题)
      'r2':R2分数(回归问题)
      'mean_squared_error':均方误差(回归问题)
  2. 使用自定义的评分函数:
    可以自定义一个评分函数,该函数接受真实标签和预测标签作为输入,并返回一个评分值。

  3. 使用可调用对象:
    可以传递一个实现了scoring接口的可调用对象作为评分规则。这个可调用对象接受真实标签和预测标签作为输入,并返回一个评分值。

GridSearchCV默认的scoring

  • GridSearchCV 默认的评估指标取决于所使用的模型。
  • 对于许多机器学习模型,如回归和分类,评估指标可以通过指定 scoring 参数来进行选择或修改。
  • 对于一些模型,默认的评估指标可以在文档中查找,例如 SVM 模型的默认评估指标是 accuracy。
  • 如果没有指定 scoring 参数,GridSearchCV 将默认使用模型的默认评估指标。
  • 在 sklearn 中,不同模型的默认评估指标可以在相应的文档中查找。

一些模型的默认评估指标:

  • 对于分类问题,SVC 模型的默认评估指标是 accuracy(准确率),即正确分类的样本占总样本数的比例。

例子

scoring 参数可以使用多种形式指定,包括字符串(使用内置指标),可调用函数(自定义指标),或一个字典(指定不同的指标)。下面是一些示例:

  1. 使用内置指标:可以使用字符串指定内置指标,例如:
#这里指定使用 accuracy 作为评估指标。
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LogisticRegression

param_grid = {'C': [0.1, 1, 10], 'penalty': ['l1', 'l2']}
clf = LogisticRegression(max_iter=1000)
grid_search = GridSearchCV(clf, param_grid, scoring='accuracy')
  1. 自定义指标:可以使用一个可调用函数指定自定义指标,例如:
#这里使用 make_scorer() 函数将 f1_score 转换为一个可调用的评估函数,并使用其作为评估指标。
from sklearn.metrics import make_scorer, f1_score

def custom_scorer(y_true, y_pred):
    return f1_score(y_true, y_pred, average='macro')

param_grid = {'C': [0.1, 1, 10], 'penalty': ['l1', 'l2']}
clf = LogisticRegression(max_iter=1000)
grid_search = GridSearchCV(clf, param_grid, scoring=make_scorer(custom_scorer))
  1. 多个指标:可以使用一个字典来指定多个指标,例如:
"""
这里指定了两个指标:accuracy 和 recall,并使用 refit 参数指定在搜索过程结束后使用哪个指标来重新拟合模型。
refit这里使用 accuracy 来重新拟合模型。
"""
from sklearn.metrics import accuracy_score, recall_score, make_scorer

scoring = {'accuracy': make_scorer(accuracy_score), 
           'recall': make_scorer(recall_score)}

param_grid = {'C': [0.1, 1, 10], 'penalty': ['l1', 'l2']}
clf = LogisticRegression(max_iter=1000)
grid_search = GridSearchCV(clf, param_grid, scoring=scoring, refit='accuracy')