SK都是干什么的

发布时间 2023-11-24 12:57:36作者: 黑逍逍

参考文档:非常全面的Sklearn介绍 (qq.com)

scikit-learn: machine learning in Python — scikit-learn 1.3.2 documentation

分类(Classification): 实现了多种监督学习分类算法,例如支持向量机(SVM)、决策树、随机森林等。

  

from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
 

  

# 创建XXX分类器
rf_classifier = RandomForestClassifier(n_estimators=100, random_state=42)
# 训练模型
rf_classifier.fit(X_train, Y_train)
# 在测试集上进行预测
y_pred = rf_classifier.predict(x_test)

  

回归(Regression): 提供了多种监督学习回归算法,例如线性回归、岭回归、Lasso回归等。

聚类(Clustering): 包括了一系列无监督学习的聚类算法,如K均值聚类、层次聚类等。

降维(Dimensionality Reduction): 提供了降维算法,例如主成分分析(PCA)、奇异值分解(SVD)等。

模型选择(Model Selection): 包含了用于模型评估、参数调优和交叉验证的工具。

预处理(Preprocessing): 提供了数据预处理的工具,如标准化、归一化、缺失值填充等。

特征工程(Feature Engineering): 包括了一些用于特征选择和特征变换的工具。

集成方法(Ensemble Methods): 支持集成学习方法,如随机森林、梯度提升等。

 

计算指标(metrics):

 

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

分类指标:

  1. 准确性(Accuracy):

    • accuracy_score(y_true, y_pred): 计算分类准确性。
  2. 精确度(Precision):

    • precision_score(y_true, y_pred): 计算正类别的精确度。
  3. 召回率(Recall):

    • recall_score(y_true, y_pred): 计算正类别的召回率。
  4. F1分数(F1 Score):

    • f1_score(y_true, y_pred): 结合精确度和召回率的指标。
  5. 混淆矩阵(Confusion Matrix):

    • confusion_matrix(y_true, y_pred): 计算混淆矩阵。
  6. 分类报告(Classification Report):

    • classification_report(y_true, y_pred): 显示包括精确度、召回率、F1分数等在内的多个分类指标。
  7. ROC曲线和AUC值:

    • roc_curve(y_true, y_score): 计算ROC曲线的值。
    • roc_auc_score(y_true, y_score): 计算AUC值。
  8. Log Loss(对数损失):

    • log_loss(y_true, y_prob): 适用于概率输出的多分类对数损失。

回归指标:

  1. 均方误差(Mean Squared Error,MSE):

    • mean_squared_error(y_true, y_pred): 计算均方误差。
  2. 平均绝对误差(Mean Absolute Error,MAE):

    • mean_absolute_error(y_true, y_pred): 计算平均绝对误差。
  3. R²分数(R-squared Score):

    • r2_score(y_true, y_pred): 计算R²分数。

聚类指标:

  1. 轮廓系数(Silhouette Coefficient):

    • silhouette_score(X, labels): 计算聚类的轮廓系数。
  2. 调整兰德指数(Adjusted Rand Index):

    • adjusted_rand_score(labels_true, labels_pred): 计算调整兰德指数。
  3. 标准化互信息(Normalized Mutual Information):

    • normalized_mutual_info_score(labels_true, labels_pred): 计算标准化互信息