CORA Dataloader 分析

发布时间 2023-06-20 18:40:03作者: ZZX11
from stellargraph.datasets import DatasetLoader


class ant_1_4(
    DatasetLoader,
    name="ant-1.4",
    directory_name="ant-1.4",
    url="",
    url_archive_format="",
    expected_files=[],
    description="",
    source="",
):
    _NUM_FEATURES = 20

    def load(
        self,
        directed=False,
        largest_connected_component_only=False,
        subject_as_feature=False,
        edge_weights=None,
        str_node_ids=False,
    ):
        nodes_dtype = str if str_node_ids else int

        return _load_defect_data(
            self,
            directed,
            largest_connected_component_only,
            subject_as_feature,
            edge_weights,
            nodes_dtype,
        )


def _load_defect_data(
    dataset,
    directed,
    largest_connected_component_only,
    subject_as_feature,
    edge_weights,
    nodes_dtype,
):
    # assert isinstance(dataset, (Ant))
    if nodes_dtype is None:
        nodes_dtype = dataset._NODES_DTYPE
    node_data = pd.read_csv("D:\\CGCN-main\\CGCN-main\\downstream_task\\data\\ant\\" + dataset.name + "\\Process-Binary.csv")
    edgelist = pd.read_csv(
        "D:\\CGCN-main\\CGCN-main\\downstream_task\\data\\ant\\" + dataset.name+ "\\edges.txt", sep="\t", header=None, names=["target", "source"], dtype=nodes_dtype
    )
    node_data.apply(pd.to_numeric, errors='ignore')

    # 0 to buggy, 1 to clean
    subjects_num = node_data['bug']
    label_list = subjects_num.to_list()
    labels = []
    for i in range(len(label_list)):
        if label_list[i] == 1:
            labels.append('buggy')
        else:
            labels.append('clean')
    subjects = pd.Series(labels, dtype='str')

    cls = StellarDiGraph if directed else StellarGraph
#   定义边的类型。
    features = node_data.iloc[:, 3:-1]
    # 第四列到倒数第二列的所有数据
    feature_names = node_data.iloc[:, 2]
    minMax = preprocessing.MinMaxScaler()
    features_std = minMax.fit_transform(features)

    graph = cls({"class": features_std}, {"to": edgelist})

    if edge_weights is not None:
        # A weighted graph means computing a second StellarGraph after using the unweighted one to
        # compute the weights.
        edgelist["weight"] = edge_weights(graph, subjects, edgelist)
        graph = cls({"class": node_data[feature_names]}, {"to": edgelist})

    if largest_connected_component_only:
        cc_ids = next(graph.connected_components())
        return graph.subgraph(cc_ids), subjects[cc_ids]

    return graph, subjects