๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
Study/Graph Neural Network

[5] GNN(GraphSAGE/ GAT) - Cora ๋…ผ๋ฌธ ์ธ์šฉ Network

by emilia park 2025. 7. 2.
728x90

๐Ÿ“Š ์„ฑ๋Šฅ ๋น„๊ต (์ผ๋ฐ˜์ ์œผ๋กœ)

๋ชจ๋ธํ•ต์‹ฌ ์•„์ด๋””์–ด์„ฑ๋ŠฅํŠน์ง•
GCN ํ‰๊ท  ์ด์›ƒ ์ •๋ณด (์ •๊ทœํ™”๋œ ํ•ฉ์‚ฐ) ๊ธฐ๋ณธ ๋น ๋ฆ„, ๊ฐ„๋‹จ
GraphSAGE ์ด์›ƒ ์ •๋ณด ์ง‘๊ณ„ (mean / LSTM ๋“ฑ) โ˜…โ˜…โ˜† inductive ๊ฐ€๋Šฅ
GAT ์ด์›ƒ๋งˆ๋‹ค attention ๊ฐ€์ค‘์น˜ ์ ์šฉ โ˜…โ˜…โ˜… ์„ฑ๋Šฅ ์ข‹์ง€๋งŒ ๋А๋ฆผ

 

๐Ÿง  ๋ชฉํ‘œ

  • Cora ๋ฐ์ดํ„ฐ๋กœ GrapheSAGE ๊ตฌํ˜„ 
  • Cora ๋ฐ์ดํ„ฐ๋กœ GAT ๊ตฌํ˜„ 

 

๐Ÿ’กGraphSAGE (Graph Sample and Aggregation)

์ด์›ƒ ๋…ธ๋“œ์˜ ์ •๋ณด๋ฅผ ์ง‘๊ณ„(aggregate) ํ•ด์„œ ๋…ธ๋“œ ์ž„๋ฒ ๋”ฉ์„ ์—…๋ฐ์ดํŠธ ํ•˜๋Š” ๋ฐฉ์‹

๐Ÿง  ํ•ต์‹ฌ ์•„์ด๋””์–ด:

  • ์ด์›ƒ ๋…ธ๋“œ๋“ค์„ ๋‹จ์ˆœ ํ‰๊ท (MEAN), LSTM, MAX ๋“ฑ์œผ๋กœ ์ง‘๊ณ„
  • ์ด๋ ‡๊ฒŒ ์ง‘๊ณ„ํ•œ ๊ฐ’์„ ์ž๊ธฐ ์ž์‹ ์˜ ๊ฐ’๊ณผ ๊ฒฐํ•ฉ(concat)ํ•˜์—ฌ ๋‹ค์Œ ๋ ˆ์ด์–ด๋กœ ์ „๋‹ฌ

๐Ÿ“ฆ ํŠน์ง•

Inductive ํ•™์Šต ์‹œ ์•ˆ ๋ณธ ๋…ธ๋“œ์—๋„ ์ผ๋ฐ˜ํ™” ๊ฐ€๋Šฅ!
๋‹ค์–‘ํ•œ Aggregator ์ง€์› mean, max, LSTM, pooling ๋“ฑ
๋น ๋ฅด๊ณ  ์œ ์—ฐํ•จ ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ํ•™์Šต๋„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ์„ค๊ณ„๋จ

 

 

#Cora + GraphSAGE

import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import SAGEConv
  • SAGEConv: PyG์—์„œ ์ œ๊ณตํ•˜๋Š” GraphSAGE ๋ ˆ์ด์–ด์ž…๋‹ˆ๋‹ค.

 

dataset = Planetoid(root='data/', name='Cora')
data = dataset[0]
  • Cora๋Š” ๋…ผ๋ฌธ ์ธ์šฉ ๊ทธ๋ž˜ํ”„์ž…๋‹ˆ๋‹ค.
  • data.x: ๋…ผ๋ฌธ ํ”ผ์ฒ˜ (1433์ฐจ์›)
  • data.edge_index: ์ธ์šฉ ๊ด€๊ณ„ (2, 10556)
  • data.y: ์ •๋‹ต ๋ ˆ์ด๋ธ” (์ฃผ์ œ 0~6)
  • data.train_mask, data.test_mask: ํ•™์Šต/ํ…Œ์ŠคํŠธ์— ์‚ฌ์šฉํ•  ๋…ธ๋“œ
class GraphSAGE(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = SAGEConv(dataset.num_node_features, 32)
        self.conv2 = SAGEConv(32, dataset.num_classes)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

โœ”๏ธ ๋ชจ๋ธ ๊ตฌ์กฐ

๊ณ„์ธต์—ญํ• 
conv1 ์ž…๋ ฅ 1433 → ์ค‘๊ฐ„ ํ”ผ์ฒ˜ 32์ฐจ์›
conv2 32์ฐจ์› → ํด๋ž˜์Šค ์ˆ˜(7) ์ถœ๋ ฅ
relu ๋น„์„ ํ˜•์„ฑ ์ฃผ๊ธฐ

 

 

model = GraphSAGE()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

 

  • Adam: ๊ฐ€์ค‘์น˜ ์—…๋ฐ์ดํŠธ ์•Œ๊ณ ๋ฆฌ์ฆ˜
  • lr: ํ•™์Šต๋ฅ 
  • weight_decay: ๊ณผ์ ํ•ฉ ๋ฐฉ์ง€ ์ •๊ทœํ™”

 

for epoch in range(200):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

 

model.train() ํ•™์Šต ๋ชจ๋“œ๋กœ ์„ค์ •
optimizer.zero_grad() ๊ธฐ์šธ๊ธฐ ์ดˆ๊ธฐํ™”
model(...) forward ์‹คํ–‰ → ์˜ˆ์ธก
cross_entropy(...) ์˜ˆ์ธก๊ณผ ์ •๋‹ต ๋น„๊ตํ•ด ์†์‹ค ๊ณ„์‚ฐ
loss.backward() ์—ญ์ „ํŒŒ: ๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ
optimizer.step() ๊ฐ€์ค‘์น˜ ์—…๋ฐ์ดํŠธ

 

 

model.eval()
pred = model(data.x, data.edge_index).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f"[GraphSAGE] Test Accuracy: {acc:.4f}")

 

model.eval() ํ‰๊ฐ€ ๋ชจ๋“œ ์„ค์ •
argmax(dim=1) ๊ฐ ๋…ธ๋“œ๋งˆ๋‹ค ๊ฐ€์žฅ ๋†’์€ ํด๋ž˜์Šค ์„ ํƒ
correct == y ์ •๋‹ต๊ณผ ๋น„๊ตํ•ด ๋งž์ถ˜ ๊ฐœ์ˆ˜ ๊ณ„์‚ฐ
acc ์ •ํ™•๋„ (์ •๋‹ต ์ˆ˜ / ์ „์ฒด ์ˆ˜)

 

 

๐Ÿง  ํ•œ๋ˆˆ์— ์š”์•ฝ

๊ตฌ์„ฑ ์š”์†ŒGraphSAGE ์„ค๋ช…
๋ชจ๋ธ ๊ตฌ์กฐ ์ด์›ƒ ์ •๋ณด๋ฅผ ํ‰๊ท  → ์ž์‹ ๊ณผ ๊ฒฐํ•ฉ
PyG ๋ ˆ์ด์–ด SAGEConv
์žฅ์  ๋น ๋ฅด๊ณ  inductive (์ƒˆ๋กœ์šด ๋…ธ๋“œ์—๋„ ์ ์šฉ ๊ฐ€๋Šฅ)
์‚ฌ์šฉ ๋ฐฉ์‹ GCN๊ณผ ๊ฑฐ์˜ ๋™์ผํ•˜๊ฒŒ ์‚ฌ์šฉ ๊ฐ€๋Šฅ
์„ฑ๋Šฅ GCN๋ณด๋‹ค ์ผ๋ฐ˜์ ์œผ๋กœ ๋” ์œ ์—ฐํ•˜๊ณ , ์ข…์ข… ๋” ์ •ํ™•ํ•จ

 

 

 

 

๐Ÿ’กGAT (Graph Attention Network)

๐Ÿง  1. GAT๋ž€?

**Graph Attention Network (GAT)**๋Š” ๊ทธ๋ž˜ํ”„ ์‹ ๊ฒฝ๋ง(GNN)์˜ ํ•œ ์ข…๋ฅ˜๋กœ,
**๋…ธ๋“œ ๊ฐ„ ์—ฐ๊ฒฐ(edge)**์— ๋Œ€ํ•ด ‘์ค‘์š”๋„(attention)’๋ฅผ ์Šค์Šค๋กœ ํ•™์Šตํ•˜์—ฌ ์ •๋ณด ์ „ํŒŒ๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.

GAT์€ ๊ธฐ์กด GNN(GCN ๋“ฑ)์ด ๊ฐ€์ง€๋Š” ํ•œ๊ณ„๋ฅผ ๊ฐœ์„ ํ•˜๊ธฐ ์œ„ํ•ด ์ œ์•ˆ๋˜์—ˆ์œผ๋ฉฐ,
๊ทธ๋ž˜ํ”„์˜ ์ด์›ƒ ๋…ธ๋“œ๋“ค ์ค‘ ์–ด๋–ค ๋…ธ๋“œ์˜ ์ •๋ณด๊ฐ€ ๋” ์ค‘์š”ํ•œ์ง€ ๋™์ ์œผ๋กœ ํŒ๋‹จํ•  ์ˆ˜ ์žˆ๋„๋ก ์„ค๊ณ„๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

โœ… ์™œ Attention์ด ํ•„์š”ํ•œ๊ฐ€?

๊ธฐ์กด GCN์€ ์ด์›ƒ ๋…ธ๋“œ์˜ ์ •๋ณด๋ฅผ ๋‹จ์ˆœํžˆ **ํ‰๊ท (average) ๋˜๋Š” ํ•ฉ(sum)**ํ•ด์„œ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
→ ์ฆ‰, ๋ชจ๋“  ์ด์›ƒ ์ •๋ณด๋ฅผ ๋™์ผํ•˜๊ฒŒ ๊ฐ„์ฃผํ•ฉ๋‹ˆ๋‹ค.

ํ•˜์ง€๋งŒ ํ˜„์‹ค ์„ธ๊ณ„์˜ ๊ทธ๋ž˜ํ”„์—์„œ๋Š” ๋ชจ๋“  ์ด์›ƒ์ด ๋™๋“ฑํ•˜๊ฒŒ ์ค‘์š”ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

์˜ˆ:

  • SNS์—์„œ ์นœ๊ตฌ ์ค‘ ์˜ํ–ฅ๋ ฅ ์žˆ๋Š” ์‚ฌ์šฉ์ž์˜ ์ •๋ณด๊ฐ€ ๋” ์ค‘์š”ํ•  ์ˆ˜ ์žˆ์Œ
  • ํ™”ํ•™ ๋ถ„์ž ๊ตฌ์กฐ์—์„œ ํŠน์ • ์›์ž๋Š” ๋ฐ˜์‘์„ฑ์— ๋” ์ค‘์š”ํ•œ ์—ญํ• ์„ ํ•  ์ˆ˜ ์žˆ์Œ

GAT์€ ์ด๋Ÿฐ ์ƒํ™ฉ์„ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด, "์–ด๋–ค ์ด์›ƒ์˜ ์ •๋ณด์— ๋” ์ง‘์ค‘ํ•  ๊ฒƒ์ธ์ง€"๋ฅผ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.

 

๐Ÿ“Œ ๊ธฐ์กด GCN/SAGE vs GAT์˜ ์ฐจ์ด์ 

GCN / GraphSAGE                                                        GAT
์ด์›ƒ ์ •๋ณด๋ฅผ ํ‰๊ท ํ•˜๊ฑฐ๋‚˜ ๊ณ ์ • ๋ฐฉ์‹์œผ๋กœ ํ•ฉ์นจ ์ด์›ƒ๋งˆ๋‹ค "์–ผ๋งˆ๋‚˜ ์ค‘์š”ํ•œ์ง€" attention ๊ฐ€์ค‘์น˜๋ฅผ ๊ณ„์‚ฐํ•ด์„œ ํ•ฉ์นจ
๋ชจ๋“  ์ด์›ƒ์ด ๋™๋“ฑํ•˜๊ฒŒ ์ทจ๊ธ‰๋จ ์ค‘์š”ํ•œ ์ด์›ƒ์€ ๋” ๋งŽ์ด ๋ฐ˜์˜, ๋œ ์ค‘์š”ํ•œ ์ด์›ƒ์€ ๋œ ๋ฐ˜์˜
 

๐Ÿ’ก ์ง๊ด€

“์ค‘์š”ํ•œ ์ด์›ƒ์€ ๋ง์„ ๋” ๊ท€๋‹ด์•„ ๋“ฃ๊ณ , ๋œ ์ค‘์š”ํ•œ ์ด์›ƒ์€ ๋ฌด์‹œํ•œ๋‹ค”
→ ๋…ผ๋ฌธ ์ธ์šฉ ๋„คํŠธ์›Œํฌ์—์„œ "ํ•ต์‹ฌ ๋…ผ๋ฌธ"์€ ๋” ์˜ํ–ฅ๋ ฅ์„ ๊ฐ–๊ฒŒ ๋จ

 

 

๐Ÿ’ก ์ดˆ๊ธฐ ์…‹ํŒ… 

##GAT
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
# 1. ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
dataset = Planetoid(root='data/', name='Cora')
data = dataset[0]

 

๐Ÿ’ก GAT ๋ชจ๋ธ ์ •์˜ 

# 2. GAT ๋ชจ๋ธ ์ •์˜
class GAT(torch.nn.Module):
    def __init__(self, use_dropout=True):
        super().__init__()
        self.use_dropout = use_dropout
        self.gat1 = GATConv(dataset.num_node_features, 8, heads=8, dropout=0.6)
        self.gat2 = GATConv(8 * 8, dataset.num_classes, heads=1, concat=False, dropout=0.6)

    def forward(self, x, edge_index):
        if self.use_dropout:
            x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.gat1(x, edge_index))
        if self.use_dropout:
            x = F.dropout(x, p=0.6, training=self.training)
        x = self.gat2(x, edge_index)
        return x

๐Ÿ’ก ํ•™์Šต ํ•จ์ˆ˜

# 3. ํ•™์Šต ํ•จ์ˆ˜
def train(model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
    model.train()
    for epoch in range(200):
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

๐Ÿ’ก ์‹œ๊ฐํ™” ํ•จ์ˆ˜ ์ •์˜

# 4. ์‹œ๊ฐํ™” ํ•จ์ˆ˜
def visualize(model, title=''):
    model.eval()
    with torch.no_grad():
        x = data.x
        x = model.gat1(x, data.edge_index)  # ์ค‘๊ฐ„ ์ž„๋ฒ ๋”ฉ ์ถ”์ถœ
    z = TSNE(n_components=2).fit_transform(x.cpu().numpy())
    y = data.y.cpu().numpy()
    plt.figure(figsize=(8, 6))
    plt.title(title)
    for i in range(dataset.num_classes):
        plt.scatter(z[y == i, 0], z[y == i, 1], label=f"Class {i}", alpha=0.6)
    plt.legend()
    plt.grid(True)
    plt.show()

1. model.eval()

  • ํ•™์Šต์ด ์•„๋‹Œ “ํ…Œ์ŠคํŠธ/์‹œ๊ฐํ™”” ๋ชจ๋“œ๋กœ ์ „ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
  • Dropout, BatchNorm ๊ฐ™์€ ๊ณ„์ธต์ด ๋‹ค๋ฅธ ๋ฐฉ์‹์œผ๋กœ ์ž‘๋™ํ•˜๋ฏ€๋กœ ๋ฐ˜๋“œ์‹œ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
  • ์ฆ‰, ํ•ญ์ƒ ์ผ์ •ํ•œ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์žฅํ•ด์ฃผ๊ธฐ ์œ„ํ•ด ์”๋‹ˆ๋‹ค.

2. with torch.no_grad():

  • PyTorch์˜ ์—ฐ์‚ฐ ๊ธฐ๋ก ๊ธฐ๋Šฅ์„ ๋”
  • ์šฐ๋ฆฌ๊ฐ€ ์ง€๊ธˆ์€ "ํ•™์Šต"์ด ์•„๋‹ˆ๋ผ **“๊ฒฐ๊ณผ ๋ณด๊ธฐ”**๋งŒ ํ•˜๋‹ˆ๊นŒ
    → GPU ๋ฉ”๋ชจ๋ฆฌ ์•„๋ผ๊ณ , ์†๋„๋„ ๋นจ๋ผ์ง‘๋‹ˆ๋‹ค.

3. x = data.x

  • Cora ๋ฐ์ดํ„ฐ์—์„œ ๊ฐ ๋…ผ๋ฌธ์„ ๋‚˜ํƒ€๋‚ด๋Š” ๋ฒกํ„ฐ (feature)
  • shape: [2708, 1433]
    → 2708๊ฐœ์˜ ๋…ธ๋“œ(๋…ผ๋ฌธ), ๊ฐ๊ฐ 1433์ฐจ์› (๋‹จ์–ด ๊ธฐ์ค€ bag-of-words)

4. x = model.gat1(x, data.edge_index)

  • gat1์€ GATConv ๋ ˆ์ด์–ด์ž…๋‹ˆ๋‹ค.
  • ์—ฌ๊ธฐ์„œ ํ•˜๋Š” ์ผ:
    • ๊ฐ ๋…ธ๋“œ๊ฐ€ ์ž๊ธฐ ์ด์›ƒ(์ธ์šฉํ•œ ๋…ผ๋ฌธ)์˜ ์ •๋ณด๋ฅผ ๋ณด๊ณ 
    • attention score๋ฅผ ๊ณ„์‚ฐํ•ด์„œ,
    • ์ž์‹ ๋งŒ์˜ ๋ฒกํ„ฐ ํ‘œํ˜„์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
  • ๊ฒฐ๊ณผ: shape [2708, 64]
    → ๊ฐ ๋…ธ๋“œ๊ฐ€ 64์ฐจ์›์˜ ๋ฒกํ„ฐ๋กœ ์••์ถ•๋จ

5. z = TSNE(n_components=2).fit_transform(...)

โœ… TSNE๋ž€?

  • ๊ณ ์ฐจ์› ๋ฒกํ„ฐ(64์ฐจ์›)๋ฅผ 2์ฐจ์›์œผ๋กœ ์ค„์—ฌ์„œ ์‹œ๊ฐํ™”ํ•˜๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค.
  • ๋‹จ์ˆœํžˆ PCA๋ณด๋‹ค ๋” ๊ตฐ์ง‘ ๊ตฌ์กฐ๋ฅผ ์ž˜ ๋ณด์กดํ•ฉ๋‹ˆ๋‹ค.
  • Input : 2708๊ฐœ ๋ฒกํ„ฐ (๊ฐ๊ฐ 64์ฐจ์›)

    TSNE → ๋น„์Šทํ•œ ๋ฒกํ„ฐ๋ผ๋ฆฌ ๊ฐ€๊น๊ฒŒ, ๋‹ค๋ฅธ ๋ฒกํ„ฐ๋Š” ๋ฉ€๋ฆฌ

    Output : 2708๊ฐœ 2D ์ขŒํ‘œ (๊ฐ๊ฐ [x, y])
  • ๊ฒฐ๊ณผ shape: [2708, 2]
  • ์ด๊ฑธ ํ‰๋ฉด(x, y) ์œ„์— ๋ฟŒ๋ฆฝ๋‹ˆ๋‹ค.

6. y = data.y.cup().numpy()

  • ๊ฐ ๋…ธ๋“œ์˜ **์ •๋‹ต ๋ผ๋ฒจ (0~6 ํด๋ž˜์Šค)**์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
  • shape: [2708]

7. Plotting with matplotlib

for i in range(dataset.num_classes):
    plt.scatter(z[y == i, 0], z[y == i, 1], label=f"Class {i}", alpha=0.6)

 

  • ํด๋ž˜์Šค๋ณ„๋กœ ๋‚˜๋ˆ ์„œ ๋‹ค๋ฅธ ์ƒ‰๊น”๋กœ ์ ๋“ค์„ ์ฐ์–ด์š”
  • z[y == i]: ํด๋ž˜์Šค i์— ์†ํ•œ ๋…ธ๋“œ๋“ค๋งŒ ์„ ํƒ
  • .scatter(...): ํ•ด๋‹น ๋…ธ๋“œ๋“ค์˜ x, y ์ขŒํ‘œ๋ฅผ 2D์— ์ฐ๊ธฐ

๐Ÿ” ์™œ gat1๊นŒ์ง€๋งŒ?

model.gat1(...)๊นŒ์ง€๋งŒ ํ†ต๊ณผ์‹œํ‚จ ์ด์œ ๋Š”
์ถœ๋ ฅ์ธต(gat2)์€ softmax ์ด์ „์˜ ๋กœ์ง“์ด๋ฏ€๋กœ, ๊ตฐ์ง‘ ํ˜•ํƒœ๊ฐ€ ๋ช…ํ™•ํ•˜์ง€ ์•Š์•„์š”.

gat1์€ "๋‚ด๊ฐ€ ์ด์›ƒ ์ •๋ณด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์–ด๋–ค ๋ฒกํ„ฐ๋กœ ๋ณ€ํ–ˆ๋Š”์ง€"๋ฅผ ๋ณด์—ฌ์ฃผ๋Š” ์ค‘๊ฐ„ ํ‘œํ˜„์ด๋ผ
๋‚ด๋ถ€ ํ‘œํ˜„ ๊ณต๊ฐ„์ด ์–ด๋–ป๊ฒŒ ํ˜•์„ฑ๋๋Š”์ง€ ํ™•์ธํ•˜๊ธฐ์— ์ข‹์Šต๋‹ˆ๋‹ค.

โœ… ์š”์•ฝ ์ •๋ฆฌ

๊ตฌ์„ฑ์„ค๋ช…
model.gat1(...) GAT ์ฒซ ๋ ˆ์ด์–ด์˜ ์ž„๋ฒ ๋”ฉ ์ถ”์ถœ
TSNE(n_components=2) 64์ฐจ์› → 2์ฐจ์› ์ถ•์†Œ
scatter(...) ๊ฐ ํด๋ž˜์Šค๋ฅผ ๋‹ค๋ฅธ ์ƒ‰์œผ๋กœ ์ฐ๊ธฐ
๋ชฉ์  ๋ชจ๋ธ์ด ๋…ธ๋“œ๋ฅผ ์–ผ๋งˆ๋‚˜ ์ž˜ ๊ตฌ๋ถ„ํ•˜๋Š”์ง€ ์‹œ๊ฐ์ ์œผ๋กœ ํ™•์ธ

 

 

# 5. ํ•™์Šต ๋ฐ ์‹œ๊ฐํ™”
# (1) Dropout ํฌํ•จ
model_dropout = GAT(use_dropout=True)
train(model_dropout)
visualize(model_dropout, title='GAT with Dropout')

# (2) Dropout ์—†์ด
model_nodrop = GAT(use_dropout=False)
train(model_nodrop)
visualize(model_nodrop, title='GAT without Dropout')

 

 

 

 

 

[์ฐธ๊ณ ์‚ฌํ•ญ]

๐ŸŽฏ GCN์ด๋ž‘ GraphSAGE ์ฐจ์ด๋Š” ? 

ํ•œ ์ค„ ์š”์•ฝ

๋ชจ๋ธ์ด์›ƒ ์ •๋ณด๋ฅผ ์–ด๋–ป๊ฒŒ ์ฒ˜๋ฆฌ?
GCN ์ด์›ƒ์˜ ์ •๋ณด๋ฅผ **ํ‰๊ท (MEAN)**ํ•ด์„œ ์—…๋ฐ์ดํŠธ
GraphSAGE ์ด์›ƒ ์ •๋ณด๋ฅผ **์ปค์Šคํ„ฐ๋งˆ์ด์ฆˆ๋œ ๋ฐฉ์‹(MEAN, LSTM, MAX ๋“ฑ)**์œผ๋กœ ์ง‘๊ณ„ํ•ด์„œ ์—…๋ฐ์ดํŠธ
๐Ÿง  ๊ตฌ์กฐ์  ์ฐจ์ด ์š”์•ฝ
ํ•ญ๋ชฉGCNGraphSAGE
ํ•ต์‹ฌ ์•„์ด๋””์–ด ์„ ํ˜• ๋ณ€ํ™˜ ํ›„ ์ •๊ทœํ™”๋œ ํ‰๊ท  ๋‹ค์–‘ํ•œ Aggregator ํ•จ์ˆ˜๋กœ ์ด์›ƒ ์ •๋ณด ์š”์•ฝ
์ž๊ธฐ ์ •๋ณด ํฌํ•จ YES YES
ํŒŒ๋ผ๋ฏธํ„ฐ ๊ณต์œ  YES (๊ณตํ†ต W) YES (๊ณตํ†ต W), Aggregator๋Š” ์„ ํƒ ๊ฐ€๋Šฅ
์—ฐ์‚ฐ ๋ฐฉ์‹ "์ •๊ทœํ™”๋œ ํ•ฉ์„ฑ๊ณฑ"์ฒ˜๋Ÿผ ๋™์ž‘ "์ƒ˜ํ”Œ๋ง + Aggregation" ๊ตฌ์กฐ
๋Œ€ํ‘œ Aggregator ํ‰๊ท (MEAN)
MEAN, LSTM, MAX, POOL ๋“ฑ

๐Ÿ”Ž ์ง๊ด€์  ์˜ˆ์‹œ

์˜ˆ: ๋…ธ๋“œ A์˜ ์ด์›ƒ = {B, C}

๐ŸŸข GCN

  • A์˜ ์ƒˆ ํ”ผ์ฒ˜ = B, C, A์˜ ๋ฒกํ„ฐ๋ฅผ ํ‰๊ท  → ์„ ํ˜• ๋ณ€ํ™˜
  • "์ด์›ƒ ์ •๋ณด๋ฅผ ํ‰๋“ฑํ•˜๊ฒŒ ํ•ฉ์ณ์„œ ์ „๋‹ฌ"

๐Ÿ”ต GraphSAGE

  • A์˜ ์ƒˆ ํ”ผ์ฒ˜ =
    • (1) A ์ž์ฒด์˜ ๋ฒกํ„ฐ
    • (2) B, C์˜ ๋ฒกํ„ฐ๋ฅผ ํ‰๊ท 
      → ์ด ๋‘˜์„ ์ด์–ด๋ถ™์ธ ๋’ค → ์„ ํ˜•๋ณ€ํ™˜
  • "์ž๊ธฐ ์ •๋ณด + ์ด์›ƒ ์ •๋ณด๋ฅผ ๋‚˜๋ˆ ์„œ ์ฒ˜๋ฆฌ"

 

 

 


๐ŸŽฏ Dropout์ด๋ž€?

Dropout์€ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์˜ ๊ณผ์ ํ•ฉ(overfitting)์„ ๋ง‰๊ธฐ ์œ„ํ•ด ๊ณ ์•ˆ๋œ ์ •๊ทœํ™” ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค.
ํ•™์Šตํ•  ๋•Œ ๋ฌด์ž‘์œ„๋กœ ์ผ๋ถ€ ๋‰ด๋Ÿฐ์„ ๊บผ์„œ(= 0์œผ๋กœ ๋งŒ๋“ค์–ด์„œ) ํ•™์Šตํ•˜๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค.

 

๐Ÿ“Œ ์™œ ํ•„์š”ํ•œ๊ฐ€์š”?

์‹ ๊ฒฝ๋ง์€ ํ•™์Šต ๋ฐ์ดํ„ฐ์— ๋„ˆ๋ฌด ์ž˜ ๋งž๊ฒŒ ๋˜๋ฉด,
→ **์ƒˆ๋กœ์šด ๋ฐ์ดํ„ฐ์— ์ผ๋ฐ˜ํ™”๊ฐ€ ์ž˜ ์•ˆ ๋˜๋Š” ํ˜„์ƒ (๊ณผ์ ํ•ฉ)**์ด ๋ฐœ์ƒํ•ฉ๋‹ˆ๋‹ค.

์˜ˆ: ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ๋งŒ ์™ธ์šฐ๊ณ , ์‹œํ—˜์€ ๋ชป ๋ณด๋Š” ์ƒํƒœ

๊ทธ๋ž˜์„œ Dropout์„ ์“ฐ๋ฉด,
ํ•™์Šต ์ค‘ ์ผ๋ถ€ ๋‰ด๋Ÿฐ์„ ๊บผ์„œ ๊ฐ•์ œ๋กœ ๋ชจ๋ธ์ด “๋” ์ผ๋ฐ˜์ ์ธ” ํŒจํ„ด์„ ํ•™์Šตํ•˜๋„๋ก ์œ ๋„ํ•ฉ๋‹ˆ๋‹ค.

 

๐Ÿงช ํ•™์Šต vs ํ…Œ์ŠคํŠธ

๋‹จ๊ณ„dropout ์‚ฌ์šฉ์„ค๋ช…
ํ•™์Šต (train) โœ… ์‚ฌ์šฉ ์ผ๋ถ€ ๋‰ด๋Ÿฐ์„ ๋žœ๋คํ•˜๊ฒŒ ๋”
ํ‰๊ฐ€ (eval) โŒ ๋ฏธ์‚ฌ์šฉ ๋ชจ๋“  ๋‰ด๋Ÿฐ ์‚ฌ์šฉ (ํ‰๊ท ํ™”๋จ)

๐Ÿงฉ Cora ๋ฐ์ดํ„ฐ ๊ตฌ์กฐ ๋ณต์Šต

  • data.x: 1433์ฐจ์› ๋ฒกํ„ฐ (๋…ผ๋ฌธ์—์„œ ์–ด๋–ค ๋‹จ์–ด๊ฐ€ ๋“ฑ์žฅํ–ˆ๋Š”์ง€)
  • ์˜ˆ:
    data.x[0] = [0, 1, 0, 0, 1, ..., 0] # ๋…ผ๋ฌธ 0์€ ๋‹จ์–ด2, ๋‹จ์–ด5 ๋“ฑ์žฅ
  • GAT ๋ชจ๋ธ์€ ์ด๊ฑธ ๋ฐ›์•„์„œ ๋…ธ๋“œ ์ž„๋ฒ ๋”ฉ์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค.

๐Ÿ”ฅ Dropout์ด ์ž‘๋™ํ•˜๋Š” ๋ถ€๋ถ„

x = F.dropout(x, p=0.6, training=self.training)

์ฆ‰,

  • x๋Š” ๊ฐ ๋…ผ๋ฌธ(๋…ธ๋“œ)์˜ 1433์ฐจ์› ํ”ผ์ฒ˜ ๋ฒกํ„ฐ
  • ์ด๊ฑธ ํ•™์Šต ์ค‘์—๋Š” 60% ํ™•๋ฅ ๋กœ ๋žœ๋คํ•˜๊ฒŒ ์ผ๋ถ€ ์ฐจ์›์„ ๊บผ๋ฒ„๋ฆฝ๋‹ˆ๋‹ค
  • ์ฆ‰, ๋…ผ๋ฌธ 0์˜ ํ”ผ์ฒ˜ ์ค‘ ์ผ๋ถ€ ๋‹จ์–ด ์ •๋ณด๋Š” "์—†๋Š” ๊ฒƒ์ฒ˜๋Ÿผ" ํ•™์Šตํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

 

๐ŸŽฏ ์™œ ์ด๋ ‡๊ฒŒ ํ•˜๋ƒ?

์ด์œ :
๐Ÿ‘‰ ํŠน์ • ํ”ผ์ฒ˜(๋‹จ์–ด)์— ๋„ˆ๋ฌด ์˜์กดํ•˜์ง€ ์•Š๋„๋ก ํ•˜๊ธฐ ์œ„ํ•ด์„œ์ž…๋‹ˆ๋‹ค.

๋งŒ์•ฝ dropout ์—†์ด ํ•™์Šตํ•˜๋ฉด,

  • ์–ด๋–ค ๋‹จ์–ด ํ•˜๋‚˜๊ฐ€ "๋‹ต"์„ ์•”์‹œํ•˜๋Š” ๊ฒฝ์šฐ,
  • ๋ชจ๋ธ์ด ๊ทธ ๋‹จ์–ด๋งŒ ์™ธ์šฐ๊ฒŒ ๋จ → ์ผ๋ฐ˜ํ™” ์ž˜ ์•ˆ ๋จ

๋ฐ˜๋ฉด dropout์„ ์“ฐ๋ฉด,

  • ๊ทธ ๋‹จ์–ด๊ฐ€ ๋žœ๋คํ•˜๊ฒŒ ์‚ฌ๋ผ์ง€๊ธฐ๋„ ํ•จ
  • → ๋ชจ๋ธ์€ ๋‹ค์–‘ํ•œ ๋‹จ์–ด ์กฐํ•ฉ์—์„œ ์ž˜ ์ž‘๋™ํ•ด์•ผ ํ•จ

 

data.x → Dropout → GATConv (attention) → Dropout → GATConv → ๊ฒฐ๊ณผ
์œ„์น˜ ์—ญํ• 
์ฒซ ๋ฒˆ์งธ Dropout ์ž…๋ ฅ ํ”ผ์ฒ˜ ์ผ๋ถ€ ์ œ๊ฑฐํ•ด์„œ ๊ฐ•๊ฑดํ•œ ํ‘œํ˜„ ์œ ๋„
๋‘ ๋ฒˆ์งธ Dropout ์ค‘๊ฐ„ ์ž„๋ฒ ๋”ฉ ์ผ๋ถ€ ์ œ๊ฑฐ (์˜ค๋ฒ„ํ”ผํŒ… ๋ฐฉ์ง€)

 

 

x = F.dropout(x, p=0.6, training=True)

์ด๊ฑด ๋‚ด๋ถ€์ ์œผ๋กœ ๋‹ค์Œ์„ ํ•ฉ๋‹ˆ๋‹ค:

  1. x์™€ ๊ฐ™์€ shape์˜ ๋งˆ์Šคํฌ mask๋ฅผ ์ƒ์„ฑ (1 ๋˜๋Š” 0)
  2. ๊ฐ ์›์†Œ์— ๋Œ€ํ•ด:
    • 60% ํ™•๋ฅ ๋กœ 0
    • 40% ํ™•๋ฅ ๋กœ 1 / 0.4 (= 2.5)๋กœ ์Šค์ผ€์ผ

์ฆ‰, ๋‹ค์Œ์ฒ˜๋Ÿผ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค:

์›๋ž˜ x๋งˆ์Šคํฌ๋“œ๋กญ์•„์›ƒ ํ›„
1.0 0 0.0
0.0 1 0.0
1.0 1 1.0 / 0.4 ≈ 2.5
0.0 0 0.0
 

๐Ÿ‘‰ 0์ธ ๊ฐ’์€ ๋ฌด์กฐ๊ฑด 0 ๊ทธ๋Œ€๋กœ ์œ ์ง€,
๐Ÿ‘‰ 1์ธ ๊ฐ’์€ ๋žœ๋คํ•˜๊ฒŒ ์‚ด๊ฑฐ๋‚˜ ์ฃฝ๊ณ , ์‚ด์•„๋‚จ์œผ๋ฉด 2.5๋กœ ์กฐ์ •๋ฉ๋‹ˆ๋‹ค (ํ•™์Šต ํ‰ํ˜• ์œ ์ง€์šฉ)

 

โœ‹ ์ค‘์š” ๊ฐœ๋…: Dropout์€ ๋ฌด์กฐ๊ฑด ๋ฌด์ž‘์œ„!

์˜คํ•ด์ง„์‹ค
"๊ฐ’์ด ์ž‘์œผ๋ฉด ๋“œ๋กญ๋˜๊ณ , ํฌ๋ฉด ์•ˆ ๋“œ๋กญ๋œ๋‹ค?" โŒ X
"์ค‘์š”ํ•˜์ง€ ์•Š์œผ๋ฉด ๋“œ๋กญ๋œ๋‹ค?" โŒ X
"๋ฌด์กฐ๊ฑด ๋žœ๋คํ•˜๊ฒŒ ๋“œ๋กญ๋œ๋‹ค" โœ… O

 

 

GAT๋Š”:

  • **attention ๊ณ„์ˆ˜ α<sub>ij</sub>**๋ฅผ ํ†ตํ•ด
    → ์ด์›ƒ ์ค‘ ์ค‘์š”ํ•œ ๋…ธ๋“œ์— ๋” ๋งŽ์€ ๊ฐ€์ค‘์น˜๋ฅผ ์ค๋‹ˆ๋‹ค.
  • ์ด ์ž์ฒด๋กœ ์ค‘์š”ํ•˜์ง€ ์•Š์€ ์ด์›ƒ์€ ๋ฌด์‹œํ•˜๋Š” ํšจ๊ณผ๊ฐ€ ์žˆ์–ด์š”.

Dropout์€:

  • ์ด attention ๊ณ„์ˆ˜ ๋˜๋Š” ์ž…๋ ฅ ํ”ผ์ฒ˜์— **๋ฌด์ž‘์œ„์„ฑ(randomness)**์„ ๋„ฃ์–ด์„œ
  • ํ•™์Šต์„ ๋” robustํ•˜๊ฒŒ ๋งŒ๋“ค๊ณ ,
  • ํŠน์ • ํ”ผ์ฒ˜๋‚˜ ์—ฐ๊ฒฐ์—๋งŒ ์˜์กดํ•˜์ง€ ์•Š๋„๋ก ๋„์™€์ค๋‹ˆ๋‹ค.

๊ทธ๋ž˜์„œ ์ •๋ฆฌํ•˜๋ฉด:

โœ… Dropout์€ GAT์˜ attention๊ณผ ์ƒํ˜ธ ๋ณด์™„์ ์ธ ์—ญํ• ์„ ํ•ฉ๋‹ˆ๋‹ค.
GAT์ด ์—ฐ๊ฒฐ์— ๊ฐ€์ค‘์น˜ ์ฐจ์ด๋ฅผ ๋‘๋Š” ๊ธฐ๋Šฅ์ด๋ผ๋ฉด,
Dropout์€ ์ด ์—ฐ๊ฒฐ์ด๋‚˜ ํ”ผ์ฒ˜ ์ž์ฒด์— ๋ฌด์ž‘์œ„์„ฑ์„ ๋ถ€์—ฌํ•ด์„œ ๊ณผ์ ํ•ฉ์„ ๋ฐฉ์ง€ํ•ฉ๋‹ˆ๋‹ค.

 

 


๐ŸŽฏ GAT ์—์„œ head๋ž€?

GATConv๋Š” **๋ฉ€ํ‹ฐ-ํ—ค๋“œ ์–ดํ…์…˜(Multi-head Attention)**์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

์ฆ‰, "์—ฌ๋Ÿฌ ๊ฐœ์˜ ๋…๋ฆฝ๋œ attention layer๋ฅผ ๋™์‹œ์— ํ•™์Šตํ•ด์„œ ๊ฒฐ๊ณผ๋ฅผ ํ•ฉ์นœ๋‹ค"

GAT์—์„œ head๋ž€, "๊ฐ๊ธฐ ๋‹ค๋ฅธ ๋ฐฉ์‹์œผ๋กœ ์ด์›ƒ ์ •๋ณด๋ฅผ ํ†ตํ•ฉํ•˜๋Š” ๋ณ‘๋ ฌ์ ์ธ attention ๊ณ„์‚ฐ๊ธฐ"์ž…๋‹ˆ๋‹ค.

๋ฌผ๋ฆฌ์ ์œผ๋กœ๋Š”:

  • ๊ฐ head๋Š” ๋…๋ฆฝ๋œ weight(๊ฐ€์ค‘์น˜)๋ฅผ ๊ฐ–๋Š” ํ•˜๋‚˜์˜ GNN ๋ ˆ์ด์–ด์˜ˆ์š”.
  • Cora ๋ฐ์ดํ„ฐ์—์„œ ๊ฐ head๋Š” **๋…ผ๋ฌธ(๋…ธ๋“œ)**๊ฐ€ ์ด์›ƒ ๋…ผ๋ฌธ์œผ๋กœ๋ถ€ํ„ฐ ์ •๋ณด๋ฅผ ๋ฐ›์•„์˜ฌ ๋•Œ,
    ๋‹ค๋ฅด๊ฒŒ ์ค‘์š”๋„๋ฅผ ํŒ๋‹จํ•ด์„œ” ์ •๋ณด๋ฅผ ์ทจํ•ฉํ•ฉ๋‹ˆ๋‹ค.

 

 

 

๐Ÿ” ์–ด๋–ป๊ฒŒ 64์ฐจ์›์ด ๋งŒ๋“ค์–ด์ง€๋‚˜?

1๊ฐœ์˜ GAT head๋Š”:

  • ์ด์›ƒ ์ •๋ณด๋ฅผ attention์œผ๋กœ ๋ฐ›์•„์„œ → out_channels = 8์ฐจ์› ์ถœ๋ ฅ

8๊ฐœ์˜ GAT head๋Š”:

  • ์ด ๊ณผ์ •์„ 8๋ฒˆ ๋…๋ฆฝ์ ์œผ๋กœ ๋ณ‘๋ ฌ ์ˆ˜ํ–‰
  • ๋งˆ์ง€๋ง‰์— **concat (์—ฐ๊ฒฐ)**ํ•ฉ๋‹ˆ๋‹ค.

Head 1: 8์ฐจ์›
Head 2: 8์ฐจ์›
...
Head 8: 8์ฐจ์›
-------------
Concat → ์ด 64์ฐจ์›

 

 

๐ŸŽฏ ํ•ต์‹ฌ ์ •๋ฆฌ

ํ•ญ๋ชฉ์˜๋ฏธ
๐Ÿ’ก head๋ž€? ์ด์›ƒ์„ ๋ฐ”๋ผ๋ณด๋Š” ๋…๋ฆฝ๋œ attention ๊ณ„์‚ฐ๊ธฐ
๐Ÿ“ฆ ๋‚ด๋ถ€์— ๋ญ ์žˆ์Œ? ๊ฐ head๋งˆ๋‹ค ๋…๋ฆฝ์ ์ธ weight, attention ๊ณ„์‚ฐ
๐Ÿ“Œ ์—ญํ•  ์ด์›ƒ ๋…ธ๋“œ์˜ ์ค‘์š”๋„๋ฅผ ๋‹ค๋ฅด๊ฒŒ ํ‰๊ฐ€ํ•ด์„œ ํ”ผ์ฒ˜ ์กฐํ•ฉ
๐Ÿค ์—ฌ๋Ÿฌ head ์“ฐ๋Š” ์ด์œ  ๋‹ค์–‘ํ•œ ๊ด€์ ์œผ๋กœ ์ด์›ƒ์„ ๋ณด๊ณ  ๋” ๊ฐ•๋ ฅํ•œ ํ‘œํ˜„ ๋งŒ๋“ค๊ธฐ

โœ… ๊ทธ๋Ÿผ head๋Š” ๋ช‡ ๊ฐœ๊ฐ€ ์ ๋‹นํ• ๊นŒ?

๋ฐ์ดํ„ฐ์…‹์ผ๋ฐ˜์ ์ธ head ์ˆ˜
Cora, Citeseer 4~8๊ฐœ
PubMed 4~16๊ฐœ
๋ณต์žกํ•œ ๊ทธ๋ž˜ํ”„ ๋” ๋งŽ๊ฒŒ ๊ฐ€๋Šฅํ•˜์ง€๋งŒ ์ฃผ์˜ ํ•„์š”

 

์ฆ‰, head ๋Š” ๊ฐ ๋…ธ๋“œ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” function ๊ฐ™์€๊ฑฐ !

 

 

 

 

728x90

'Study > Graph Neural Network' ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋‹ค๋ฅธ ๊ธ€

[6] GAT - Attention ์‹œ๊ฐํ™” (node attention, heatmap)  (0) 2025.07.02
[4] GNN(GCN) - Cora ๋…ผ๋ฌธ ์ธ์šฉ Network  (0) 2025.07.02
[3] GNN - Graph Classification  (0) 2025.07.02
[2] GNN - Node Classification  (3) 2025.07.02
[1] What is Graph Neural Network(GNN) ?  (0) 2025.07.01