【算法】决策树算法:ID3

发布时间 2023-12-20 18:40:10作者: 我爱我家喵喵
import math
from collections import Counter

# 创建数据集
def create_dataset():
    dataset = [
        # 年龄, 工作, 房子,信用,标签
        ['青年', 0, 0, '一般', '0'],
        ['青年', 0, 0, '好', '0'],
        ['青年', 1, 0, '好', '1'],
        ['青年', 1, 1, '一般', '1'],
        ['青年', 0, 0, '一般', '0'],
        ['中年', 0, 0, '一般', '0'],
        ['中年', 0, 0, '好', '0'],
        ['中年', 1, 1, '好', '1'],
        ['中年', 0, 1, '很好', '1'],
        ['中年', 0, 1, '很好', '1'],
        ['老年', 0, 1, '很好', '1'],
        ['老年', 0, 1, '好', '1'],
        ['老年', 1, 0, '好', '1'],
        ['老年', 1, 0, '很好', '1'],
        ['老年', 0, 0, '一般', '0']
    ]
    return dataset

# 计算熵
def cal_entropy(dataset):
    label_count = {}
    # 统计样本标签
    for item in dataset:
        # 样本标签
        label = item[-1]
        # 不在字典中
        if label not in label_count:
            label_count[label] = 0
        # 计数+1
        label_count[label] += 1
    # 计算熵
    entropy = 0.0
    for label in label_count:
        # 概率 = 样本数 / 样本总数
        p = label_count[label] / len(dataset)
        # 计算熵
        if p == 0:
            continue
        entropy -= p * math.log(p, 2)
    return entropy

# 计算条件熵
def cal_cond_entropy(dataset, feature, value):
    ret_dataset = []
    for item in dataset:
        if item[feature] == value:
            # 抽取当前特征左侧的数据
            except_item = item[:feature]
            # 抽取当前特征右侧的数据
            except_item.extend(item[feature + 1:])
            ret_dataset.append(except_item)
    return ret_dataset

# 计算信息增益
def cal_info_gain(dataset):
    # 样本数
    num_feature = len(dataset[0]) - 1
    # 计算基本熵
    base_entropy = cal_entropy(dataset)
    # 最优的信息增益
    best_info_gain = 0.0
    # 最优的信息增益的索引
    best_info_gain_feature = 0
    for i in range(num_feature):
        feature_list = [example[i] for example in dataset]
        feature_set = set(feature_list)
        conditional_entropy = 0.0

        for value in feature_set:
            # 计算条件熵
            sub_dataset = cal_cond_entropy(dataset, i, value)
            p = float(len(sub_dataset)) / len(dataset)
            conditional_entropy += p * cal_entropy(sub_dataset)

        info_gain = base_entropy - conditional_entropy
        # 选取最大的信息索引
        if info_gain > best_info_gain:
            best_info_gain = info_gain
            best_info_gain_feature = i
    return best_info_gain_feature, best_info_gain

# 多数表决法决定叶子节点分类
def majority_cnt(class_list):
    class_count = Counter(class_list)
    sorted_class_count = sorted(class_count.items(), key=lambda x: x[1], reverse=True)
    return sorted_class_count[0][0]

# 构建决策树
def build_decision_tree(dataset, labels):
    class_list = [data[-1] for data in dataset]
    if class_list.count(class_list[0]) == len(class_list):  # 类别完全相同则停止继续划分
        return class_list[0]
    if len(dataset[0]) == 1:  # 遍历完所有特征时返回出现次数最多的类别
        return majority_cnt(class_list)
    best_feat, best_info_gain = cal_info_gain(dataset)
    best_feat_label = labels[best_feat]
    my_tree = {best_feat_label: {}}
    new_labels = labels[:]
    del(new_labels[best_feat])
    feat_values = [data[best_feat] for data in dataset]
    unique_vals = set(feat_values)
    for value in unique_vals:
        sub_labels = new_labels[:]
        my_tree[best_feat_label][value] = build_decision_tree(cal_cond_entropy(dataset, best_feat, value), sub_labels)
    return my_tree

# 使用决策树进行分类
def classify(input_tree, feat_labels, test_data):
    first_str = list(input_tree.keys())[0]
    second_dict = input_tree[first_str]
    feat_index = feat_labels.index(first_str)
    key = test_data[feat_index]
    value_of_feat = second_dict[key]
    if isinstance(value_of_feat, dict):
        class_label = classify(value_of_feat, feat_labels, test_data)
    else:
        class_label = value_of_feat
    return class_label

# ID3 算法举例
if __name__ == '__main__':
    dataset = create_dataset()
    labels = ['年龄', '工作', '房子', '信用']
    print("熵:", cal_entropy(dataset))
    best_info_gain_feature, best_info_gain = cal_info_gain(dataset)
    print("信息增益:", best_info_gain_feature, best_info_gain)

    tree = build_decision_tree(dataset, labels)
    print("决策树:", tree)
    print("测试数据:", dataset[0])
    result = classify(tree, labels, ['老年', 1, 0, '一般'])
    print("预测结果:", result)

运行效果: