[DGL基础系列]-异构图神经网络RGCN节点分类案例

文摘   2024-08-21 12:36   浙江  

异构图比较复杂,先从简单的demo写起。有时用户会想在异构图上进行图神经网络的训练。先整个节点分类的案例。我没创建如下格式的一个图,然后进行节点分类任务的学习。

这个 hetero_graph 异构图有以下这些边的类型:

  • ('user', 'follow', 'user')

  • ('user', 'followed-by', 'user')

  • ('user', 'click', 'item')

  • ('item', 'clicked-by', 'user')

  • ('user', 'dislike', 'item')

  • ('item', 'disliked-by', 'user')

加载所需要的库

import dglimport dgl.nn as dglnnimport numpy as npimport torch
import torch.nn as nnimport torch.nn.functional as F

数据集构建

n_users = 1000 #用户数n_items = 500 #项目数n_follows = 3000 #用户采用数n_clicks = 5000 #用户点击数n_dislikes = 500 #用户不喜欢数n_hetero_features = 10 #节点特征维度n_user_classes = 5 #用户类型数,训练d的输层使用n_max_clicks = 10 #

边数据构造

#np.random.randint(low, high, size, dtype='l')#采用边-3000follow_src = np.random.randint(0, n_users, n_follows)follow_dst = np.random.randint(0, n_users, n_follows)#点击边-5000click_src = np.random.randint(0, n_users, n_clicks)click_dst = np.random.randint(0, n_items, n_clicks)#不喜欢边-500dislike_src = np.random.randint(0, n_users, n_dislikes)dislike_dst = np.random.randint(0, n_items, n_dislikes)

创建DGL的异构图

hetero_graph = dgl.heterograph({ #正反两个方向构边    ('user', 'follow', 'user'): (follow_src, follow_dst),    ('user', 'followed-by', 'user'): (follow_dst, follow_src),
('user', 'click', 'item'): (click_src, click_dst), ('item', 'clicked-by', 'user'): (click_dst, click_src),
('user', 'dislike', 'item'): (dislike_src, dislike_dst), ('item', 'disliked-by', 'user'): (dislike_dst, dislike_src)})hetero_graph

特征构造

#用户特征,1000个用户特征10维hetero_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features)#项目特征,500个项目特征10维hetero_graph.nodes['item'].data['feature'] = torch.randn(n_items, n_hetero_features)#用户类型标签,1000维向量hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,))#边标签hetero_graph.edges['click'].data['label'] = torch.randint(1, n_max_clicks, (n_clicks,)).float()
hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6)hetero_graph.edges['click'].data['train_mask'] = torch.zeros(n_clicks, dtype=torch.bool).bernoulli(0.6)

节点分类任务

用户分为五类,对用户的类型进行预测

#定义特征聚合模块class RGCN(nn.Module):def __init__(self, in_feats, hid_feats, out_feats, rel_names):
super().__init__()self.conv1 = dglnn.HeteroGraphConv({rel:dglnn.GraphConv(in_feats, hid_feats)for rel in rel_names}, aggregate='sum')self.conv2 = dglnn.HeteroGraphConv({rel:dglnn.GraphConv(hid_feats, out_feats)for rel in rel_names}, aggregate='sum')def forward(self, graph, inputs): h = self.conv1(graph, inputs) h = {k:F.relu(v) for k, v in h.items()} h = self.conv2(graph, h)return h

构造模型

#构造模型model = RGCN(in_feats=n_hetero_features, hid_feats=20,              out_feats=n_user_classes, rel_names=hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature'] #用户特征item_feats = hetero_graph.nodes['item'].data['feature'] #项目特征labels = hetero_graph.nodes['user'].data['label'] #用户类型标签train_mask = hetero_graph.nodes['user'].data['train_mask']

这里用到了hetero_graph.etypes,我们看看这个是啥:

hetero_graph.etypeshetero_graph.etypes['clicked-by', 'disliked-by', 'click', 'dislike', 'follow', 'followed-by']

可以看到就是边的类型

node_features = {'user':user_feats, 'item':item_feats}#特征字典h_dict = model(hetero_graph, node_features)


模型训练

#模型优化器opt = torch.optim.Adam(model.parameters())
best_train_acc = 0loss_list = []train_score_list = []
#迭代训练for epoch in range(500):model.train() #输入图和节点特征,提取出user的特征logits = model(hetero_graph, node_features)['user'] #计算损失loss = F.cross_entropy(logits[train_mask], labels[train_mask]) #预测userpred = logits.argmax(1) #计算准确率train_acc = (pred[train_mask] == labels[train_mask]).float().mean()if best_train_acc < train_acc:best_train_acc = train_acctrain_score_list.append(train_acc)
#反向优化opt.zero_grad()loss.backward()opt.step()
loss_list.append(loss.item()) #输出训练结果print('Loss %.4f, Train Acc %.4f (Best %.4f)' % (loss.item(),train_acc.item(),best_train_acc.item(),))

往期精彩:

复杂网络社群检测-Leiden算法实战

金融风控-通用建模流程

业务实战-如何搭建知识图谱?

作为风控人,你会找工作么?

利用决策树分分钟生成上千条策略-代码更新

万物皆网络,万字长文详解社区发现算法Louvain

策略自动化挖掘、团伙挖掘-课程链接

社区发现之标签传播算法(LPA)

风控中的复杂网络-学习路径图

长按关注本号             长按加我进群
      

小伍哥聊风控
风控策略&算法,内容风控、复杂网络挖掘、图神经网络、异常检测、策略自动化、黑产挖掘、反欺诈、反作弊等
 最新文章