链接预测任务旨在预测图中两个节点间是否存在边。这一任务在社交推荐、知识图谱完成等场景中颇为关键。实现链接预测通常视作二分类任务,将图中已存在的边作为正样本,而负样本则为图中不存在的边。通过正负样本的合并,形成训练集与测试集。评估模型效果时,常使用如AUC值的二分类模型指标。
在大规模推荐系统或信息检索中,评估模型效果时,需考虑top-k预测结果的准确性,因此引入了其他指标,如MR(MeanRank)、MRR(Mean Reciprocal Rank)和Hit@n。MR指标计算整个图谱中正确三元组排序后的序号平均值,MRR指标计算排序后序号倒数的平均值,而Hit@n则衡量正确三元组排序后序号小于n的比例。
在Cora引文数据集上,预测论文间的引用关系或被引用关系时,通过采用上述指标,可以评估模型性能。
使用DGL框架实现GNN模型进行链接预测时,首先导入所需库,如dgl和pytorch。数据加载部分通过dgl库提供的Cora数据对象进行,其中dataset可能包含多个图,但Cora数据集仅由单个图组成。正负数据划分需随机选取数据集中10%边作为测试集的正样本,剩余边作为训练集。为保证训练集和测试集的正负样本比例为1:1,随机生成等量的负样本。在模型训练时,需确保测试集中边不在训练集中,避免数据泄露。
在链接预测任务中,定义节点对得分函数至关重要。DGL提供了一种方式,将节点对视为一个图,基于节点特征和原始图中边特征计算节点对得分。使用DGLGraph.apply_edges方法计算新边特征,便于在多个图上进行节点特征传递。可使用官方提供的预测函数,如DotPredictor或自定义函数,如MLPPredictor,以实现节点对的得分预测。
训练过程与pytorch模型训练相似,通常包括迭代优化、损失计算和梯度更新。通过上述步骤,可以实现GNN模型在链接预测任务中的应用,并评估模型在Cora数据集上的性能。
本文如未解决您的问题请添加抖音号:51dongshi(抖音搜索懂视),直接咨询即可。