class Edge_Discriminator(nn.Module):
def __init__(self, nnodes, input_dim, alpha, sparse, hidden_dim=128, temperature=1.0, bias=0.0 + 0.0001):
"""
그래프 엣지를 판별하는 모델인 Edge Discriminator의 클래스입니다.
Parameters:
- nnodes: 노드의 수
- input_dim: 입력 특성의 차원
- alpha: Negative sampling에서 사용되는 가중치
- sparse: 희소 그래프 여부
- hidden_dim: 은닉 레이어의 차원 (기본값: 128)
- temperature: Gumbel-Softmax 샘플링의 온도 매개변수 (기본값: 1.0)
- bias: Gumbel-Softmax 샘플링에서의 바이어스 (기본값: 0.0001)
Returns:
- None
"""
super(Edge_Discriminator, self).__init__()
# 임베딩 레이어 및 MLP 레이어 정의
self.embedding_layers = nn.ModuleList()
self.embedding_layers.append(nn.Linear(input_dim, hidden_dim))
self.edge_mlp = nn.Linear(hidden_dim * 2, 1)
# 모델 파라미터 설정
self.temperature = temperature
self.bias = bias
self.nnodes = nnodes
self.sparse = sparse
self.alpha = alpha
모델은 위처럼 구현되어있다. 노드를 임베딩해주는 layer, 노드 두 개를 넣어서 edge정보를 계산해주는 edge_mlp레이어가 있다.
def weight_forward(self, features, edges):
"""
모델을 통해 엣지의 가중치를 계산합니다.
Parameters:
- features: 노드 특성
- edges: 그래프의 엣지 정보
Returns:
- Gumbel-Softmax로 샘플링된 엣지의 가중치 (Low Probability, High Probability)
"""
embeddings = self.get_node_embedding(features)
edges_weights_raw = self.get_edge_weight(embeddings, edges)
weights_lp = self.gumbel_sampling(edges_weights_raw)
weights_hp = 1 - weights_lp
# high probability가 heomo, low probability가 hetero
return weights_lp, weights_hp
이걸 보면 어디에 뭐가 사용되는지 알 수 있다.
get_node_embedding으로 노드의 feature vector를 노드임베딩으로 변환해준다. 이 함수에선 embedding_layer가 사용된다.
get_dege_weight에는 노드 임베딩과 엣지를 넣는다. edge_mlp를 통과해서 엣지의 가중치가 생성된다. 이때 생성된 edge_weight_raw에는 10556개의 weight(스칼라값)이 들어있다.
gumbel_sampling은 수학적으로 뭔가 sampling을 하는 것이다. 굼벨 샘플링은 Generative Models이나 강화 학습(Reinforcement Learning)에 많이 사용된다고 한다. sigmoid를 사용했기 때문에 0~1 사이이다. 이를 통과하면 엣지가 low probability한 것에 대한 가중치를 나타내준다. 여기서 high probability한 것에 대한 가중치를 얻으려면 1에서 빼주면 된다. (low probability는 heterophily, high probability는 homophily라고 생각하면 될 듯 하다.)
위에서 얻은 weight_lp, weight_hp를 이용해서 homophily한 버전의 인접행렬(adj_hp)과 heterophily한 버전의 인접행렬(adj_lp)를 얻을 수 있다. weight_to_adj라는 함수에서 이 동작이 수행된다. (코드는 github 참고)
Reference
논문 - https://arxiv.org/pdf/2211.14065.pdf
github - https://github.com/yixinliu233/GREET
'공부 > 실전문제연구단' 카테고리의 다른 글
[실전문제연구단] GREET 코드 뜯어보기 - 4. train 과정 (1) | 2024.01.15 |
---|---|
[실전문제연구단] GREET 코드 뜯어보기 - 2. GCL 모델 (0) | 2024.01.12 |
[실전문제연구단] GREET 코드 뜯어보기 - 1. data 확인 (0) | 2024.01.12 |
[실전문제연구단] 주제 선정 (0) | 2024.01.11 |