图融合GCN(Graph Convolutional Networks)
2021/9/8 6:36:13
本文主要是介绍图融合GCN(Graph Convolutional Networks),对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
图融合GCN(Graph Convolutional Networks)
数据其实是图(graph),图在生活中无处不在,如社交网络,知识图谱,蛋白质结构等。本文介绍GNN(Graph Neural Networks)中的分支:GCN(Graph Convolutional Networks)。
GCN的PyTorch实现
虽然GCN从数学上较难理解,但是,实现是非常简单的,值得注意的一点是,一般情况下邻接矩阵是稀疏矩阵,所以,在实现矩阵乘法时,采用稀疏运算会更高效。首先,图卷积层的实现:
import torch
import torch.nn as nn
class GraphConvolution(nn.Module):
"""GCN layer"""
def __init__(self, in_features, out_features, bias=True):
super(GraphConvolution, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight)
if self.bias isnotNone:
nn.init.zeros_(self.bias)
def forward(self, input, adj):
support = torch.mm(input, self.weight)
output = torch.spmm(adj, support)
if self.bias isnotNone:
return output + self.bias
else:
return output
def extra_repr(self):
return'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias isnotNone
)
对于GCN,只需要将图卷积层堆积起来就可以,这里,实现一个两层的GCN:
class GCN(nn.Module):
"""a simple two layer GCN"""
def __init__(self, nfeat, nhid, nclass):
super(GCN, self).__init__()
self.gc1 = GraphConvolution(nfeat, nhid)
self.gc2 = GraphConvolution(nhid, nclass)
def forward(self, input, adj):
h1 = F.relu(self.gc1(input, adj))
logits = self.gc2(h1, adj)
return logits
这里的激活函数采用ReLU,后面,将用这个网络实现一个图中节点的半监督分类任务。
数据的提取,只需要load就可以:
# https://github.com/tkipf/pygcn/blob/master/pygcn/utils.py
adj, features, labels, idx_train, idx_val, idx_test = load_data(path="./data/cora/")
值得注意的有两点,一是论文引用应该是单向图,但是在网络时,要先将其转成无向图,或者说建立双向引用,这个对模型训练结果影响较大:
# build symmetric adjacency matrix
adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
另外,官方实现中对邻接矩阵采用的是普通均值归一化,当然,也可以采用对称归一化方式:
def normalize_adj(adj):
"""compute
L=D^-0.5 * (A+I) * D^-0.5"""
adj += sp.eye(adj.shape[0])
degree = np.array(adj.sum(1))
d_hat = sp.diags(np.power(degree, -0.5).flatten())
norm_adj = d_hat.dot(adj).dot(d_hat)
return norm_adj
这里,只采用图中140个有标签样本对GCN进行训练,每个epoch计算出这些节点特征,然后计算loss:
loss_history = []
val_acc_history = []
for epoch in range(epochs):
model.train()
logits = model(features, adj)
loss = criterion(logits[idx_train],
labels[idx_train])
train_acc =
accuracy(logits[idx_train], labels[idx_train])
optimizer.zero_grad()
loss.backward()
optimizer.step()
val_acc = test(idx_val)
loss_history.append(loss.item())
val_acc_history.append(val_acc.item())
print("Epoch {:03d}: Loss
{:.4f}, TrainAcc {:.4}, ValAcc {:.4f}".format(
epoch, loss.item(),
train_acc.item(), val_acc.item()))
只需要训练200个epoch,就可以在测试集上达到80%左右的分类准确,GCN的强大可想而知:
融合BN和Conv层
在PyTorch中实现这个融合操作:nn.Conv2d参数:
- filter weights,W: conv.weight;
- bias,b: conv.bias;
nn.BatchNorm2d参数:
具体的实现代码如下(Google Colab, https://colab.research.google.com/drive/1mRyq_LlJW4u_rArzzhEe_T6tmEWoNN1K):
import torch
import
torchvision
def fuse(conv,
bn):
fused = torch.nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
bias=True
)
#
setting weights
w_conv =
conv.weight.clone().view(conv.out_channels, -1)
w_bn =
torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
fused.weight.copy_(
torch.mm(w_bn, w_conv).view(fused.weight.size()) )
#
setting bias
if
conv.bias isnotNone:
b_conv = conv.bias
else:
b_conv = torch.zeros(
conv.weight.size(0) )
b_bn = bn.bias -
bn.weight.mul(bn.running_mean).div(
torch.sqrt(bn.running_var + bn.eps)
)
fused.bias.copy_( b_conv + b_bn )
return
fused
#
Testing
# we
need to turn off gradient calculation because we didn't write it
torch.set_grad_enabled(False)
x = torch.randn(16,
3, 256,
256)
resnet18 =
torchvision.models.resnet18(pretrained=True)
#
removing all learning variables, etc
resnet18.eval()
model = torch.nn.Sequential(
resnet18.conv1,
resnet18.bn1
)
f1 = model.forward(x)
fused = fuse(model[0],
model[1])
f2 = fused.forward(x)
d = (f1 - f2).mean().item()
print("error:",d)
参考链接:
- Semi-Supervised Classification with Graph Convolutional Networks https://arxiv.org/abs/1609.02907
- How to do Deep Learning on Graphs with Graph Convolutional Networks https://towardsdatascience.com/how-to-do-deep-learning-on-graphs-with-graph-convolutional-networks-7d2250723780
- Graph Convolutional Networks http://tkipf.github.io/graph-convolutional-networks
- Graph Convolutional Networks in PyTorch https://github.com/tkipf/pygcn
- 回顾频谱图卷积的经典工作:从ChebNet到GCN https://www.jianshu.com/p/2fd5a2454781
- 图数据集之cora数据集介绍- 用pyton处理 - 可用于GCN任务 https://blog.csdn.net/yeziand01/article/details/93374216
- Speeding up model with fusing batch normalization and convolution (http://learnml.today/speeding-up-model-with-fusing-batch-normalization-and-convolution-3)
这篇关于图融合GCN(Graph Convolutional Networks)的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2025-01-06PingCAP 连续两年入选 Gartner 云数据库管理系统魔力象限“荣誉提及”
- 2025-01-05Easysearch 可搜索快照功能,看这篇就够了
- 2025-01-04BOT+EPC模式在基础设施项目中的应用与优势
- 2025-01-03用LangChain构建会检索和搜索的智能聊天机器人指南
- 2025-01-03图像文字理解,OCR、大模型还是多模态模型?PalliGema2在QLoRA技术上的微调与应用
- 2025-01-03混合搜索:用LanceDB实现语义和关键词结合的搜索技术(应用于实际项目)
- 2025-01-03停止思考数据管道,开始构建数据平台:介绍Analytics Engineering Framework
- 2025-01-03如果 Azure-Samples/aks-store-demo 使用了 Score 会怎样?
- 2025-01-03Apache Flink概述:实时数据处理的利器
- 2025-01-01使用 SVN合并操作时,怎么解决冲突的情况?-icode9专业技术文章分享