๐ ์ฑ๋ฅ ๋น๊ต (์ผ๋ฐ์ ์ผ๋ก)
GCN | ํ๊ท ์ด์ ์ ๋ณด (์ ๊ทํ๋ ํฉ์ฐ) | ๊ธฐ๋ณธ | ๋น ๋ฆ, ๊ฐ๋จ |
GraphSAGE | ์ด์ ์ ๋ณด ์ง๊ณ (mean / LSTM ๋ฑ) | โ โ โ | inductive ๊ฐ๋ฅ |
GAT | ์ด์๋ง๋ค attention ๊ฐ์ค์น ์ ์ฉ | โ โ โ | ์ฑ๋ฅ ์ข์ง๋ง ๋๋ฆผ |
๐ง ๋ชฉํ
- Cora ๋ฐ์ดํฐ๋ก GrapheSAGE ๊ตฌํ
- Cora ๋ฐ์ดํฐ๋ก GAT ๊ตฌํ
๐กGraphSAGE (Graph Sample and Aggregation)
์ด์ ๋ ธ๋์ ์ ๋ณด๋ฅผ ์ง๊ณ(aggregate) ํด์ ๋ ธ๋ ์๋ฒ ๋ฉ์ ์ ๋ฐ์ดํธ ํ๋ ๋ฐฉ์
๐ง ํต์ฌ ์์ด๋์ด:
- ์ด์ ๋ ธ๋๋ค์ ๋จ์ ํ๊ท (MEAN), LSTM, MAX ๋ฑ์ผ๋ก ์ง๊ณ
- ์ด๋ ๊ฒ ์ง๊ณํ ๊ฐ์ ์๊ธฐ ์์ ์ ๊ฐ๊ณผ ๊ฒฐํฉ(concat)ํ์ฌ ๋ค์ ๋ ์ด์ด๋ก ์ ๋ฌ
๐ฆ ํน์ง
Inductive | ํ์ต ์ ์ ๋ณธ ๋ ธ๋์๋ ์ผ๋ฐํ ๊ฐ๋ฅ! |
๋ค์ํ Aggregator ์ง์ | mean, max, LSTM, pooling ๋ฑ |
๋น ๋ฅด๊ณ ์ ์ฐํจ | ๋ฏธ๋๋ฐฐ์น ํ์ต๋ ๊ฐ๋ฅํ๊ฒ ์ค๊ณ๋จ |
#Cora + GraphSAGE
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import SAGEConv
- SAGEConv: PyG์์ ์ ๊ณตํ๋ GraphSAGE ๋ ์ด์ด์ ๋๋ค.
dataset = Planetoid(root='data/', name='Cora')
data = dataset[0]
- Cora๋ ๋ ผ๋ฌธ ์ธ์ฉ ๊ทธ๋ํ์ ๋๋ค.
- data.x: ๋ ผ๋ฌธ ํผ์ฒ (1433์ฐจ์)
- data.edge_index: ์ธ์ฉ ๊ด๊ณ (2, 10556)
- data.y: ์ ๋ต ๋ ์ด๋ธ (์ฃผ์ 0~6)
- data.train_mask, data.test_mask: ํ์ต/ํ ์คํธ์ ์ฌ์ฉํ ๋ ธ๋
class GraphSAGE(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = SAGEConv(dataset.num_node_features, 32)
self.conv2 = SAGEConv(32, dataset.num_classes)
def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = self.conv2(x, edge_index)
return x
โ๏ธ ๋ชจ๋ธ ๊ตฌ์กฐ
conv1 | ์ ๋ ฅ 1433 → ์ค๊ฐ ํผ์ฒ 32์ฐจ์ |
conv2 | 32์ฐจ์ → ํด๋์ค ์(7) ์ถ๋ ฅ |
relu | ๋น์ ํ์ฑ ์ฃผ๊ธฐ |
model = GraphSAGE()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
- Adam: ๊ฐ์ค์น ์ ๋ฐ์ดํธ ์๊ณ ๋ฆฌ์ฆ
- lr: ํ์ต๋ฅ
- weight_decay: ๊ณผ์ ํฉ ๋ฐฉ์ง ์ ๊ทํ
for epoch in range(200):
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
model.train() | ํ์ต ๋ชจ๋๋ก ์ค์ |
optimizer.zero_grad() | ๊ธฐ์ธ๊ธฐ ์ด๊ธฐํ |
model(...) | forward ์คํ → ์์ธก |
cross_entropy(...) | ์์ธก๊ณผ ์ ๋ต ๋น๊ตํด ์์ค ๊ณ์ฐ |
loss.backward() | ์ญ์ ํ: ๊ธฐ์ธ๊ธฐ ๊ณ์ฐ |
optimizer.step() | ๊ฐ์ค์น ์ ๋ฐ์ดํธ |
model.eval()
pred = model(data.x, data.edge_index).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f"[GraphSAGE] Test Accuracy: {acc:.4f}")
model.eval() | ํ๊ฐ ๋ชจ๋ ์ค์ |
argmax(dim=1) | ๊ฐ ๋ ธ๋๋ง๋ค ๊ฐ์ฅ ๋์ ํด๋์ค ์ ํ |
correct == y | ์ ๋ต๊ณผ ๋น๊ตํด ๋ง์ถ ๊ฐ์ ๊ณ์ฐ |
acc | ์ ํ๋ (์ ๋ต ์ / ์ ์ฒด ์) |
๐ง ํ๋์ ์์ฝ
๋ชจ๋ธ ๊ตฌ์กฐ | ์ด์ ์ ๋ณด๋ฅผ ํ๊ท → ์์ ๊ณผ ๊ฒฐํฉ |
PyG ๋ ์ด์ด | SAGEConv |
์ฅ์ | ๋น ๋ฅด๊ณ inductive (์๋ก์ด ๋ ธ๋์๋ ์ ์ฉ ๊ฐ๋ฅ) |
์ฌ์ฉ ๋ฐฉ์ | GCN๊ณผ ๊ฑฐ์ ๋์ผํ๊ฒ ์ฌ์ฉ ๊ฐ๋ฅ |
์ฑ๋ฅ | GCN๋ณด๋ค ์ผ๋ฐ์ ์ผ๋ก ๋ ์ ์ฐํ๊ณ , ์ข ์ข ๋ ์ ํํจ |
๐กGAT (Graph Attention Network)
๐ง 1. GAT๋?
**Graph Attention Network (GAT)**๋ ๊ทธ๋ํ ์ ๊ฒฝ๋ง(GNN)์ ํ ์ข
๋ฅ๋ก,
**๋
ธ๋ ๊ฐ ์ฐ๊ฒฐ(edge)**์ ๋ํด ‘์ค์๋(attention)’๋ฅผ ์ค์ค๋ก ํ์ตํ์ฌ ์ ๋ณด ์ ํ๋ฅผ ์ํํ๋ ๋ชจ๋ธ์
๋๋ค.
GAT์ ๊ธฐ์กด GNN(GCN ๋ฑ)์ด ๊ฐ์ง๋ ํ๊ณ๋ฅผ ๊ฐ์ ํ๊ธฐ ์ํด ์ ์๋์์ผ๋ฉฐ,
๊ทธ๋ํ์ ์ด์ ๋
ธ๋๋ค ์ค ์ด๋ค ๋
ธ๋์ ์ ๋ณด๊ฐ ๋ ์ค์ํ์ง ๋์ ์ผ๋ก ํ๋จํ ์ ์๋๋ก ์ค๊ณ๋์์ต๋๋ค.
โ ์ Attention์ด ํ์ํ๊ฐ?
๊ธฐ์กด GCN์ ์ด์ ๋
ธ๋์ ์ ๋ณด๋ฅผ ๋จ์ํ **ํ๊ท (average) ๋๋ ํฉ(sum)**ํด์ ์ฌ์ฉํฉ๋๋ค.
→ ์ฆ, ๋ชจ๋ ์ด์ ์ ๋ณด๋ฅผ ๋์ผํ๊ฒ ๊ฐ์ฃผํฉ๋๋ค.
ํ์ง๋ง ํ์ค ์ธ๊ณ์ ๊ทธ๋ํ์์๋ ๋ชจ๋ ์ด์์ด ๋๋ฑํ๊ฒ ์ค์ํ์ง ์์ต๋๋ค.
์:
- SNS์์ ์น๊ตฌ ์ค ์ํฅ๋ ฅ ์๋ ์ฌ์ฉ์์ ์ ๋ณด๊ฐ ๋ ์ค์ํ ์ ์์
- ํํ ๋ถ์ ๊ตฌ์กฐ์์ ํน์ ์์๋ ๋ฐ์์ฑ์ ๋ ์ค์ํ ์ญํ ์ ํ ์ ์์
GAT์ ์ด๋ฐ ์ํฉ์ ํด๊ฒฐํ๊ธฐ ์ํด, "์ด๋ค ์ด์์ ์ ๋ณด์ ๋ ์ง์คํ ๊ฒ์ธ์ง"๋ฅผ ํ์ตํฉ๋๋ค.
๐ ๊ธฐ์กด GCN/SAGE vs GAT์ ์ฐจ์ด์
์ด์ ์ ๋ณด๋ฅผ ํ๊ท ํ๊ฑฐ๋ ๊ณ ์ ๋ฐฉ์์ผ๋ก ํฉ์นจ | ์ด์๋ง๋ค "์ผ๋ง๋ ์ค์ํ์ง" attention ๊ฐ์ค์น๋ฅผ ๊ณ์ฐํด์ ํฉ์นจ |
๋ชจ๋ ์ด์์ด ๋๋ฑํ๊ฒ ์ทจ๊ธ๋จ | ์ค์ํ ์ด์์ ๋ ๋ง์ด ๋ฐ์, ๋ ์ค์ํ ์ด์์ ๋ ๋ฐ์ |

๐ก ์ง๊ด
“์ค์ํ ์ด์์ ๋ง์ ๋ ๊ท๋ด์ ๋ฃ๊ณ , ๋ ์ค์ํ ์ด์์ ๋ฌด์ํ๋ค”
→ ๋ ผ๋ฌธ ์ธ์ฉ ๋คํธ์ํฌ์์ "ํต์ฌ ๋ ผ๋ฌธ"์ ๋ ์ํฅ๋ ฅ์ ๊ฐ๊ฒ ๋จ
๐ก ์ด๊ธฐ ์ ํ
##GAT
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
# 1. ๋ฐ์ดํฐ์
๋ก๋
dataset = Planetoid(root='data/', name='Cora')
data = dataset[0]
๐ก GAT ๋ชจ๋ธ ์ ์
# 2. GAT ๋ชจ๋ธ ์ ์
class GAT(torch.nn.Module):
def __init__(self, use_dropout=True):
super().__init__()
self.use_dropout = use_dropout
self.gat1 = GATConv(dataset.num_node_features, 8, heads=8, dropout=0.6)
self.gat2 = GATConv(8 * 8, dataset.num_classes, heads=1, concat=False, dropout=0.6)
def forward(self, x, edge_index):
if self.use_dropout:
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.gat1(x, edge_index))
if self.use_dropout:
x = F.dropout(x, p=0.6, training=self.training)
x = self.gat2(x, edge_index)
return x
๐ก ํ์ต ํจ์
# 3. ํ์ต ํจ์
def train(model):
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
๐ก ์๊ฐํ ํจ์ ์ ์
# 4. ์๊ฐํ ํจ์
def visualize(model, title=''):
model.eval()
with torch.no_grad():
x = data.x
x = model.gat1(x, data.edge_index) # ์ค๊ฐ ์๋ฒ ๋ฉ ์ถ์ถ
z = TSNE(n_components=2).fit_transform(x.cpu().numpy())
y = data.y.cpu().numpy()
plt.figure(figsize=(8, 6))
plt.title(title)
for i in range(dataset.num_classes):
plt.scatter(z[y == i, 0], z[y == i, 1], label=f"Class {i}", alpha=0.6)
plt.legend()
plt.grid(True)
plt.show()
1. model.eval()
- ํ์ต์ด ์๋ “ํ ์คํธ/์๊ฐํ” ๋ชจ๋๋ก ์ ํํฉ๋๋ค.
- Dropout, BatchNorm ๊ฐ์ ๊ณ์ธต์ด ๋ค๋ฅธ ๋ฐฉ์์ผ๋ก ์๋ํ๋ฏ๋ก ๋ฐ๋์ ํ์ํฉ๋๋ค.
- ์ฆ, ํญ์ ์ผ์ ํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฅํด์ฃผ๊ธฐ ์ํด ์๋๋ค.
2. with torch.no_grad():
- PyTorch์ ์ฐ์ฐ ๊ธฐ๋ก ๊ธฐ๋ฅ์ ๋
- ์ฐ๋ฆฌ๊ฐ ์ง๊ธ์ "ํ์ต"์ด ์๋๋ผ **“๊ฒฐ๊ณผ ๋ณด๊ธฐ”**๋ง ํ๋๊น
→ GPU ๋ฉ๋ชจ๋ฆฌ ์๋ผ๊ณ , ์๋๋ ๋นจ๋ผ์ง๋๋ค.
3. x = data.x
- Cora ๋ฐ์ดํฐ์์ ๊ฐ ๋ ผ๋ฌธ์ ๋ํ๋ด๋ ๋ฒกํฐ (feature)
- shape: [2708, 1433]
→ 2708๊ฐ์ ๋ ธ๋(๋ ผ๋ฌธ), ๊ฐ๊ฐ 1433์ฐจ์ (๋จ์ด ๊ธฐ์ค bag-of-words)
4. x = model.gat1(x, data.edge_index)
- gat1์ GATConv ๋ ์ด์ด์ ๋๋ค.
- ์ฌ๊ธฐ์ ํ๋ ์ผ:
- ๊ฐ ๋ ธ๋๊ฐ ์๊ธฐ ์ด์(์ธ์ฉํ ๋ ผ๋ฌธ)์ ์ ๋ณด๋ฅผ ๋ณด๊ณ
- attention score๋ฅผ ๊ณ์ฐํด์,
- ์์ ๋ง์ ๋ฒกํฐ ํํ์ ์์ฑํฉ๋๋ค.
- ๊ฒฐ๊ณผ: shape [2708, 64]
→ ๊ฐ ๋ ธ๋๊ฐ 64์ฐจ์์ ๋ฒกํฐ๋ก ์์ถ๋จ
5. z = TSNE(n_components=2).fit_transform(...)
โ TSNE๋?
- ๊ณ ์ฐจ์ ๋ฒกํฐ(64์ฐจ์)๋ฅผ 2์ฐจ์์ผ๋ก ์ค์ฌ์ ์๊ฐํํ๋ ๋ฐฉ๋ฒ์ ๋๋ค.
- ๋จ์ํ PCA๋ณด๋ค ๋ ๊ตฐ์ง ๊ตฌ์กฐ๋ฅผ ์ ๋ณด์กดํฉ๋๋ค.
- Input : 2708๊ฐ ๋ฒกํฐ (๊ฐ๊ฐ 64์ฐจ์)
↓
TSNE → ๋น์ทํ ๋ฒกํฐ๋ผ๋ฆฌ ๊ฐ๊น๊ฒ, ๋ค๋ฅธ ๋ฒกํฐ๋ ๋ฉ๋ฆฌ
↓
Output : 2708๊ฐ 2D ์ขํ (๊ฐ๊ฐ [x, y]) - ๊ฒฐ๊ณผ shape: [2708, 2]
- ์ด๊ฑธ ํ๋ฉด(x, y) ์์ ๋ฟ๋ฆฝ๋๋ค.
6. y = data.y.cup().numpy()
- ๊ฐ ๋ ธ๋์ **์ ๋ต ๋ผ๋ฒจ (0~6 ํด๋์ค)**์ ๊ฐ์ ธ์ต๋๋ค.
- shape: [2708]
7. Plotting with matplotlib
for i in range(dataset.num_classes):
plt.scatter(z[y == i, 0], z[y == i, 1], label=f"Class {i}", alpha=0.6)
- ํด๋์ค๋ณ๋ก ๋๋ ์ ๋ค๋ฅธ ์๊น๋ก ์ ๋ค์ ์ฐ์ด์
- z[y == i]: ํด๋์ค i์ ์ํ ๋ ธ๋๋ค๋ง ์ ํ
- .scatter(...): ํด๋น ๋ ธ๋๋ค์ x, y ์ขํ๋ฅผ 2D์ ์ฐ๊ธฐ
๐ ์ gat1๊น์ง๋ง?
model.gat1(...)๊น์ง๋ง ํต๊ณผ์ํจ ์ด์ ๋
์ถ๋ ฅ์ธต(gat2)์ softmax ์ด์ ์ ๋ก์ง์ด๋ฏ๋ก, ๊ตฐ์ง ํํ๊ฐ ๋ช ํํ์ง ์์์.gat1์ "๋ด๊ฐ ์ด์ ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ด๋ค ๋ฒกํฐ๋ก ๋ณํ๋์ง"๋ฅผ ๋ณด์ฌ์ฃผ๋ ์ค๊ฐ ํํ์ด๋ผ
๋ด๋ถ ํํ ๊ณต๊ฐ์ด ์ด๋ป๊ฒ ํ์ฑ๋๋์ง ํ์ธํ๊ธฐ์ ์ข์ต๋๋ค.
โ ์์ฝ ์ ๋ฆฌ
model.gat1(...) | GAT ์ฒซ ๋ ์ด์ด์ ์๋ฒ ๋ฉ ์ถ์ถ |
TSNE(n_components=2) | 64์ฐจ์ → 2์ฐจ์ ์ถ์ |
scatter(...) | ๊ฐ ํด๋์ค๋ฅผ ๋ค๋ฅธ ์์ผ๋ก ์ฐ๊ธฐ |
๋ชฉ์ | ๋ชจ๋ธ์ด ๋ ธ๋๋ฅผ ์ผ๋ง๋ ์ ๊ตฌ๋ถํ๋์ง ์๊ฐ์ ์ผ๋ก ํ์ธ |
# 5. ํ์ต ๋ฐ ์๊ฐํ
# (1) Dropout ํฌํจ
model_dropout = GAT(use_dropout=True)
train(model_dropout)
visualize(model_dropout, title='GAT with Dropout')
# (2) Dropout ์์ด
model_nodrop = GAT(use_dropout=False)
train(model_nodrop)
visualize(model_nodrop, title='GAT without Dropout')
[์ฐธ๊ณ ์ฌํญ]
๐ฏ GCN์ด๋ GraphSAGE ์ฐจ์ด๋ ?
ํ ์ค ์์ฝ
GCN | ์ด์์ ์ ๋ณด๋ฅผ **ํ๊ท (MEAN)**ํด์ ์ ๋ฐ์ดํธ |
GraphSAGE | ์ด์ ์ ๋ณด๋ฅผ **์ปค์คํฐ๋ง์ด์ฆ๋ ๋ฐฉ์(MEAN, LSTM, MAX ๋ฑ)**์ผ๋ก ์ง๊ณํด์ ์ ๋ฐ์ดํธ |
ํต์ฌ ์์ด๋์ด | ์ ํ ๋ณํ ํ ์ ๊ทํ๋ ํ๊ท | ๋ค์ํ Aggregator ํจ์๋ก ์ด์ ์ ๋ณด ์์ฝ |
์๊ธฐ ์ ๋ณด ํฌํจ | YES | YES |
ํ๋ผ๋ฏธํฐ ๊ณต์ | YES (๊ณตํต W) | YES (๊ณตํต W), Aggregator๋ ์ ํ ๊ฐ๋ฅ |
์ฐ์ฐ ๋ฐฉ์ | "์ ๊ทํ๋ ํฉ์ฑ๊ณฑ"์ฒ๋ผ ๋์ | "์ํ๋ง + Aggregation" ๊ตฌ์กฐ |
๋ํ Aggregator | ํ๊ท (MEAN) | MEAN, LSTM, MAX, POOL ๋ฑ |
๐ ์ง๊ด์ ์์
์: ๋ ธ๋ A์ ์ด์ = {B, C}
๐ข GCN
- A์ ์ ํผ์ฒ = B, C, A์ ๋ฒกํฐ๋ฅผ ํ๊ท → ์ ํ ๋ณํ
- "์ด์ ์ ๋ณด๋ฅผ ํ๋ฑํ๊ฒ ํฉ์ณ์ ์ ๋ฌ"
๐ต GraphSAGE
- A์ ์ ํผ์ฒ =
- (1) A ์์ฒด์ ๋ฒกํฐ
- (2) B, C์ ๋ฒกํฐ๋ฅผ ํ๊ท
→ ์ด ๋์ ์ด์ด๋ถ์ธ ๋ค → ์ ํ๋ณํ
- "์๊ธฐ ์ ๋ณด + ์ด์ ์ ๋ณด๋ฅผ ๋๋ ์ ์ฒ๋ฆฌ"
๐ฏ Dropout์ด๋?
Dropout์ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ๊ณผ์ ํฉ(overfitting)์ ๋ง๊ธฐ ์ํด ๊ณ ์๋ ์ ๊ทํ ๊ธฐ๋ฒ์
๋๋ค.
ํ์ตํ ๋ ๋ฌด์์๋ก ์ผ๋ถ ๋ด๋ฐ์ ๊บผ์(= 0์ผ๋ก ๋ง๋ค์ด์) ํ์ตํ๊ฒ ํฉ๋๋ค.
๐ ์ ํ์ํ๊ฐ์?
์ ๊ฒฝ๋ง์ ํ์ต ๋ฐ์ดํฐ์ ๋๋ฌด ์ ๋ง๊ฒ ๋๋ฉด,
→ **์๋ก์ด ๋ฐ์ดํฐ์ ์ผ๋ฐํ๊ฐ ์ ์ ๋๋ ํ์ (๊ณผ์ ํฉ)**์ด ๋ฐ์ํฉ๋๋ค.
์: ํ๋ จ ๋ฐ์ดํฐ๋ง ์ธ์ฐ๊ณ , ์ํ์ ๋ชป ๋ณด๋ ์ํ
๊ทธ๋์ Dropout์ ์ฐ๋ฉด,
→ ํ์ต ์ค ์ผ๋ถ ๋ด๋ฐ์ ๊บผ์ ๊ฐ์ ๋ก ๋ชจ๋ธ์ด “๋ ์ผ๋ฐ์ ์ธ” ํจํด์ ํ์ตํ๋๋ก ์ ๋ํฉ๋๋ค.
๐งช ํ์ต vs ํ ์คํธ
ํ์ต (train) | โ ์ฌ์ฉ | ์ผ๋ถ ๋ด๋ฐ์ ๋๋คํ๊ฒ ๋ |
ํ๊ฐ (eval) | โ ๋ฏธ์ฌ์ฉ | ๋ชจ๋ ๋ด๋ฐ ์ฌ์ฉ (ํ๊ท ํ๋จ) |
๐งฉ Cora ๋ฐ์ดํฐ ๊ตฌ์กฐ ๋ณต์ต
- data.x: 1433์ฐจ์ ๋ฒกํฐ (๋ ผ๋ฌธ์์ ์ด๋ค ๋จ์ด๊ฐ ๋ฑ์ฅํ๋์ง)
- ์:
data.x[0] = [0, 1, 0, 0, 1, ..., 0] # ๋ ผ๋ฌธ 0์ ๋จ์ด2, ๋จ์ด5 ๋ฑ์ฅ
- GAT ๋ชจ๋ธ์ ์ด๊ฑธ ๋ฐ์์ ๋ ธ๋ ์๋ฒ ๋ฉ์ ๋ง๋ญ๋๋ค.
๐ฅ Dropout์ด ์๋ํ๋ ๋ถ๋ถ
x = F.dropout(x, p=0.6, training=self.training)
์ฆ,
- x๋ ๊ฐ ๋ ผ๋ฌธ(๋ ธ๋)์ 1433์ฐจ์ ํผ์ฒ ๋ฒกํฐ
- ์ด๊ฑธ ํ์ต ์ค์๋ 60% ํ๋ฅ ๋ก ๋๋คํ๊ฒ ์ผ๋ถ ์ฐจ์์ ๊บผ๋ฒ๋ฆฝ๋๋ค
- ์ฆ, ๋ ผ๋ฌธ 0์ ํผ์ฒ ์ค ์ผ๋ถ ๋จ์ด ์ ๋ณด๋ "์๋ ๊ฒ์ฒ๋ผ" ํ์ตํ๊ฒ ๋ฉ๋๋ค.
๐ฏ ์ ์ด๋ ๊ฒ ํ๋?
์ด์ :
๐ ํน์ ํผ์ฒ(๋จ์ด)์ ๋๋ฌด ์์กดํ์ง ์๋๋ก ํ๊ธฐ ์ํด์์
๋๋ค.
๋ง์ฝ dropout ์์ด ํ์ตํ๋ฉด,
- ์ด๋ค ๋จ์ด ํ๋๊ฐ "๋ต"์ ์์ํ๋ ๊ฒฝ์ฐ,
- ๋ชจ๋ธ์ด ๊ทธ ๋จ์ด๋ง ์ธ์ฐ๊ฒ ๋จ → ์ผ๋ฐํ ์ ์ ๋จ
๋ฐ๋ฉด dropout์ ์ฐ๋ฉด,
- ๊ทธ ๋จ์ด๊ฐ ๋๋คํ๊ฒ ์ฌ๋ผ์ง๊ธฐ๋ ํจ
- → ๋ชจ๋ธ์ ๋ค์ํ ๋จ์ด ์กฐํฉ์์ ์ ์๋ํด์ผ ํจ
data.x → Dropout → GATConv (attention) → Dropout → GATConv → ๊ฒฐ๊ณผ
์์น ์ญํ
์ฒซ ๋ฒ์งธ Dropout ์
๋ ฅ ํผ์ฒ ์ผ๋ถ ์ ๊ฑฐํด์ ๊ฐ๊ฑดํ ํํ ์ ๋
๋ ๋ฒ์งธ Dropout ์ค๊ฐ ์๋ฒ ๋ฉ ์ผ๋ถ ์ ๊ฑฐ (์ค๋ฒํผํ
๋ฐฉ์ง)
x = F.dropout(x, p=0.6, training=True)
์ด๊ฑด ๋ด๋ถ์ ์ผ๋ก ๋ค์์ ํฉ๋๋ค:
- x์ ๊ฐ์ shape์ ๋ง์คํฌ mask๋ฅผ ์์ฑ (1 ๋๋ 0)
- ๊ฐ ์์์ ๋ํด:
- 60% ํ๋ฅ ๋ก 0
- 40% ํ๋ฅ ๋ก 1 / 0.4 (= 2.5)๋ก ์ค์ผ์ผ
์ฆ, ๋ค์์ฒ๋ผ ์๋ํฉ๋๋ค:
1.0 | 0 | 0.0 |
0.0 | 1 | 0.0 |
1.0 | 1 | 1.0 / 0.4 ≈ 2.5 |
0.0 | 0 | 0.0 |
๐ 0์ธ ๊ฐ์ ๋ฌด์กฐ๊ฑด 0 ๊ทธ๋๋ก ์ ์ง,
๐ 1์ธ ๊ฐ์ ๋๋คํ๊ฒ ์ด๊ฑฐ๋ ์ฃฝ๊ณ , ์ด์๋จ์ผ๋ฉด 2.5๋ก ์กฐ์ ๋ฉ๋๋ค (ํ์ต ํํ ์ ์ง์ฉ)
โ ์ค์ ๊ฐ๋ : Dropout์ ๋ฌด์กฐ๊ฑด ๋ฌด์์!
"๊ฐ์ด ์์ผ๋ฉด ๋๋กญ๋๊ณ , ํฌ๋ฉด ์ ๋๋กญ๋๋ค?" | โ X |
"์ค์ํ์ง ์์ผ๋ฉด ๋๋กญ๋๋ค?" | โ X |
"๋ฌด์กฐ๊ฑด ๋๋คํ๊ฒ ๋๋กญ๋๋ค" | โ O |
GAT๋:
- **attention ๊ณ์ α<sub>ij</sub>**๋ฅผ ํตํด
→ ์ด์ ์ค ์ค์ํ ๋ ธ๋์ ๋ ๋ง์ ๊ฐ์ค์น๋ฅผ ์ค๋๋ค. - ์ด ์์ฒด๋ก ์ค์ํ์ง ์์ ์ด์์ ๋ฌด์ํ๋ ํจ๊ณผ๊ฐ ์์ด์.
Dropout์:
- ์ด attention ๊ณ์ ๋๋ ์ ๋ ฅ ํผ์ฒ์ **๋ฌด์์์ฑ(randomness)**์ ๋ฃ์ด์
- ํ์ต์ ๋ robustํ๊ฒ ๋ง๋ค๊ณ ,
- ํน์ ํผ์ฒ๋ ์ฐ๊ฒฐ์๋ง ์์กดํ์ง ์๋๋ก ๋์์ค๋๋ค.
๊ทธ๋์ ์ ๋ฆฌํ๋ฉด:
โ Dropout์ GAT์ attention๊ณผ ์ํธ ๋ณด์์ ์ธ ์ญํ ์ ํฉ๋๋ค.
GAT์ด ์ฐ๊ฒฐ์ ๊ฐ์ค์น ์ฐจ์ด๋ฅผ ๋๋ ๊ธฐ๋ฅ์ด๋ผ๋ฉด,
Dropout์ ์ด ์ฐ๊ฒฐ์ด๋ ํผ์ฒ ์์ฒด์ ๋ฌด์์์ฑ์ ๋ถ์ฌํด์ ๊ณผ์ ํฉ์ ๋ฐฉ์งํฉ๋๋ค.
๐ฏ GAT ์์ head๋?
GATConv๋ **๋ฉํฐ-ํค๋ ์ดํ ์ (Multi-head Attention)**์ ์ฌ์ฉํฉ๋๋ค.
์ฆ, "์ฌ๋ฌ ๊ฐ์ ๋ ๋ฆฝ๋ attention layer๋ฅผ ๋์์ ํ์ตํด์ ๊ฒฐ๊ณผ๋ฅผ ํฉ์น๋ค"
GAT์์ head๋, "๊ฐ๊ธฐ ๋ค๋ฅธ ๋ฐฉ์์ผ๋ก ์ด์ ์ ๋ณด๋ฅผ ํตํฉํ๋ ๋ณ๋ ฌ์ ์ธ attention ๊ณ์ฐ๊ธฐ"์ ๋๋ค.
๋ฌผ๋ฆฌ์ ์ผ๋ก๋:
- ๊ฐ head๋ ๋ ๋ฆฝ๋ weight(๊ฐ์ค์น)๋ฅผ ๊ฐ๋ ํ๋์ GNN ๋ ์ด์ด์์.
- Cora ๋ฐ์ดํฐ์์ ๊ฐ head๋ **๋ ผ๋ฌธ(๋ ธ๋)**๊ฐ ์ด์ ๋ ผ๋ฌธ์ผ๋ก๋ถํฐ ์ ๋ณด๋ฅผ ๋ฐ์์ฌ ๋,
→ “๋ค๋ฅด๊ฒ ์ค์๋๋ฅผ ํ๋จํด์” ์ ๋ณด๋ฅผ ์ทจํฉํฉ๋๋ค.
๐ ์ด๋ป๊ฒ 64์ฐจ์์ด ๋ง๋ค์ด์ง๋?
1๊ฐ์ GAT head๋:
- ์ด์ ์ ๋ณด๋ฅผ attention์ผ๋ก ๋ฐ์์ → out_channels = 8์ฐจ์ ์ถ๋ ฅ
8๊ฐ์ GAT head๋:
- ์ด ๊ณผ์ ์ 8๋ฒ ๋ ๋ฆฝ์ ์ผ๋ก ๋ณ๋ ฌ ์ํ
- ๋ง์ง๋ง์ **concat (์ฐ๊ฒฐ)**ํฉ๋๋ค.
Head 1: 8์ฐจ์
Head 2: 8์ฐจ์
...
Head 8: 8์ฐจ์
-------------
Concat → ์ด 64์ฐจ์
๐ฏ ํต์ฌ ์ ๋ฆฌ
๐ก head๋? | ์ด์์ ๋ฐ๋ผ๋ณด๋ ๋ ๋ฆฝ๋ attention ๊ณ์ฐ๊ธฐ |
๐ฆ ๋ด๋ถ์ ๋ญ ์์? | ๊ฐ head๋ง๋ค ๋ ๋ฆฝ์ ์ธ weight, attention ๊ณ์ฐ |
๐ ์ญํ | ์ด์ ๋ ธ๋์ ์ค์๋๋ฅผ ๋ค๋ฅด๊ฒ ํ๊ฐํด์ ํผ์ฒ ์กฐํฉ |
๐ค ์ฌ๋ฌ head ์ฐ๋ ์ด์ | ๋ค์ํ ๊ด์ ์ผ๋ก ์ด์์ ๋ณด๊ณ ๋ ๊ฐ๋ ฅํ ํํ ๋ง๋ค๊ธฐ |
โ ๊ทธ๋ผ head๋ ๋ช ๊ฐ๊ฐ ์ ๋นํ ๊น?
Cora, Citeseer | 4~8๊ฐ |
PubMed | 4~16๊ฐ |
๋ณต์กํ ๊ทธ๋ํ | ๋ ๋ง๊ฒ ๊ฐ๋ฅํ์ง๋ง ์ฃผ์ ํ์ |
์ฆ, head ๋ ๊ฐ ๋ ธ๋์ ๊ฐ์ค์น๋ฅผ ๊ณ์ฐํ๋ function ๊ฐ์๊ฑฐ !
'Study > Graph Neural Network' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[6] GAT - Attention ์๊ฐํ (node attention, heatmap) (0) | 2025.07.02 |
---|---|
[4] GNN(GCN) - Cora ๋ ผ๋ฌธ ์ธ์ฉ Network (0) | 2025.07.02 |
[3] GNN - Graph Classification (0) | 2025.07.02 |
[2] GNN - Node Classification (3) | 2025.07.02 |
[1] What is Graph Neural Network(GNN) ? (0) | 2025.07.01 |