-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathTransNAR-GCN-Conjectura-de-Hadwiger.py
209 lines (167 loc) · 7.95 KB
/
TransNAR-GCN-Conjectura-de-Hadwiger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch_geometric.nn import GCNConv
from torch_geometric.data import DataLoader
from torch_geometric.utils import to_dense_adj
class SubgraphDetectorModule(nn.Module):
def __init__(self, embed_dim, k):
super(SubgraphDetectorModule, self).__init__()
self.k = k
self.fc = nn.Linear(embed_dim, 1) # Saída binária para indicar a presença de K_(k+1)
def forward(self, x, edge_index):
# Suponha que x contém embeddings dos nós e edge_index contém as arestas
# Detectar subgrafos completos K_(k+1)
# Simplificação: apenas uma camada linear para determinar a presença
x = x.mean(dim=0, keepdim=True) # Simples média dos embeddings
subgraph_scores = self.fc(x)
return subgraph_scores
class PositionalEncoding(nn.Module):
def __init__(self, embed_dim, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, embed_dim)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * -(math.log(10000.0) / embed_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class TransformerLayer(nn.Module):
def __init__(self, embed_dim, num_heads, ffn_dim, dropout):
super(TransformerLayer, self).__init__()
self.self_attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
self.feed_forward = nn.Sequential(
nn.Linear(embed_dim, ffn_dim),
nn.ReLU(),
nn.Linear(ffn_dim, embed_dim)
)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
attn_output, _ = self.self_attention(x, x, x)
x = self.norm1(x + self.dropout(attn_output))
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
class GCNModule(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(GCNModule, self).__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = self.conv2(x, edge_index)
return x
class NAR(nn.Module):
def __init__(self, embed_dim, hidden_dim, gcn_out_dim):
super(NAR, self).__init__()
self.gcn = GCNModule(embed_dim, hidden_dim, gcn_out_dim)
self.mlp = nn.Sequential(
nn.Linear(gcn_out_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, embed_dim)
)
def forward(self, x, edge_index):
x = self.gcn(x, edge_index)
return self.mlp(x)
class CrossAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super(CrossAttention, self).__init__()
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
self.norm = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value):
attn_output, _ = self.multihead_attn(query, key, value)
return self.norm(query + self.dropout(attn_output))
class TransNAR_GCN(nn.Module):
def __init__(self, input_dim, output_dim, embed_dim, num_heads, num_layers, ffn_dim, gcn_hidden_dim, gcn_out_dim, k, dropout=0.1):
super(TransNAR_GCN, self).__init__()
self.embedding = nn.Linear(input_dim, embed_dim)
self.pos_encoding = PositionalEncoding(embed_dim, dropout)
self.transformer_layers = nn.ModuleList([
TransformerLayer(embed_dim, num_heads, ffn_dim, dropout)
for _ in range(num_layers)
])
self.nar = NAR(embed_dim, gcn_hidden_dim, gcn_out_dim)
self.subgraph_detector = SubgraphDetectorModule(embed_dim, k)
self.cross_attention = CrossAttention(embed_dim, num_heads, dropout)
self.decoder = nn.Linear(embed_dim, output_dim)
self.final_norm = nn.LayerNorm(output_dim)
self.initialize_weights()
self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.MultiheadAttention):
nn.init.normal_(m.in_proj_weight, std=0.02)
nn.init.normal_(m.out_proj.weight, std=0.02)
def forward(self, x, edge_index):
x = self.embedding(x)
x = self.pos_encoding(x)
for layer in self.transformer_layers:
x = layer(x)
nar_output = self.nar(x, edge_index)
# Detectar subgrafos completos
subgraph_scores = self.subgraph_detector(x, edge_index)
# Analisar a presença de subgrafos completos e a estrutura global
x = self.cross_attention(x, nar_output, nar_output)
output = self.decoder(x)
output = self.final_norm(output)
return output, subgraph_scores
def train_model(self, train_loader, val_loader, num_epochs):
for epoch in range(num_epochs):
self.train()
train_loss = 0
for batch in train_loader:
self.optimizer.zero_grad()
output, subgraph_scores = self(batch.x, batch.edge_index)
# Calculando a perda apenas com base na saída principal (não em subgraph_scores)
loss = F.mse_loss(output, batch.y)
loss.backward()
self.optimizer.step()
train_loss += loss.item()
self.eval()
val_loss = 0
with torch.no_grad():
for batch in val_loader:
output, subgraph_scores = self(batch.x, batch.edge_index)
# Calculando a perda apenas com base na saída principal (não em subgraph_scores)
loss = F.mse_loss(output, batch.y)
val_loss += loss.item()
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}")
torch.save(self.state_dict(), f'transnar_gcn_checkpoint_epoch_{epoch+1}.pth')
# Exemplo de uso
input_dim = 1 # Ajustado para as características dos nós gerados
output_dim = 1 # Ajustado para a saída desejada
embed_dim = 256
num_heads = 8
num_layers = 6
ffn_dim = 1024
gcn_hidden_dim = 128
gcn_out_dim = 256
k = 4 # Para detectar subgrafos completos K_(k+1), com k = 4 é K_5
model = TransNAR_GCN(input_dim, output_dim, embed_dim, num_heads, num_layers, ffn_dim, gcn_hidden_dim, gcn_out_dim, k)
# Carregar dados gerados
graphs_with_k_clique = torch.load('graphs_with_k_clique.pt')
graphs_without_k_clique = torch.load('graphs_without_k_clique.pt')
# Adiciona rótulos para os dados (1 para presença de K_(k+1), 0 para ausência)
for data in graphs_with_k_clique:
data.y = torch.tensor([1], dtype=torch.float)
for data in graphs_without_k_clique:
data.y = torch.tensor([0], dtype=torch.float)
# Criar DataLoader
train_data = graphs_with_k_clique + graphs_without_k_clique
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
# Simulando um conjunto de dados de validação
val_data = graphs_with_k_clique[:100] + graphs_without_k_clique[:100]
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)
# Treinamento do modelo
model.train_model(train_loader, val_loader, num_epochs=100)