图神经网络学习task07(图预测任务实践)
2021/7/9 23:19:20
本文主要是介绍图神经网络学习task07(图预测任务实践),对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
一、本阶段的组队学习网站地址:[datawhale]
二、本期主要学习内容:
学习样本按需获取的数据集类的构造方法
最后学习基于图表征学习的图预测任务的实践
三、超大规模数据集类的创建
在前面的学习中我们只接触了数据可全部储存于内存的数据集,这些数据集对应的数据集类在创建对象时就将所有数据都加载到内存。然而在一些应用场景中,数据集规模超级大,我们很难有足够大的内存完全存下所有数据。因此需要一个按需加载样本到内存的数据集类。
在PyG中,我们通过继承torch_geometric.data.Dataset基类来自定义一个按需加载样本到内存的数据集类。
四、创建超大规模数据集类实践
PCQM4M-LSC是一个分子图的量子特性回归数据集,它包含了3,803,453个图。
注意以下代码依赖于ogb包,通过pip install ogb命令可安装此包。
主要代码如下:
import os import os.path as osp import pandas as pd import torch from ogb.utils import smiles2graph from ogb.utils.torch_util import replace_numpy_with_torchtensor from ogb.utils.url import download_url, extract_zip from rdkit import RDLogger from torch_geometric.data import Data, Dataset import shutil RDLogger.DisableLog('rdApp.*') class MyPCQM4MDataset(Dataset): def __init__(self, root): self.url = 'https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m_kddcup2021.zip' super(MyPCQM4MDataset, self).__init__(root) filepath = osp.join(root, 'raw/data.csv.gz') data_df = pd.read_csv(filepath) self.smiles_list = data_df['smiles'] self.homolumogap_list = data_df['homolumogap'] @property def raw_file_names(self): return 'data.csv.gz' def download(self): path = download_url(self.url, self.root) extract_zip(path, self.root) os.unlink(path) shutil.move(osp.join(self.root, 'pcqm4m_kddcup2021/raw/data.csv.gz'), osp.join(self.root, 'raw/data.csv.gz')) def len(self): return len(self.smiles_list) def get(self, idx): smiles, homolumogap = self.smiles_list[idx], self.homolumogap_list[idx] graph = smiles2graph(smiles) assert(len(graph['edge_feat']) == graph['edge_index'].shape[1]) assert(len(graph['node_feat']) == graph['num_nodes']) x = torch.from_numpy(graph['node_feat']).to(torch.int64) edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64) edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64) y = torch.Tensor([homolumogap]) num_nodes = int(graph['num_nodes']) data = Data(x, edge_index, edge_attr, y, num_nodes=num_nodes) return data # 获取数据集划分 def get_idx_split(self): split_dict = replace_numpy_with_torchtensor(torch.load(osp.join(self.root, 'pcqm4m_kddcup2021/split_dict.pt'))) return split_dict if __name__ == "__main__": dataset = MyPCQM4MDataset('dataset2') from torch_geometric.data import DataLoader from tqdm import tqdm dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4) for batch in tqdm(dataloader): pass
在生成一个该数据集类的对象时,程序首先会检查指定的文件夹下是否存在data.csv.gz文件,如果不在,则会执行download方法,这一过程是在运行super类的__init__方法中发生的。然后程序继续执行__init__方法的剩余部分,读取data.csv.gz文件,获取存储图信息的smiles格式的字符串,以及回归预测的目标homolumogap。我们将由smiles格式的字符串转成图的过程在get()方法中实现,这样我们在生成一个DataLoader变量时,通过指定num_workers可以实现并行执行生成多个图。
五、图预测任务实践
基于GIN的图表示学习神经网络,和在上半节中自己定义的数据集来实现分子图的量子性质预测任务。
代码见codes\gin_regression文件夹
试验运行开始后,程序会在saves目录下创建一个task_name参数指定名称的文件夹用于记录试验过程,当saves目录下已经有一个同名的文件夹时,程序会在task_name参数末尾增加一个后缀作为文件夹名称。试验运行过程中,所有的print输出都会写入到试验文件夹下的output文件,tensorboard.SummaryWriter记录的信息也存储在试验文件夹下的文件中。
这篇关于图神经网络学习task07(图预测任务实践)的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-12-21《鸿蒙HarmonyOS应用开发从入门到精通(第2版)》简介
- 2024-12-21后台管理系统开发教程:新手入门全指南
- 2024-12-21后台开发教程:新手入门及实战指南
- 2024-12-21后台综合解决方案教程:新手入门指南
- 2024-12-21接口模块封装教程:新手必备指南
- 2024-12-21请求动作封装教程:新手必看指南
- 2024-12-21RBAC的权限教程:从入门到实践
- 2024-12-21登录鉴权实战:新手入门教程
- 2024-12-21动态权限实战入门指南
- 2024-12-21功能权限实战:新手入门指南