【PYG】GNN和全连接层(FC)分别在不同的类中,使用反向传播联合训练,实现端到端的训练过程

文章目录

    • 基本步骤
    • GNN和全连接层(FC)联合训练
      • 1. 定义GNN模型类
      • 2. 定义FC模型类
      • 3. 训练循环中的联合优化
      • 解释
      • 完整代码
    • GNN和全连接层(FC)分别使用不同的优化器和学习率分别进行参数更新
      • 解释

基本步骤

要从GNN(图神经网络)中提取特征,并使用全连接层(FC,Fully Connected Layer)进行后续处理,可以按照以下步骤进行:

  1. 构建图神经网络模型:选择一种GNN架构,例如GCN(Graph Convolutional Network)、GAT(Graph Attention Network)等。你可以使用深度学习框架(如PyTorch、TensorFlow)来实现。

  2. 获取节点特征和图结构:准备好节点特征矩阵和邻接矩阵,这些是GNN模型的输入。

  3. 通过GNN提取特征

    • 设计GNN模型的前向传播过程,将节点特征和邻接矩阵输入GNN层。
    • 从GNN层的输出中提取节点的嵌入特征。
  4. 连接全连接层进行分类或回归

    • 将GNN提取的节点特征作为输入传递给一个或多个全连接层。
    • 通过全连接层进行后续的分类、回归等任务。

GNN和全连接层(FC)联合训练

如果GNN和全连接层(FC)分别在不同的类中,并且你希望它们可以联合训练,你可以通过以下步骤实现端到端的训练过程,并确保反向传播能够正确进行:

  1. 定义GNN和FC模型:分别定义GNN和FC模型类。
  2. 特征提取与分类:在训练循环中,将GNN提取的特征传递给FC进行分类。
  3. 联合优化:使用一个优化器来更新两个模型的参数。

以下是具体的实现步骤和代码示例:

1. 定义GNN模型类

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, Batch
from sklearn.preprocessing import StandardScaler

class GNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        gnn_features = F.relu(x)
        return gnn_features

2. 定义FC模型类

class FC(nn.Module):
    def __init__(self, in_features, num_classes):
        super(FC, self).__init__()
        self.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        out = self.fc(x)
        return out

3. 训练循环中的联合优化

# 假设我们有一些数据
num_nodes_per_graph = 10
num_graphs = 5
num_node_features = 16
num_classes = 3

# 创建多个图数据
graphs = []
for _ in range(num_graphs):
    x = torch.randn((num_nodes_per_graph, num_node_features))
    scaler = StandardScaler()
    x = torch.tensor(scaler.fit_transform(x), dtype=torch.float)  # 标准化
    edge_index = torch_geometric.utils.grid(num_nodes_per_graph)
    graphs.append(Data(x=x, edge_index=edge_index))

# 批处理数据
batch = Batch.from_data_list(graphs)

# 创建模型
gnn_model = GNN(in_channels=num_node_features, hidden_channels=32, out_channels=64)
fc_model = FC(in_features=64, num_classes=num_classes)

# 使用一个优化器来联合优化两个模型的参数
optimizer = torch.optim.Adam(list(gnn_model.parameters()) + list(fc_model.parameters()), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# 生成一些随机目标
target = torch.randint(0, num_classes, (num_nodes_per_graph * num_graphs,))

# 训练模型
for epoch in range(100):
    gnn_model.train()
    fc_model.train()
    
    optimizer.zero_grad()
    
    # 前向传播通过GNN模型
    gnn_features = gnn_model(batch)
    
    # 前向传播通过FC模型
    output = fc_model(gnn_features)
    
    # 计算损失
    loss = criterion(output, target)
    
    # 反向传播
    loss.backward()
    
    # 优化器步
    optimizer.step()
    
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

# 查看特征
print("Extracted GNN features:", gnn_features)

解释

  1. GNN模型类GNN类定义了一个简单的两层GCN模型,用于特征提取。
  2. FC模型类FC类定义了一个全连接层模型,用于分类。
  3. 联合优化
    • 在训练循环中,首先通过GNN模型提取特征,然后将提取的特征传递给FC模型进行分类。
    • 使用一个优化器来同时优化GNN和FC模型的参数。
    • 通过调用optimizer.zero_grad()清除梯度,调用loss.backward()进行反向传播,最后调用optimizer.step()更新参数。

通过这种方式,尽管GNN和FC模型分别在不同的类中,它们仍然可以端到端地进行联合训练,并确保梯度正确地传播到整个模型的每一部分。

使用正确的参数来生成随机图。torch_geometric.utils.erdos_renyi_graph需要使用num_nodes和edge_prob参数

完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv
from torch_geometric.utils import erdos_renyi_graph
from sklearn.preprocessing import StandardScaler

class GNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        gnn_features = F.relu(x)
        return gnn_features

class FC(nn.Module):
    def __init__(self, in_features, num_classes):
        super(FC, self).__init__()
        self.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        out = self.fc(x)
        return out

# 假设我们有一些数据
num_nodes_per_graph = 10
num_graphs = 5
num_node_features = 16
num_classes = 3

# 创建多个图数据
graphs = []
for _ in range(num_graphs):
    x = torch.randn((num_nodes_per_graph, num_node_features))
    scaler = StandardScaler()
    x = torch.tensor(scaler.fit_transform(x), dtype=torch.float)  # 标准化
    edge_index = erdos_renyi_graph(num_nodes=num_nodes_per_graph, edge_prob=0.5)  # 生成随机图
    graphs.append(Data(x=x, edge_index=edge_index))

# 批处理数据
batch = Batch.from_data_list(graphs)

# 创建模型
gnn_model = GNN(in_channels=num_node_features, hidden_channels=32, out_channels=64)
fc_model = FC(in_features=64, num_classes=num_classes)

# 使用一个优化器来联合优化两个模型的参数
optimizer = torch.optim.Adam(list(gnn_model.parameters()) + list(fc_model.parameters()), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# 生成一些随机目标
target = torch.randint(0, num_classes, (num_nodes_per_graph * num_graphs,))

# 训练模型
for epoch in range(100):
    gnn_model.train()
    fc_model.train()
    
    optimizer.zero_grad()
    
    # 前向传播通过GNN模型
    gnn_features = gnn_model(batch)
    
    # 前向传播通过FC模型
    output = fc_model(gnn_features)
    
    # 计算损失
    loss = criterion(output, target)
    
    # 反向传播
    loss.backward()
    
    # 优化器步
    optimizer.step()
    
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

# 查看特征
print("Extracted GNN features:", gnn_features)

GNN和全连接层(FC)分别使用不同的优化器和学习率分别进行参数更新

如果你想为GNN和全连接层(FC)分别使用不同的优化器和学习率,可以按照以下步骤进行:

  1. 定义两个优化器:一个用于GNN模型,另一个用于FC模型。
  2. 分别进行参数更新:在训练循环中,分别对两个模型进行前向传播、损失计算和反向传播,然后使用各自的优化器更新参数。

以下是实现代码示例:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv
from torch_geometric.utils import erdos_renyi_graph
from sklearn.preprocessing import StandardScaler

class GNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        gnn_features = F.relu(x)
        return gnn_features

class FC(nn.Module):
    def __init__(self, in_features, num_classes):
        super(FC, self).__init__()
        self.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        out = self.fc(x)
        return out

# 假设我们有一些数据
num_nodes_per_graph = 10
num_graphs = 5
num_node_features = 16
num_classes = 3

# 创建多个图数据
graphs = []
for _ in range(num_graphs):
    x = torch.randn((num_nodes_per_graph, num_node_features))
    scaler = StandardScaler()
    x = torch.tensor(scaler.fit_transform(x), dtype=torch.float)  # 标准化
    edge_index = erdos_renyi_graph(num_nodes=num_nodes_per_graph, edge_prob=0.5)  # 生成随机图
    graphs.append(Data(x=x, edge_index=edge_index))

# 批处理数据
batch = Batch.from_data_list(graphs)

# 创建模型
gnn_model = GNN(in_channels=num_node_features, hidden_channels=32, out_channels=64)
fc_model = FC(in_features=64, num_classes=num_classes)

# 使用两个优化器分别优化GNN和FC模型的参数
optimizer_gnn = torch.optim.Adam(gnn_model.parameters(), lr=1e-3)  # GNN使用较高的学习率
optimizer_fc = torch.optim.Adam(fc_model.parameters(), lr=1e-4)  # FC使用较低的学习率
criterion = nn.CrossEntropyLoss()

# 生成一些随机目标
target = torch.randint(0, num_classes, (num_nodes_per_graph * num_graphs,))

# 训练模型
for epoch in range(100):
    gnn_model.train()
    fc_model.train()
    
    optimizer_gnn.zero_grad()
    optimizer_fc.zero_grad()
    
    # 前向传播通过GNN模型
    gnn_features = gnn_model(batch)
    
    # 前向传播通过FC模型
    output = fc_model(gnn_features)
    
    # 计算损失
    loss = criterion(output, target)
    
    # 反向传播
    loss.backward()
    
    # 使用各自的优化器更新参数
    optimizer_gnn.step()
    optimizer_fc.step()
    
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

# 查看特征
print("Extracted GNN features:", gnn_features)

解释

  1. GNN模型类GNN类定义了一个简单的两层GCN模型,用于特征提取。
  2. FC模型类FC类定义了一个全连接层模型,用于分类。
  3. 数据生成:使用torch_geometric.utils.erdos_renyi_graph生成随机图数据,并确保参数正确。
  4. 联合优化
    • 定义两个优化器,分别用于GNN和FC模型,并为它们设置不同的学习率。
    • 在训练循环中,首先通过GNN模型提取特征,然后将提取的特征传递给FC模型进行分类。
    • 使用各自的优化器来分别清除梯度、进行反向传播和更新参数。

通过这种方式,尽管GNN和FC模型分别在不同的类中,并使用不同的优化器和学习率,它们仍然可以端到端地进行联合训练,并确保梯度正确地传播到整个模型的每一部分。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/781163.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

cross attention交叉熵注意力机制

交叉注意力(Cross-Attention)则是在两个不同序列上计算注意力,用于处理两个序列之间的语义关系。在两个不同的输入序列之间计算关联度和加权求和的机制。具体来说,给定两个输入序列,cross attention机制将一个序列中的每个元素与另一个序列中…

Java请求webService,IDEA生成客户端调用代码

Axis是Apache开放源代码组织的一个项目,全称为Apache Extensible Interaction System,简称Axis。它是一个基于Java的SOAP(Simple Object Access Protocol,简单对象访问协议)引擎,提供创建服务器端、客户端和…

Linux基础: 二. Linux的目录和文件

文章目录 二. Linux的目录和文件1.1 目录概要1.2 目录详细说明 二. Linux的目录和文件 1.1 目录概要 command:ls / Linux的文件系统像一棵树一样,树干是根目录(/),树枝是子目录,树叶是文件; …

QML:Settings介绍

用途 提供持久的独立于平台的应用程序设置。 用户通常希望应用程序在会话中记住其设置(窗口大小、位置、选项等)。Settings能够以最小的工作量保存和恢复此类应用程序设置。 通过在Settings元素中声明属性来指定各个设置值。仅支持由QSettings识别的值…

2024 JuniorCryptCTF reppc 部分wp

Random cipher 文本编辑器打开附件 比较简单。脚本 Mutated Caesar 文本编辑器打开附件 比较简单。脚本 Pizza 附件拖入dnSpy 比较简单。脚本 l33t Leet,又称黑客语,是指一种发源于欧美地区的BBS、线上游戏和黑客社群所使用的文字书写方式,通…

Polkadot(DOT)即将爆雷?治理无能还歧视亚洲!资金将在两年内耗尽!是下一个FTX吗?

近期,关于Polkadot(DOT)生态圈的一系列负面消息引发了业界和投资者的广泛关注。从高昂的营销开支、缺乏实际业务亮点,再到治理问题和种族歧视指控,Polkadot似乎正面临着严峻的危机。业内人士警告,Polkadot的财政状况堪忧&#xff…

【C语言】qsort()函数详解:能给万物排序的神奇函数

🦄个人主页:修修修也 🎏所属专栏:C语言 ⚙️操作环境:Visual Studio 2022 目录 一.qsort()函数的基本信息及功能 二.常见的排序算法及冒泡排序 三.逐一解读qsort()函数的参数及其原理 1.void* base 2.size_t num 3.size_t size 4.int (*compar)(c…

机器学习Day12:特征选择与稀疏学习

1.子集搜索与评价 相关特征:对当前学习任务有用的特征 无关特征:对当前学习任务没用的特征 特征选择:从给定的特征集合中选择出相关特征子集的过程 为什么要特征选择? 1.任务中经常碰到维数灾难 2.去除不相关的特征能降低学习的…

认证授权auth

什么是认证授权 认证授权包含 认证和授权两部分。 什么是用户身份认证? 用户身份认证即当用户访问系统资源时,系统要求验证用户的身份信息,身份合法方可继续访问常见的用户身份认证表现形式有 用户名密码登录微信扫码登录等 什么是用户授…

【数据结构】链表带环问题分析及顺序表链表对比分析

【C语言】链表带环问题分析及顺序表链表对比分析 🔥个人主页:大白的编程日记 🔥专栏:C语言学习之路 文章目录 【C语言】链表带环问题分析及顺序表链表对比分析前言一.顺序表和链表对比1.1顺序表和链表的区别1.2缓存利用率&#…

ID3算法决策树

步骤: 先计算出信息量;信息熵;信息增量; 再比较信息增量的大小,确定分类依据。 信息量: 信息熵: 信息增益:

【网络安全】实验五(身份隐藏与ARP欺骗)

一、本次实验的实验目的 (1)了解网络攻击中常用的身份隐藏技术,掌握代理服务器的配置及使用方法 (2)通过实现ARP欺骗攻击,了解黑客利用协议缺陷进行网络攻击的一般方法 二、搭配环境 打开三台虚拟机&#…

IntelliJ IDEA 同时多行同时编辑操作快捷键

首先 点击要编辑的地方,长按鼠标左键不放,同时按住 Ctrl Shift Alt,然后就可以进行多行编辑了

【Unity】RPG2D龙城纷争(八)寻路系统

更新日期:2024年7月4日。 项目源码:第五章发布(正式开始游戏逻辑的章节) 索引 简介一、寻路系统二、寻路规则(角色移动)三、寻路规则(角色攻击)四、角色移动寻路1.自定义寻路规则2.寻…

如何根据控制框图写传递函数

控制框图(也称为方块图或信号流图)是控制系统工程中常用的一种图形表示方法,用于描述系统中各个组件之间的关系以及信号流向。传递函数则是描述线性时不变系统动态特性的数学模型,通常用于分析和设计控制系统。 识别组件&#xff…

Learn To Rank

在信息检索中,给定一个query,搜索引擎召回一系列相关的Documents,然后对这些Documents进行排序,最后将Top N的Documents输出。 排序问题最关注的是各Documents之间的相对顺序关系,而不是各个Documents的预测分最准确。…

GD32实战篇-双向数控BUCK-BOOST-BOOST升压理论基础

本文章基于兆易创新GD32 MCU所提供的2.2.4版本库函数开发 向上代码兼容GD32F450ZGT6中使用 后续项目主要在下面该专栏中发布: https://blog.csdn.net/qq_62316532/category_12608431.html?spm1001.2014.3001.5482 感兴趣的点个关注收藏一下吧! 电机驱动开发可以跳转…

即时通讯平台项目测试(登录/注册页面)

http://8.130.98.211:8080/login.html项目访问地址:即时通讯平台http://8.130.98.211:8080/login.html 本篇文章进行登录和注册页面的测试。自动化脚本的依赖在文章末尾。 登录页面测试 UI测试 测试环境:Win11;IntelliJ IDEA 2023.2&#…

C语言课设--读取文件并统计数据

读取文件并统计数据 // 统计记事本英文字符数 //功能 &#xff1a; 读 文件 a.txt 统计出各种 数字 大写字母 小写字母 0~9 A~Z a ~z各有多少个 程序主体框架如下&#xff1a; #include<stdio.h> int a[128]{0};//其中数组元素a[i]保存 ASCII 码为i的字符的数量 v…

css样式学习样例之边框

成品效果 边框固定 .login_box{width: 450px;height: 300px;background-color: aliceblue;border-radius: 3px;position: absolute;left: 50%;top: 50%;transform: translate(-50%,-50%); }这段CSS代码定义了一个名为.login_box的类的样式&#xff0c;它主要用于创建一个登录框…